convert script handles more ckpt variants

This commit is contained in:
Lincoln Stein 2023-07-29 15:28:39 -04:00
parent 72c519c6ad
commit 2a2d988928
3 changed files with 35 additions and 21 deletions

View File

@ -422,8 +422,11 @@ def convert_ldm_unet_checkpoint(
) )
for key in keys: for key in keys:
if key.startswith("model.diffusion_model"): if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + ".".join(key.split(".")[1:]) for delimiter in ['','.']:
flat_ema_key = "model_ema." + delimiter.join(key.split(".")[1:])
if checkpoint.get(flat_ema_key) is not None:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
break
else: else:
if sum(k.startswith("model_ema") for k in keys) > 100: if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning( logger.warning(
@ -1070,7 +1073,7 @@ def convert_controlnet_checkpoint(
extract_ema, extract_ema,
use_linear_projection=None, use_linear_projection=None,
cross_attention_dim=None, cross_attention_dim=None,
precision: torch.dtype = torch.float32, precision: torch.dtype = None,
): ):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config["upcast_attention"] = upcast_attention
@ -1121,7 +1124,7 @@ def download_from_original_stable_diffusion_ckpt(
prediction_type: str = None, prediction_type: str = None,
model_type: str = None, model_type: str = None,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float32, precision: torch.dtype = None,
scheduler_type: str = "pndm", scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
@ -1194,6 +1197,8 @@ def download_from_original_stable_diffusion_ckpt(
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
needed. needed.
precision (`torch.dtype`, *optional*, defauts to `None`):
If not provided the precision will be set to the precision of the original file.
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
@ -1252,6 +1257,10 @@ def download_from_original_stable_diffusion_ckpt(
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias"
logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}")
precision = precision or checkpoint[precision_probing_key].dtype
if original_config_file is None: if original_config_file is None:
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
@ -1281,7 +1290,7 @@ def download_from_original_stable_diffusion_ckpt(
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if ( if (
model_version == BaseModelType.StableDiffusion2 model_version == BaseModelType.StableDiffusion2
and original_config["model"]["params"]["parameterization"] == "v" and original_config["model"]["params"].get("parameterization") == "v"
): ):
prediction_type = "v_prediction" prediction_type = "v_prediction"
upcast_attention = True upcast_attention = True
@ -1447,7 +1456,7 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1459,7 +1468,7 @@ def download_from_original_stable_diffusion_ckpt(
else: else:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1484,8 +1493,8 @@ def download_from_original_stable_diffusion_ckpt(
image_noising_scheduler=image_noising_scheduler, image_noising_scheduler=image_noising_scheduler,
# regular denoising components # regular denoising components
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_model, text_encoder=text_model.to(precision),
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
# vae # vae
vae=vae, vae=vae,
@ -1560,7 +1569,7 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
controlnet=controlnet, controlnet=controlnet,
@ -1571,7 +1580,7 @@ def download_from_original_stable_diffusion_ckpt(
else: else:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1594,9 +1603,9 @@ def download_from_original_stable_diffusion_ckpt(
pipe = StableDiffusionXLPipeline( pipe = StableDiffusionXLPipeline(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_encoder, text_encoder=text_encoder.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder_2=text_encoder_2, text_encoder_2=text_encoder_2.to(precision),
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
@ -1639,7 +1648,7 @@ def download_controlnet_from_original_ckpt(
original_config_file: str, original_config_file: str,
image_size: int = 512, image_size: int = 512,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float32, precision: torch.dtype = None,
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
device: str = None, device: str = None,
@ -1680,6 +1689,12 @@ def download_controlnet_from_original_ckpt(
while "state_dict" in checkpoint: while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
# use original precision
precision_probing_key = 'input_blocks.0.0.bias'
ckpt_precision = checkpoint[precision_probing_key].dtype
logger.debug(f'original controlnet precision = {ckpt_precision}')
precision = precision or ckpt_precision
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
@ -1699,7 +1714,7 @@ def download_controlnet_from_original_ckpt(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
return controlnet return controlnet.to(precision)
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:

View File

@ -17,7 +17,7 @@ from .base import (
ModelNotFoundException, ModelNotFoundException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger
class ControlNetModelFormat(str, Enum): class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
@ -66,7 +66,7 @@ class ControlNetModel(ModelBase):
child_type: Optional[SubModelType] = None, child_type: Optional[SubModelType] = None,
): ):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in controlnet model") raise Exception("There are no child models in controlnet model")
model = None model = None
for variant in ["fp16", None]: for variant in ["fp16", None]:
@ -123,10 +123,7 @@ class ControlNetModel(ModelBase):
else: else:
return model_path return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache( def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str, model_path: str,
output_path: str, output_path: str,
base_model: BaseModelType, base_model: BaseModelType,
@ -141,6 +138,7 @@ def _convert_controlnet_ckpt_and_cache(
weights = app_config.root_path / model_path weights = app_config.root_path / model_path
output_path = Path(output_path) output_path = Path(output_path)
logger.info(f"Converting {weights} to diffusers format")
# return cached version if it exists # return cached version if it exists
if output_path.exists(): if output_path.exists():
return output_path return output_path

View File

@ -123,6 +123,7 @@ class StableDiffusion1Model(DiffusersModel):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1, version=BaseModelType.StableDiffusion1,
model_config=config, model_config=config,
load_safety_checker=False,
output_path=output_path, output_path=output_path,
) )
else: else: