diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index dfd7fd100c..aa4dc0ecbc 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -1071,6 +1071,7 @@ def convert_controlnet_checkpoint( extract_ema, use_linear_projection=None, cross_attention_dim=None, + precision: torch.dtype=torch.float32, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention @@ -1108,9 +1109,9 @@ def convert_controlnet_checkpoint( controlnet.load_state_dict(converted_ctrl_checkpoint) - return controlnet - + return controlnet.to(precision) +# TO DO - PASS PRECISION def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, model_version: BaseModelType, @@ -1120,6 +1121,7 @@ def download_from_original_stable_diffusion_ckpt( prediction_type: str = None, model_type: str = None, extract_ema: bool = False, + precision: torch.dtype = torch.float16, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, @@ -1395,11 +1397,11 @@ def download_from_original_stable_diffusion_ckpt( with ctx(): unet = UNet2DConditionModel(**unet_config) - # if is_accelerate_available(): - # for param_name, param in converted_unet_checkpoint.items(): - # set_module_tensor_to_device(unet, param_name, "cpu", value=param) - # else: - unet.load_state_dict(converted_unet_checkpoint) + if is_accelerate_available(): + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) # Convert the VAE model. if vae_path is None: @@ -1439,10 +1441,10 @@ def download_from_original_stable_diffusion_ckpt( if stable_unclip is None: if controlnet: pipe = pipeline_class( - vae=vae, + vae=vae.to(precision), text_encoder=text_model, tokenizer=tokenizer, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, controlnet=controlnet, safety_checker=None, @@ -1451,10 +1453,10 @@ def download_from_original_stable_diffusion_ckpt( ) else: pipe = pipeline_class( - vae=vae, + vae=vae.to(precision), text_encoder=text_model, tokenizer=tokenizer, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, safety_checker=None, feature_extractor=None, @@ -1542,10 +1544,10 @@ def download_from_original_stable_diffusion_ckpt( if controlnet: pipe = pipeline_class( - vae=vae, + vae=vae.to(precision), text_encoder=text_model, tokenizer=tokenizer, - unet=unet, + unet=unet.to(precision), controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, @@ -1553,10 +1555,10 @@ def download_from_original_stable_diffusion_ckpt( ) else: pipe = pipeline_class( - vae=vae, + vae=vae.to(precision), text_encoder=text_model, tokenizer=tokenizer, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, @@ -1576,12 +1578,12 @@ def download_from_original_stable_diffusion_ckpt( ) pipe = StableDiffusionXLPipeline ( - vae=vae, + vae=vae.to(precision), text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, force_zeros_for_empty_prompt=True, ) @@ -1598,12 +1600,12 @@ def download_from_original_stable_diffusion_ckpt( ) pipe = StableDiffusionXLImg2ImgPipeline( - vae=vae, + vae=vae.to(precision), text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, requires_aesthetics_score=True, force_zeros_for_empty_prompt=False, @@ -1622,6 +1624,7 @@ def download_controlnet_from_original_ckpt( original_config_file: str, image_size: int = 512, extract_ema: bool = False, + precision: torch.dtype = torch.float16, num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 6a8329e911..cf34686d6f 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -237,6 +237,7 @@ class StableDiffusion2Model(DiffusersModel): return model_path # TODO: rework +# pass precision - currently defaulting to fp16 def _convert_ckpt_and_cache( version: BaseModelType, model_config: Union[StableDiffusion1Model.CheckpointConfig,