import os import json from enum import Enum from pydantic import Field from pathlib import Path from typing import Literal, Optional, Union from .base import ( ModelConfigBase, BaseModelType, ModelType, ModelVariantType, DiffusersModel, SilenceWarnings, read_checkpoint_meta, classproperty, InvalidModelException, ModelNotFoundException, ) from .sdxl import StableDiffusionXLModel import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf class StableDiffusion1ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" class StableDiffusion1Model(DiffusersModel): class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusion1ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType class CheckpointConfig(ModelConfigBase): model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: str variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.Main super().__init__( model_path=model_path, base_model=BaseModelType.StableDiffusion1, model_type=ModelType.Main, ) @classmethod def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) if model_format == StableDiffusion1ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] else: checkpoint = read_checkpoint_meta(path) checkpoint = checkpoint.get("state_dict", checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] elif model_format == StableDiffusion1ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: unet_config = json.loads(f.read()) in_channels = unet_config["in_channels"] else: raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format") else: raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}") if in_channels == 9: variant = ModelVariantType.Inpaint elif in_channels == 4: variant = ModelVariantType.Normal else: raise Exception("Unkown stable diffusion 1.* model format") if ckpt_config_path is None: ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant) return cls.create_config( path=path, model_format=model_format, config=ckpt_config_path, variant=variant, ) @classproperty def save_to_config(cls) -> bool: return True @classmethod def detect_format(cls, model_path: str): if not os.path.exists(model_path): raise ModelNotFoundException() if os.path.isdir(model_path): if os.path.exists(os.path.join(model_path, "model_index.json")): return StableDiffusion1ModelFormat.Diffusers if os.path.isfile(model_path): if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): return StableDiffusion1ModelFormat.Checkpoint raise InvalidModelException(f"Not a valid model: {model_path}") @classmethod def convert_if_required( cls, model_path: str, output_path: str, config: ModelConfigBase, base_model: BaseModelType, ) -> str: if isinstance(config, cls.CheckpointConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion1, model_config=config, output_path=output_path, ) else: return model_path class StableDiffusion2ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusion2ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType class CheckpointConfig(ModelConfigBase): model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: str variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion2 assert model_type == ModelType.Main super().__init__( model_path=model_path, base_model=BaseModelType.StableDiffusion2, model_type=ModelType.Main, ) @classmethod def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) if model_format == StableDiffusion2ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] else: checkpoint = read_checkpoint_meta(path) checkpoint = checkpoint.get("state_dict", checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] elif model_format == StableDiffusion2ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: unet_config = json.loads(f.read()) in_channels = unet_config["in_channels"] else: raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") else: raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") if in_channels == 9: variant = ModelVariantType.Inpaint elif in_channels == 5: variant = ModelVariantType.Depth elif in_channels == 4: variant = ModelVariantType.Normal else: raise Exception("Unkown stable diffusion 2.* model format") if ckpt_config_path is None: ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant) return cls.create_config( path=path, model_format=model_format, config=ckpt_config_path, variant=variant, ) @classproperty def save_to_config(cls) -> bool: return True @classmethod def detect_format(cls, model_path: str): if not os.path.exists(model_path): raise ModelNotFoundException() if os.path.isdir(model_path): if os.path.exists(os.path.join(model_path, "model_index.json")): return StableDiffusion2ModelFormat.Diffusers if os.path.isfile(model_path): if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): return StableDiffusion2ModelFormat.Checkpoint raise InvalidModelException(f"Not a valid model: {model_path}") @classmethod def convert_if_required( cls, model_path: str, output_path: str, config: ModelConfigBase, base_model: BaseModelType, ) -> str: if isinstance(config, cls.CheckpointConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion2, model_config=config, output_path=output_path, ) else: return model_path # TODO: rework # pass precision - currently defaulting to fp16 def _convert_ckpt_and_cache( version: BaseModelType, model_config: Union[ StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig, StableDiffusionXLModel.CheckpointConfig, ], output_path: str, use_save_model: bool = False, **kwargs, ) -> str: """ Convert the checkpoint model indicated in mconfig into a diffusers, cache it to disk, and return Path to converted file. If already on disk then just returns Path. """ app_config = InvokeAIAppConfig.get_config() weights = app_config.models_path / model_config.path config_file = app_config.root_path / model_config.config output_path = Path(output_path) # return cached version if it exists if output_path.exists(): return output_path # to avoid circular import errors from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from ...util.devices import choose_torch_device, torch_dtype model_base_to_model_type = { BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", BaseModelType.StableDiffusionXL: "SDXL", BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", } logger.info(f"Converting {weights} to diffusers format") with SilenceWarnings(): convert_ckpt_to_diffusers( weights, output_path, model_type=model_base_to_model_type[version], model_version=version, model_variant=model_config.variant, original_config_file=config_file, extract_ema=True, scan_needed=True, from_safetensors=weights.suffix == ".safetensors", precision=torch_dtype(choose_torch_device()), **kwargs, ) return output_path def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): ckpt_configs = { BaseModelType.StableDiffusion1: { ModelVariantType.Normal: "v1-inference.yaml", ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", ModelVariantType.Depth: "v2-midas-inference.yaml", }, BaseModelType.StableDiffusionXL: { ModelVariantType.Normal: "sd_xl_base.yaml", ModelVariantType.Inpaint: None, ModelVariantType.Depth: None, }, BaseModelType.StableDiffusionXLRefiner: { ModelVariantType.Normal: "sd_xl_refiner.yaml", ModelVariantType.Inpaint: None, ModelVariantType.Depth: None, }, } app_config = InvokeAIAppConfig.get_config() try: config_path = app_config.legacy_conf_path / ckpt_configs[version][variant] if config_path.is_relative_to(app_config.root_path): config_path = config_path.relative_to(app_config.root_path) return str(config_path) except: return None