fix(mm): add config field to ckpt vaes

This commit is contained in:
psychedelicious 2024-03-01 15:23:18 +11:00
parent 76cbc745e1
commit 16a5d718bf
2 changed files with 20 additions and 12 deletions

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -26,7 +27,7 @@ class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint: if not isinstance(config, CheckpointConfigBase):
return False return False
elif ( elif (
dest_path.exists() dest_path.exists()
@ -38,13 +39,12 @@ class VaeLoader(GenericDiffusersLoader):
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert. # TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}") raise Exception(f"VAE conversion not supported for model type: {config.base}")
else: else:
config_file = ( assert isinstance(config, CheckpointConfigBase)
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" config_file = config.config
)
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu") checkpoint = safetensors_load_file(model_path, device="cpu")
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
assert isinstance(ckpt_config, DictConfig) assert isinstance(ckpt_config, DictConfig)
vae_model = convert_ldm_vae_to_diffusers( vae_model = convert_ldm_vae_to_diffusers(

View File

@ -137,7 +137,7 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None model_info = None
model_type = None model_type = None
if format_type == "diffusers": if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path) model_type = cls.get_model_type_from_folder(model_path)
else: else:
model_type = cls.get_model_type_from_checkpoint(model_path) model_type = cls.get_model_type_from_checkpoint(model_path)
@ -168,7 +168,7 @@ class ModelProbe(object):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path( fields["config"] = cls._get_checkpoint_config_path(
model_path, model_path,
model_type=fields["type"], model_type=fields["type"],
@ -285,13 +285,21 @@ class ModelProbe(object):
if possible_conf.exists(): if possible_conf.exists():
return possible_conf.absolute() return possible_conf.absolute()
if model_type == ModelType.Main: if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type] config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet: elif model_type is ModelType.ControlNet:
config_file = ( config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml" "../controlnet/cldm_v15.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../controlnet/cldm_v21.yaml"
)
elif model_type is ModelType.Vae:
config_file = (
"../stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../stable-diffusion/v2-inference.yaml"
) )
else: else:
raise InvalidModelConfigException( raise InvalidModelConfigException(