import os import json from enum import Enum from pydantic import Field from pathlib import Path from typing import Literal, Optional, Union from .base import ( ModelBase, ModelConfigBase, BaseModelType, ModelType, SubModelType, ModelVariantType, DiffusersModel, SchedulerPredictionType, SilenceWarnings, read_checkpoint_meta, classproperty, ) from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf class StableDiffusion1ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" class StableDiffusion1Model(DiffusersModel): class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusion1ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType class CheckpointConfig(ModelConfigBase): model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: str variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.Main super().__init__( model_path=model_path, base_model=BaseModelType.StableDiffusion1, model_type=ModelType.Main, ) @classmethod def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) if model_format == StableDiffusion1ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] else: checkpoint = read_checkpoint_meta(path) checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] elif model_format == StableDiffusion1ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: unet_config = json.loads(f.read()) in_channels = unet_config['in_channels'] else: raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format") else: raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}") if in_channels == 9: variant = ModelVariantType.Inpaint elif in_channels == 4: variant = ModelVariantType.Normal else: raise Exception("Unkown stable diffusion 1.* model format") if ckpt_config_path is None: ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant) return cls.create_config( path=path, model_format=model_format, config=ckpt_config_path, variant=variant, ) @classproperty def save_to_config(cls) -> bool: return True @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): return StableDiffusion1ModelFormat.Diffusers else: return StableDiffusion1ModelFormat.Checkpoint @classmethod def convert_if_required( cls, model_path: str, output_path: str, config: ModelConfigBase, base_model: BaseModelType, ) -> str: if isinstance(config, cls.CheckpointConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion1, model_config=config, output_path=output_path, ) else: return model_path class StableDiffusion2ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusion2ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType class CheckpointConfig(ModelConfigBase): model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: str variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion2 assert model_type == ModelType.Main super().__init__( model_path=model_path, base_model=BaseModelType.StableDiffusion2, model_type=ModelType.Main, ) @classmethod def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) if model_format == StableDiffusion2ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] else: checkpoint = read_checkpoint_meta(path) checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] elif model_format == StableDiffusion2ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: unet_config = json.loads(f.read()) in_channels = unet_config['in_channels'] else: raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") else: raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") if in_channels == 9: variant = ModelVariantType.Inpaint elif in_channels == 5: variant = ModelVariantType.Depth elif in_channels == 4: variant = ModelVariantType.Normal else: raise Exception("Unkown stable diffusion 2.* model format") if ckpt_config_path is None: ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant) return cls.create_config( path=path, model_format=model_format, config=ckpt_config_path, variant=variant, ) @classproperty def save_to_config(cls) -> bool: return True @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): return StableDiffusion2ModelFormat.Diffusers else: return StableDiffusion2ModelFormat.Checkpoint @classmethod def convert_if_required( cls, model_path: str, output_path: str, config: ModelConfigBase, base_model: BaseModelType, ) -> str: if isinstance(config, cls.CheckpointConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion2, model_config=config, output_path=output_path, ) else: return model_path def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): ckpt_configs = { BaseModelType.StableDiffusion1: { ModelVariantType.Normal: "v1-inference.yaml", ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", ModelVariantType.Depth: "v2-midas-inference.yaml", } } app_config = InvokeAIAppConfig.get_config() try: config_path = app_config.legacy_conf_path / ckpt_configs[version][variant] if config_path.is_relative_to(app_config.root_path): config_path = config_path.relative_to(app_config.root_path) return str(config_path) except: return None # TODO: rework def _convert_ckpt_and_cache( version: BaseModelType, model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig], output_path: str, ) -> str: """ Convert the checkpoint model indicated in mconfig into a diffusers, cache it to disk, and return Path to converted file. If already on disk then just returns Path. """ app_config = InvokeAIAppConfig.get_config() weights = app_config.root_path / model_config.path config_file = app_config.root_path / model_config.config output_path = Path(output_path) # return cached version if it exists if output_path.exists(): return output_path # to avoid circular import errors from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers with SilenceWarnings(): convert_ckpt_to_diffusers( weights, output_path, model_version=version, model_variant=model_config.variant, original_config_file=config_file, extract_ema=True, scan_needed=True, ) return output_path