diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index f8a03f5670..11f165c65a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -13,6 +13,7 @@ from invokeai.backend.model_manager import ( ModelFormat, ModelType, ) +from invokeai.backend.model_manager.config import CheckpointConfigBase from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from .. import ModelLoaderRegistry @@ -26,7 +27,7 @@ class VaeLoader(GenericDiffusersLoader): """Class to load VAE models.""" def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: - if config.format != ModelFormat.Checkpoint: + if not isinstance(config, CheckpointConfigBase): return False elif ( dest_path.exists() @@ -38,13 +39,12 @@ class VaeLoader(GenericDiffusersLoader): return True def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: - # TO DO: check whether sdxl VAE models convert. + # TODO(MM2): check whether sdxl VAE models convert. if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: - raise Exception(f"Vae conversion not supported for model type: {config.base}") + raise Exception(f"VAE conversion not supported for model type: {config.base}") else: - config_file = ( - "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" - ) + assert isinstance(config, CheckpointConfigBase) + config_file = config.config if model_path.suffix == ".safetensors": checkpoint = safetensors_load_file(model_path, device="cpu") @@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader): if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) + ckpt_config = OmegaConf.load(self._app_config.root_path / config_file) assert isinstance(ckpt_config, DictConfig) vae_model = convert_ldm_vae_to_diffusers( diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index a7250f33d1..608d9e3c59 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -137,7 +137,7 @@ class ModelProbe(object): format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint model_info = None model_type = None - if format_type == "diffusers": + if format_type is ModelFormat.Diffusers: model_type = cls.get_model_type_from_folder(model_path) else: model_type = cls.get_model_type_from_checkpoint(model_path) @@ -168,7 +168,7 @@ class ModelProbe(object): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models - if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: + if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] and fields["format"] == ModelFormat.Checkpoint: fields["config"] = cls._get_checkpoint_config_path( model_path, model_type=fields["type"], @@ -285,13 +285,21 @@ class ModelProbe(object): if possible_conf.exists(): return possible_conf.absolute() - if model_type == ModelType.Main: + if model_type is ModelType.Main: config_file = LEGACY_CONFIGS[base_type][variant_type] if isinstance(config_file, dict): # need another tier for sd-2.x models config_file = config_file[prediction_type] - elif model_type == ModelType.ControlNet: + elif model_type is ModelType.ControlNet: config_file = ( - "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml" + "../controlnet/cldm_v15.yaml" + if base_type is BaseModelType.StableDiffusion1 + else "../controlnet/cldm_v21.yaml" + ) + elif model_type is ModelType.Vae: + config_file = ( + "../stable-diffusion/v1-inference.yaml" + if base_type is BaseModelType.StableDiffusion1 + else "../stable-diffusion/v2-inference.yaml" ) else: raise InvalidModelConfigException(