SDXL checkpoint models now convert and load; needs refactor

This commit is contained in:
Lincoln Stein 2023-07-23 00:00:31 -04:00
parent b1d7c9b306
commit 5e59edfaf1
2 changed files with 23 additions and 19 deletions

View File

@ -1071,6 +1071,7 @@ def convert_controlnet_checkpoint(
extract_ema,
use_linear_projection=None,
cross_attention_dim=None,
precision: torch.dtype=torch.float32,
):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention
@ -1108,9 +1109,9 @@ def convert_controlnet_checkpoint(
controlnet.load_state_dict(converted_ctrl_checkpoint)
return controlnet
return controlnet.to(precision)
# TO DO - PASS PRECISION
def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
model_version: BaseModelType,
@ -1120,6 +1121,7 @@ def download_from_original_stable_diffusion_ckpt(
prediction_type: str = None,
model_type: str = None,
extract_ema: bool = False,
precision: torch.dtype = torch.float16,
scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None,
@ -1395,11 +1397,11 @@ def download_from_original_stable_diffusion_ckpt(
with ctx():
unet = UNet2DConditionModel(**unet_config)
# if is_accelerate_available():
# for param_name, param in converted_unet_checkpoint.items():
# set_module_tensor_to_device(unet, param_name, "cpu", value=param)
# else:
unet.load_state_dict(converted_unet_checkpoint)
if is_accelerate_available():
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
else:
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model.
if vae_path is None:
@ -1439,10 +1441,10 @@ def download_from_original_stable_diffusion_ckpt(
if stable_unclip is None:
if controlnet:
pipe = pipeline_class(
vae=vae,
vae=vae.to(precision),
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
unet=unet.to(precision),
scheduler=scheduler,
controlnet=controlnet,
safety_checker=None,
@ -1451,10 +1453,10 @@ def download_from_original_stable_diffusion_ckpt(
)
else:
pipe = pipeline_class(
vae=vae,
vae=vae.to(precision),
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
unet=unet.to(precision),
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
@ -1542,10 +1544,10 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet:
pipe = pipeline_class(
vae=vae,
vae=vae.to(precision),
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
unet=unet.to(precision),
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
@ -1553,10 +1555,10 @@ def download_from_original_stable_diffusion_ckpt(
)
else:
pipe = pipeline_class(
vae=vae,
vae=vae.to(precision),
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
unet=unet.to(precision),
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
@ -1576,12 +1578,12 @@ def download_from_original_stable_diffusion_ckpt(
)
pipe = StableDiffusionXLPipeline (
vae=vae,
vae=vae.to(precision),
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
unet=unet.to(precision),
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
@ -1598,12 +1600,12 @@ def download_from_original_stable_diffusion_ckpt(
)
pipe = StableDiffusionXLImg2ImgPipeline(
vae=vae,
vae=vae.to(precision),
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
unet=unet.to(precision),
scheduler=scheduler,
requires_aesthetics_score=True,
force_zeros_for_empty_prompt=False,
@ -1622,6 +1624,7 @@ def download_controlnet_from_original_ckpt(
original_config_file: str,
image_size: int = 512,
extract_ema: bool = False,
precision: torch.dtype = torch.float16,
num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None,
device: str = None,

View File

@ -237,6 +237,7 @@ class StableDiffusion2Model(DiffusersModel):
return model_path
# TODO: rework
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig,