diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 5a3228658e..d6d61ee71d 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -422,8 +422,11 @@ def convert_ldm_unet_checkpoint( ) for key in keys: if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + ".".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + for delimiter in ['','.']: + flat_ema_key = "model_ema." + delimiter.join(key.split(".")[1:]) + if checkpoint.get(flat_ema_key) is not None: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + break else: if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( @@ -1070,7 +1073,7 @@ def convert_controlnet_checkpoint( extract_ema, use_linear_projection=None, cross_attention_dim=None, - precision: torch.dtype = torch.float32, + precision: torch.dtype = None, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention @@ -1121,7 +1124,7 @@ def download_from_original_stable_diffusion_ckpt( prediction_type: str = None, model_type: str = None, extract_ema: bool = False, - precision: torch.dtype = torch.float32, + precision: torch.dtype = None, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, @@ -1194,6 +1197,8 @@ def download_from_original_stable_diffusion_ckpt( [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if needed. + precision (`torch.dtype`, *optional*, defauts to `None`): + If not provided the precision will be set to the precision of the original file. return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ @@ -1251,6 +1256,10 @@ def download_from_original_stable_diffusion_ckpt( checkpoint = checkpoint["state_dict"] logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") + + precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" + logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") + precision = precision or checkpoint[precision_probing_key].dtype if original_config_file is None: key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" @@ -1281,7 +1290,7 @@ def download_from_original_stable_diffusion_ckpt( original_config = OmegaConf.load(original_config_file) if ( model_version == BaseModelType.StableDiffusion2 - and original_config["model"]["params"]["parameterization"] == "v" + and original_config["model"]["params"].get("parameterization") == "v" ): prediction_type = "v_prediction" upcast_attention = True @@ -1447,7 +1456,7 @@ def download_from_original_stable_diffusion_ckpt( if controlnet: pipe = pipeline_class( vae=vae.to(precision), - text_encoder=text_model, + text_encoder=text_model.to(precision), tokenizer=tokenizer, unet=unet.to(precision), scheduler=scheduler, @@ -1459,7 +1468,7 @@ def download_from_original_stable_diffusion_ckpt( else: pipe = pipeline_class( vae=vae.to(precision), - text_encoder=text_model, + text_encoder=text_model.to(precision), tokenizer=tokenizer, unet=unet.to(precision), scheduler=scheduler, @@ -1484,8 +1493,8 @@ def download_from_original_stable_diffusion_ckpt( image_noising_scheduler=image_noising_scheduler, # regular denoising components tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, + text_encoder=text_model.to(precision), + unet=unet.to(precision), scheduler=scheduler, # vae vae=vae, @@ -1560,7 +1569,7 @@ def download_from_original_stable_diffusion_ckpt( if controlnet: pipe = pipeline_class( vae=vae.to(precision), - text_encoder=text_model, + text_encoder=text_model.to(precision), tokenizer=tokenizer, unet=unet.to(precision), controlnet=controlnet, @@ -1571,7 +1580,7 @@ def download_from_original_stable_diffusion_ckpt( else: pipe = pipeline_class( vae=vae.to(precision), - text_encoder=text_model, + text_encoder=text_model.to(precision), tokenizer=tokenizer, unet=unet.to(precision), scheduler=scheduler, @@ -1594,9 +1603,9 @@ def download_from_original_stable_diffusion_ckpt( pipe = StableDiffusionXLPipeline( vae=vae.to(precision), - text_encoder=text_encoder, + text_encoder=text_encoder.to(precision), tokenizer=tokenizer, - text_encoder_2=text_encoder_2, + text_encoder_2=text_encoder_2.to(precision), tokenizer_2=tokenizer_2, unet=unet.to(precision), scheduler=scheduler, @@ -1639,7 +1648,7 @@ def download_controlnet_from_original_ckpt( original_config_file: str, image_size: int = 512, extract_ema: bool = False, - precision: torch.dtype = torch.float32, + precision: torch.dtype = None, num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, @@ -1680,6 +1689,12 @@ def download_controlnet_from_original_ckpt( while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] + # use original precision + precision_probing_key = 'input_blocks.0.0.bias' + ckpt_precision = checkpoint[precision_probing_key].dtype + logger.debug(f'original controlnet precision = {ckpt_precision}') + precision = precision or ckpt_precision + original_config = OmegaConf.load(original_config_file) if num_in_channels is not None: @@ -1699,7 +1714,7 @@ def download_controlnet_from_original_ckpt( cross_attention_dim=cross_attention_dim, ) - return controlnet + return controlnet.to(precision) def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index e075843a56..ed1e7316dc 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -17,7 +17,7 @@ from .base import ( ModelNotFoundException, ) from invokeai.app.services.config import InvokeAIAppConfig - +import invokeai.backend.util.logging as logger class ControlNetModelFormat(str, Enum): Checkpoint = "checkpoint" @@ -66,7 +66,7 @@ class ControlNetModel(ModelBase): child_type: Optional[SubModelType] = None, ): if child_type is not None: - raise Exception("There is no child models in controlnet model") + raise Exception("There are no child models in controlnet model") model = None for variant in ["fp16", None]: @@ -123,10 +123,7 @@ class ControlNetModel(ModelBase): else: return model_path - -@classmethod def _convert_controlnet_ckpt_and_cache( - cls, model_path: str, output_path: str, base_model: BaseModelType, @@ -140,7 +137,8 @@ def _convert_controlnet_ckpt_and_cache( app_config = InvokeAIAppConfig.get_config() weights = app_config.root_path / model_path output_path = Path(output_path) - + + logger.info(f"Converting {weights} to diffusers format") # return cached version if it exists if output_path.exists(): return output_path diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 76b4833f9c..e4396a9582 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -123,6 +123,7 @@ class StableDiffusion1Model(DiffusersModel): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion1, model_config=config, + load_safety_checker=False, output_path=output_path, ) else: