SDXL checkpoint models now convert and load; needs refactor

This commit is contained in:
Lincoln Stein 2023-07-23 00:00:31 -04:00
parent b1d7c9b306
commit 5e59edfaf1
2 changed files with 23 additions and 19 deletions

View File

@ -1071,6 +1071,7 @@ def convert_controlnet_checkpoint(
extract_ema, extract_ema,
use_linear_projection=None, use_linear_projection=None,
cross_attention_dim=None, cross_attention_dim=None,
precision: torch.dtype=torch.float32,
): ):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config["upcast_attention"] = upcast_attention
@ -1108,9 +1109,9 @@ def convert_controlnet_checkpoint(
controlnet.load_state_dict(converted_ctrl_checkpoint) controlnet.load_state_dict(converted_ctrl_checkpoint)
return controlnet return controlnet.to(precision)
# TO DO - PASS PRECISION
def download_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
model_version: BaseModelType, model_version: BaseModelType,
@ -1120,6 +1121,7 @@ def download_from_original_stable_diffusion_ckpt(
prediction_type: str = None, prediction_type: str = None,
model_type: str = None, model_type: str = None,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float16,
scheduler_type: str = "pndm", scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
@ -1395,10 +1397,10 @@ def download_from_original_stable_diffusion_ckpt(
with ctx(): with ctx():
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
# if is_accelerate_available(): if is_accelerate_available():
# for param_name, param in converted_unet_checkpoint.items(): for param_name, param in converted_unet_checkpoint.items():
# set_module_tensor_to_device(unet, param_name, "cpu", value=param) set_module_tensor_to_device(unet, param_name, "cpu", value=param)
# else: else:
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model. # Convert the VAE model.
@ -1439,10 +1441,10 @@ def download_from_original_stable_diffusion_ckpt(
if stable_unclip is None: if stable_unclip is None:
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
controlnet=controlnet, controlnet=controlnet,
safety_checker=None, safety_checker=None,
@ -1451,10 +1453,10 @@ def download_from_original_stable_diffusion_ckpt(
) )
else: else:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
@ -1542,10 +1544,10 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet.to(precision),
controlnet=controlnet, controlnet=controlnet,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
@ -1553,10 +1555,10 @@ def download_from_original_stable_diffusion_ckpt(
) )
else: else:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
@ -1576,12 +1578,12 @@ def download_from_original_stable_diffusion_ckpt(
) )
pipe = StableDiffusionXLPipeline ( pipe = StableDiffusionXLPipeline (
vae=vae, vae=vae.to(precision),
text_encoder=text_encoder, text_encoder=text_encoder,
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder_2=text_encoder_2, text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
force_zeros_for_empty_prompt=True, force_zeros_for_empty_prompt=True,
) )
@ -1598,12 +1600,12 @@ def download_from_original_stable_diffusion_ckpt(
) )
pipe = StableDiffusionXLImg2ImgPipeline( pipe = StableDiffusionXLImg2ImgPipeline(
vae=vae, vae=vae.to(precision),
text_encoder=text_encoder, text_encoder=text_encoder,
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder_2=text_encoder_2, text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
requires_aesthetics_score=True, requires_aesthetics_score=True,
force_zeros_for_empty_prompt=False, force_zeros_for_empty_prompt=False,
@ -1622,6 +1624,7 @@ def download_controlnet_from_original_ckpt(
original_config_file: str, original_config_file: str,
image_size: int = 512, image_size: int = 512,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float16,
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
device: str = None, device: str = None,

View File

@ -237,6 +237,7 @@ class StableDiffusion2Model(DiffusersModel):
return model_path return model_path
# TODO: rework # TODO: rework
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache( def _convert_ckpt_and_cache(
version: BaseModelType, version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, model_config: Union[StableDiffusion1Model.CheckpointConfig,