diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 122b2f0797..f51c551f09 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint) class VAELoader(GenericDiffusersLoader): """Class to load VAE models.""" @@ -40,12 +39,8 @@ class VAELoader(GenericDiffusersLoader): return True def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel: - # TODO(MM2): check whether sdxl VAE models convert. - if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: - raise Exception(f"VAE conversion not supported for model type: {config.base}") - else: - assert isinstance(config, CheckpointConfigBase) - config_file = self._app_config.legacy_conf_path / config.config_path + assert isinstance(config, CheckpointConfigBase) + config_file = self._app_config.legacy_conf_path / config.config_path if model_path.suffix == ".safetensors": checkpoint = safetensors_load_file(model_path, device="cpu") diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 4aef281bac..a19a772764 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase): class VaeCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1 + # VAEs of all base types have the same structure, so we wimp out and + # guess using the name. + for regexp, basetype in [ + (r"xl", BaseModelType.StableDiffusionXL), + (r"sd2", BaseModelType.StableDiffusion2), + (r"vae", BaseModelType.StableDiffusion1), + ]: + if re.search(regexp, self.model_path.name, re.IGNORECASE): + return basetype + raise InvalidModelConfigException("Cannot determine base type") class LoRACheckpointProbe(CheckpointProbeBase):