mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
SDXL checkpoint models now convert and load; needs refactor
This commit is contained in:
parent
b1d7c9b306
commit
5e59edfaf1
@ -1071,6 +1071,7 @@ def convert_controlnet_checkpoint(
|
||||
extract_ema,
|
||||
use_linear_projection=None,
|
||||
cross_attention_dim=None,
|
||||
precision: torch.dtype=torch.float32,
|
||||
):
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
@ -1108,9 +1109,9 @@ def convert_controlnet_checkpoint(
|
||||
|
||||
controlnet.load_state_dict(converted_ctrl_checkpoint)
|
||||
|
||||
return controlnet
|
||||
|
||||
return controlnet.to(precision)
|
||||
|
||||
# TO DO - PASS PRECISION
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
model_version: BaseModelType,
|
||||
@ -1120,6 +1121,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
prediction_type: str = None,
|
||||
model_type: str = None,
|
||||
extract_ema: bool = False,
|
||||
precision: torch.dtype = torch.float16,
|
||||
scheduler_type: str = "pndm",
|
||||
num_in_channels: Optional[int] = None,
|
||||
upcast_attention: Optional[bool] = None,
|
||||
@ -1395,11 +1397,11 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
with ctx():
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
# if is_accelerate_available():
|
||||
# for param_name, param in converted_unet_checkpoint.items():
|
||||
# set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
# else:
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
else:
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
if vae_path is None:
|
||||
@ -1439,10 +1441,10 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if stable_unclip is None:
|
||||
if controlnet:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
@ -1451,10 +1453,10 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
else:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
@ -1542,10 +1544,10 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if controlnet:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
@ -1553,10 +1555,10 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
else:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
@ -1576,12 +1578,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline (
|
||||
vae=vae,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
@ -1598,12 +1600,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=vae,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
requires_aesthetics_score=True,
|
||||
force_zeros_for_empty_prompt=False,
|
||||
@ -1622,6 +1624,7 @@ def download_controlnet_from_original_ckpt(
|
||||
original_config_file: str,
|
||||
image_size: int = 512,
|
||||
extract_ema: bool = False,
|
||||
precision: torch.dtype = torch.float16,
|
||||
num_in_channels: Optional[int] = None,
|
||||
upcast_attention: Optional[bool] = None,
|
||||
device: str = None,
|
||||
|
@ -237,6 +237,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
return model_path
|
||||
|
||||
# TODO: rework
|
||||
# pass precision - currently defaulting to fp16
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[StableDiffusion1Model.CheckpointConfig,
|
||||
|
Loading…
Reference in New Issue
Block a user