mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
convert script handles more ckpt variants
This commit is contained in:
parent
72c519c6ad
commit
2a2d988928
@ -422,8 +422,11 @@ def convert_ldm_unet_checkpoint(
|
|||||||
)
|
)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
flat_ema_key = "model_ema." + ".".join(key.split(".")[1:])
|
for delimiter in ['','.']:
|
||||||
|
flat_ema_key = "model_ema." + delimiter.join(key.split(".")[1:])
|
||||||
|
if checkpoint.get(flat_ema_key) is not None:
|
||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1070,7 +1073,7 @@ def convert_controlnet_checkpoint(
|
|||||||
extract_ema,
|
extract_ema,
|
||||||
use_linear_projection=None,
|
use_linear_projection=None,
|
||||||
cross_attention_dim=None,
|
cross_attention_dim=None,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = None,
|
||||||
):
|
):
|
||||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||||
@ -1121,7 +1124,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
prediction_type: str = None,
|
prediction_type: str = None,
|
||||||
model_type: str = None,
|
model_type: str = None,
|
||||||
extract_ema: bool = False,
|
extract_ema: bool = False,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = None,
|
||||||
scheduler_type: str = "pndm",
|
scheduler_type: str = "pndm",
|
||||||
num_in_channels: Optional[int] = None,
|
num_in_channels: Optional[int] = None,
|
||||||
upcast_attention: Optional[bool] = None,
|
upcast_attention: Optional[bool] = None,
|
||||||
@ -1194,6 +1197,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
||||||
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
|
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
|
||||||
needed.
|
needed.
|
||||||
|
precision (`torch.dtype`, *optional*, defauts to `None`):
|
||||||
|
If not provided the precision will be set to the precision of the original file.
|
||||||
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1252,6 +1257,10 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
|
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
|
||||||
|
|
||||||
|
precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias"
|
||||||
|
logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}")
|
||||||
|
precision = precision or checkpoint[precision_probing_key].dtype
|
||||||
|
|
||||||
if original_config_file is None:
|
if original_config_file is None:
|
||||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||||
@ -1281,7 +1290,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
if (
|
if (
|
||||||
model_version == BaseModelType.StableDiffusion2
|
model_version == BaseModelType.StableDiffusion2
|
||||||
and original_config["model"]["params"]["parameterization"] == "v"
|
and original_config["model"]["params"].get("parameterization") == "v"
|
||||||
):
|
):
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
@ -1447,7 +1456,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
if controlnet:
|
if controlnet:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model,
|
text_encoder=text_model.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -1459,7 +1468,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
else:
|
else:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model,
|
text_encoder=text_model.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -1484,8 +1493,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
image_noising_scheduler=image_noising_scheduler,
|
image_noising_scheduler=image_noising_scheduler,
|
||||||
# regular denoising components
|
# regular denoising components
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_model,
|
text_encoder=text_model.to(precision),
|
||||||
unet=unet,
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
# vae
|
# vae
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@ -1560,7 +1569,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
if controlnet:
|
if controlnet:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model,
|
text_encoder=text_model.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
@ -1571,7 +1580,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
else:
|
else:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model,
|
text_encoder=text_model.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -1594,9 +1603,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
pipe = StableDiffusionXLPipeline(
|
pipe = StableDiffusionXLPipeline(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder_2=text_encoder_2,
|
text_encoder_2=text_encoder_2.to(precision),
|
||||||
tokenizer_2=tokenizer_2,
|
tokenizer_2=tokenizer_2,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -1639,7 +1648,7 @@ def download_controlnet_from_original_ckpt(
|
|||||||
original_config_file: str,
|
original_config_file: str,
|
||||||
image_size: int = 512,
|
image_size: int = 512,
|
||||||
extract_ema: bool = False,
|
extract_ema: bool = False,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = None,
|
||||||
num_in_channels: Optional[int] = None,
|
num_in_channels: Optional[int] = None,
|
||||||
upcast_attention: Optional[bool] = None,
|
upcast_attention: Optional[bool] = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
@ -1680,6 +1689,12 @@ def download_controlnet_from_original_ckpt(
|
|||||||
while "state_dict" in checkpoint:
|
while "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
# use original precision
|
||||||
|
precision_probing_key = 'input_blocks.0.0.bias'
|
||||||
|
ckpt_precision = checkpoint[precision_probing_key].dtype
|
||||||
|
logger.debug(f'original controlnet precision = {ckpt_precision}')
|
||||||
|
precision = precision or ckpt_precision
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
if num_in_channels is not None:
|
if num_in_channels is not None:
|
||||||
@ -1699,7 +1714,7 @@ def download_controlnet_from_original_ckpt(
|
|||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
return controlnet
|
return controlnet.to(precision)
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
||||||
|
@ -17,7 +17,7 @@ from .base import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
class ControlNetModelFormat(str, Enum):
|
class ControlNetModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
@ -66,7 +66,7 @@ class ControlNetModel(ModelBase):
|
|||||||
child_type: Optional[SubModelType] = None,
|
child_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in controlnet model")
|
raise Exception("There are no child models in controlnet model")
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
for variant in ["fp16", None]:
|
for variant in ["fp16", None]:
|
||||||
@ -123,10 +123,7 @@ class ControlNetModel(ModelBase):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _convert_controlnet_ckpt_and_cache(
|
def _convert_controlnet_ckpt_and_cache(
|
||||||
cls,
|
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
@ -141,6 +138,7 @@ def _convert_controlnet_ckpt_and_cache(
|
|||||||
weights = app_config.root_path / model_path
|
weights = app_config.root_path / model_path
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
|
||||||
|
logger.info(f"Converting {weights} to diffusers format")
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -123,6 +123,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion1,
|
version=BaseModelType.StableDiffusion1,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
|
load_safety_checker=False,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user