From b03073d888b5f8cd77a119a4d8612c235c480859 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 19 Jun 2024 22:57:27 -0400 Subject: [PATCH] [MM] Add support for probing and loading SDXL VAE checkpoint files (#6524) * add support for probing and loading SDXL VAE checkpoint files * broaden regexp probe for SDXL VAEs --------- Co-authored-by: Lincoln Stein --- .../backend/model_manager/load/model_loaders/vae.py | 11 +++-------- invokeai/backend/model_manager/probe.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 122b2f0797..f51c551f09 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint) class VAELoader(GenericDiffusersLoader): """Class to load VAE models.""" @@ -40,12 +39,8 @@ class VAELoader(GenericDiffusersLoader): return True def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel: - # TODO(MM2): check whether sdxl VAE models convert. - if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: - raise Exception(f"VAE conversion not supported for model type: {config.base}") - else: - assert isinstance(config, CheckpointConfigBase) - config_file = self._app_config.legacy_conf_path / config.config_path + assert isinstance(config, CheckpointConfigBase) + config_file = self._app_config.legacy_conf_path / config.config_path if model_path.suffix == ".safetensors": checkpoint = safetensors_load_file(model_path, device="cpu") diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 4aef281bac..a19a772764 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase): class VaeCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1 + # VAEs of all base types have the same structure, so we wimp out and + # guess using the name. + for regexp, basetype in [ + (r"xl", BaseModelType.StableDiffusionXL), + (r"sd2", BaseModelType.StableDiffusion2), + (r"vae", BaseModelType.StableDiffusion1), + ]: + if re.search(regexp, self.model_path.name, re.IGNORECASE): + return basetype + raise InvalidModelConfigException("Cannot determine base type") class LoRACheckpointProbe(CheckpointProbeBase):