diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 61351a1851..55782bc445 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -9,5 +9,8 @@ from .generator import ( Img2Img, Inpaint ) -from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubmodelType, ModelInfo +from .model_management import ( + ModelManager, ModelCache, BaseModelType, + ModelType, SubModelType, ModelInfo + ) from .safety_checker import SafetyChecker diff --git a/invokeai/backend/model_management/model_install.py b/invokeai/backend/model_management/model_install.py index e1ac441fee..6016f5f3f5 100644 --- a/invokeai/backend/model_management/model_install.py +++ b/invokeai/backend/model_management/model_install.py @@ -11,191 +11,87 @@ from diffusers import ModelMixin from enum import Enum from typing import Callable from pathlib import Path -from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger +from invokeai.app.services.config import InvokeAIAppConfig from .models import BaseModelType, ModelType +from .model_probe import ModelProbe, ModelVariantInfo -class CheckpointProbe(object): - PROBES = dict() # see below for redefinition - +class ModelInstall(object): + ''' + This class is able to download and install several different kinds of + InvokeAI models. The helper function, if provided, is called on to distinguish + between v2-base and v2-768 stable diffusion pipelines. This usually involves + asking the user to select the proper type, as there is no way of distinguishing + the two type of v2 file programmatically (as far as I know). + ''' def __init__(self, - checkpoint_path: Path, - checkpoint: dict = None, - helper: Callable[[Path], BaseModelType]=None + config: InvokeAIAppConfig, + model_base_helper: Callable[[Path],BaseModelType]=None, + clobber:bool = False ): - checkpoint = checkpoint or self._scan_and_load_checkpoint(self.checkpoint_path) - self.checkpoint = checkpoint - self.checkpoint_path = checkpoint_path - self.helper = helper + ''' + :param config: InvokeAI configuration object + :param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum + :param clobber: If true, models with colliding names will be overwritten + ''' + self.config = config + self.clogger = clobber + self.helper = model_base_helper + self.prober = ModelProbe() + + def install_checkpoint_file(self, checkpoint: Path)->dict: + ''' + Install the checkpoint file at path and return a + configuration entry that can be added to `models.yaml`. + Model checkpoints and VAEs will be converted into + diffusers before installation. Note that the model manager + does not hold entries for anything but diffusers pipelines, + and the configuration file stanzas returned from such models + can be safely ignored. + ''' + model_info = self.prober.probe(checkpoint, self.helper) + if not model_info: + raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}") + + # non-pipeline; no conversion needed, just copy into right place + if model_info.model_type != ModelType.Pipeline: + destination_path = self._dest_path(model_info) / checkpoint.name + self._check_for_collision(destination_path) + shutil.copyfile(checkpoint, destination_path) + key = ModelManager.create_key( + model_name = checkpoint.stem, + base_model = model_info.base_type + model_type = model_info.model_type + ) + return { + key: dict( + name = model_name, + description = f'{model_info.model_type} model {model_name}', + path = str(destination_path), + format = 'checkpoint', + base = str(base_model), + type = str(model_type), + variant = str(model_info.variant_type), + ) + } + + + destination_path = self._dest_path(model_info) / checkpoint.stem + + + + + def _check_for_collision(self, path: Path): + if not path.exists(): + return + if self.clobber: + shutil.rmtree(path) + else: + raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.") + + def _staging_directory(self)->tempfile.TemporaryDirectory: + return tempfile.TemporaryDirectory(dir=self.config.root_path) + - def probe(self) -> ModelVariantInfo: - ''' - Probes the checkpoint at path `checkpoint_path` and return - a ModelType object indicating the model base, model type and - model variant for the checkpoint. - ''' - checkpoint = self.checkpoint - state_dict = checkpoint.get("state_dict") or checkpoint - - model_info = None - - try: - model_type = self.get_checkpoint_type(state_dict) - if not model_type: - if self.checkpoint_path.name == "learned_embeds.bin": - model_type = ModelType.TextualInversion - else: - return None # we give up - probe = self.PROBES[model_type]() - base_type = probe.get_base_type(checkpoint, self.checkpoint_path, self.helper) - variant_type = probe.get_variant_type(model_type, checkpoint) - - model_info = ModelVariantInfo( - model_type = model_type, - base_type = base_type, - variant_type = variant_type, - ) - except (KeyError, ValueError) as e: - logger.error(f'An error occurred while probing {self.checkpoint_path}: {str(e)}') - logger.error(traceback.format_exc()) - - return model_info - - class CheckpointProbeBase(object): - def get_base_type(self, - checkpoint: dict, - checkpoint_path: Path = None, - helper: Callable[[Path],BaseModelType] = None - )->BaseModelType: - pass - def get_variant_type(self, - model_type: ModelType, - checkpoint: dict, - )-> VariantType: - if model_type != ModelType.Pipeline: - return None - state_dict = checkpoint.get('state_dict') or checkpoint - in_channels = state_dict[ - "model.diffusion_model.input_blocks.0.0.weight" - ].shape[1] - if in_channels == 9: - return VariantType.Inpaint - elif in_channels == 5: - return VariantType.depth - else: - return None - - - class CheckpointProbe(CheckpointProbeBase): - def get_base_type(self, - checkpoint: dict, - checkpoint_path: Path = None, - helper: Callable[[Path],BaseModelType] = None - )->BaseModelType: - state_dict = checkpoint.get('state_dict') or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1_5 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if 'global_step' in checkpoint: - if checkpoint['global_step'] == 220000: - return BaseModelType.StableDiffusion2Base - elif checkpoint["global_step"] == 110000: - return BaseModelType.StableDiffusion2 - if checkpoint_path and helper: - return helper(checkpoint_path) - else: - return None - - class VaeProbe(CheckpointProbeBase): - def get_base_type(self, - checkpoint: dict, - checkpoint_path: Path = None, - helper: Callable[[Path],BaseModelType] = None - )->BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1_5 - - class LoRAProbe(CheckpointProbeBase): - def get_base_type(self, - checkpoint: dict, - checkpoint_path: Path = None, - helper: Callable[[Path],BaseModelType] = None - )->BaseModelType: - key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" - lora_token_vector_length = ( - checkpoint[key1].shape[1] - if key1 in checkpoint - else checkpoint[key2].shape[0] - if key2 in checkpoint - else 768 - ) - if lora_token_vector_length == 768: - return BaseModelType.StableDiffusion1_5 - elif lora_token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - else: - return None - - class TextualInversionProbe(CheckpointProbeBase): - def get_base_type(self, - checkpoint: dict, - checkpoint_path: Path = None, - helper: Callable[[Path],BaseModelType] = None - )->BaseModelType: - - if 'string_to_token' in checkpoint: - token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1] - elif 'emb_params' in checkpoint: - token_dim = checkpoint['emb_params'].shape[-1] - else: - token_dim = list(checkpoint.values())[0].shape[0] - if token_dim == 768: - return BaseModelType.StableDiffusion1_5 - elif token_dim == 1024: - return BaseModelType.StableDiffusion2Base - else: - return None - - class ControlNetProbe(CheckpointProbeBase): - def get_base_type(self, - checkpoint: dict, - checkpoint_path: Path = None, - helper: Callable[[Path],BaseModelType] = None - )->BaseModelType: - for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight', - 'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight' - ): - if key_name not in checkpoint: - continue - if checkpoint[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1_5 - elif checkpoint_path and helper: - return helper(checkpoint_path) - PROBES = { - ModelType.Pipeline: CheckpointProbe, - ModelType.Vae: VaeProbe, - ModelType.Lora: LoRAProbe, - ModelType.TextualInversion: TextualInversionProbe, - ModelType.ControlNet: ControlNetProbe, - } - - @classmethod - def get_checkpoint_type(cls, state_dict: dict) -> ModelType: - if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]): - return ModelType.Pipeline - if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]): - return ModelType.Vae - if "string_to_token" in state_dict or "emb_params" in state_dict: - return ModelType.TextualInversion - if any([x.startswith("lora") for x in state_dict.keys()]): - return ModelType.Lora - if any([x.startswith("control_model") for x in state_dict.keys()]): - return ModelType.ControlNet - if any([x.startswith("input_blocks") for x in state_dict.keys()]): - return ModelType.ControlNet - return None # give up - diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 41d5c09ed5..4ab2381109 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -331,8 +331,9 @@ class ModelManager(object): model_key = self.create_key(model_name, base_model, model_type) return model_key in self.models + @classmethod def create_key( - self, + cls, model_name: str, base_model: BaseModelType, model_type: ModelType, diff --git a/invokeai/backend/model_management/models.py b/invokeai/backend/model_management/models.py deleted file mode 100644 index c09f6f6c30..0000000000 --- a/invokeai/backend/model_management/models.py +++ /dev/null @@ -1,728 +0,0 @@ -import sys -from dataclasses import dataclass -from enum import Enum -import torch -import safetensors.torch -from diffusers import ConfigMixin -from diffusers.utils import is_safetensors_available -from omegaconf import DictConfig -from pathlib import Path -from pydantic import BaseModel, Field, root_validator -from typing import Union, List, Type, Optional - -class BaseModelType(str, Enum): - # TODO: maybe then add sample size(512/768)? - StableDiffusion1_5 = "SD-1" - StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization - StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization - #Kandinsky2_1 = "kandinsky_2_1" - -class ModelType(str, Enum): - Pipeline = "pipeline" - Classifier = "classifier" - Vae = "vae" - Lora = "lora" - ControlNet = "controlnet" - TextualInversion = "embedding" - -class SubModelType: - UNet = "unet" - TextEncoder = "text_encoder" - Tokenizer = "tokenizer" - Vae = "vae" - Scheduler = "scheduler" - SafetyChecker = "safety_checker" - FeatureExtractor = "feature_extractor" - #MoVQ = "movq" - -class VariantType(str, Enum): - Normal = "normal" - Inpaint = "inpaint" - Depth = "depth" - -class EmptyConfigLoader(ConfigMixin): - @classmethod - def load_config(cls, *args, **kwargs): - cls.config_name = kwargs.pop("config_name") - return super().load_config(*args, **kwargs) - -class ModelBase: - #model_path: str - #base_model: BaseModelType - #model_type: ModelType - - def __init__( - self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, - ): - self.model_path = model_path - self.base_model = base_model - self.model_type = model_type - - def _hf_definition_to_type(self, subtypes: List[str]) -> Type: - if len(subtypes) < 2: - raise Exception("Invalid subfolder definition!") - if subtypes[0] in ["diffusers", "transformers"]: - res_type = sys.modules[subtypes[0]] - subtypes = subtypes[1:] - - else: - res_type = sys.modules["diffusers"] - res_type = getattr(res_type, "pipelines") - - - for subtype in subtypes: - res_type = getattr(res_type, subtype) - return res_type - - -class DiffusersModel(ModelBase): - #child_types: Dict[str, Type] - #child_sizes: Dict[str, int] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - super().__init__(model_path, base_model, model_type) - - self.child_types: Dict[str, Type] = dict() - self.child_sizes: Dict[str, int] = dict() - - try: - config_data = DiffusionPipeline.load_config(self.model_path) - #config_data = json.loads(os.path.join(self.model_path, "model_index.json")) - except: - raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") - - config_data.pop("_ignore_files", None) - - # retrieve all folder_names that contain relevant files - child_components = [k for k, v in config_data.items() if isinstance(v, list)] - - for child_name in child_components: - child_type = self._hf_definition_to_type(config_data[child_name]) - self.child_types[child_name] = child_type - self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) - - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is None: - return sum(self.child_sizes.values()) - else: - return self.child_sizes[child_type] - - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - # return pipeline in different function to pass more arguments - if child_type is None: - raise Exception("Child model type can't be null on diffusers model") - if child_type not in self.child_types: - return None # TODO: or raise - - if torch_dtype == torch.float16: - variants = ["fp16", None] - else: - variants = [None, "fp16"] - - # TODO: better error handling(differentiate not found from others) - for variant in variants: - try: - # TODO: set cache_dir to /dev/null to be sure that cache not used? - model = self.child_types[child_type].from_pretrained( - self.model_path, - subfolder=child_type.value, - torch_dtype=torch_dtype, - variant=variant, - local_files_only=True, - ) - break - except Exception as e: - print("====ERR LOAD====") - print(f"{variant}: {e}") - - # calc more accurate size - self.child_sizes[child_type] = calc_model_size_by_data(model) - return model - - #def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: - - -class StableDiffusionModel(DiffusersModel): - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model in { - BaseModelType.StableDiffusion1_5, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusion2Base, - } - assert model_type == ModelType.Pipeline - super().__init__(model_path, base_model, model_type) - - @staticmethod - def convert_if_required(model_path: Union[str, Path], dst_path: str, config: Optional[dict]) -> Path: - if not isinstance(model_path, Path): - model_path = Path(model_path) - - # TODO: args - # TODO: set model_path, to config? pass dst_path as arg? - # TODO: check - return _convert_ckpt_and_cache(config) - -class classproperty(object): # pylint: disable=invalid-name - """Class property decorator. - - Example usage: - - class MyClass(object): - - @classproperty - def value(cls): - return '123' - - > print MyClass.value - 123 - """ - - def __init__(self, func): - self._func = func - - def __get__(self, owner_self, owner_cls): - return self._func(owner_cls) - -class ModelConfigBase(BaseModel): - path: str # or Path - name: str - description: Optional[str] - - -class StableDiffusionDModel(DiffusersModel): - class Config(ModelConfigBase): - format: str - vae: Optional[str] = Field(None) - config: Optional[str] = Field(None) - - @root_validator - def validator(cls, values): - if values["format"] not in {"checkpoint", "diffusers"}: - raise ValueError(f"Unkown stable diffusion model format: {values['format']}") - if values["config"] is not None and values["format"] != "checkpoint": - raise ValueError(f"Custom config field allowed only in checkpoint stable diffusion model") - return values - - # return config only for checkpoint format - def dict(self, *args, **kwargs): - result = super().dict(*args, **kwargs) - if self.format != "checkpoint": - result.pop("config", None) - return result - - @classproperty - def has_config(self): - return True - - def build_config(self, **kwargs) -> dict: - try: - res = dict( - path=kwargs["path"], - name=kwargs["name"], - description=kwargs.get("description", None), - - format=kwargs["format"], - vae=kwargs.get("vae", None), - ) - if res["format"] not in {"checkpoint", "diffusers"}: - raise Exception(f"Unkonwn stable diffusion model format: {res['format']}") - if res["format"] == "checkpoint": - res["config"] = kwargs.get("config", None) - # TODO: raise if config specified for diffusers? - - return res - - except KeyError as e: - raise Exception(f"Field \"{e.args[0]}\" not found!") - - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion1_5 - assert model_type == ModelType.Pipeline - super().__init__(model_path, base_model, model_type) - - @classmethod - def convert_if_required(cls, model_path: str, dst_path: str, config: Optional[dict]) -> str: - model_config = cls.Config( - **config, - path=model_path, - name="", - ) - - if hasattr(model_config, "config"): - convert_ckpt_and_cache( - model_path=model_path, - dst_path=dst_path, - config=config, - ) - return dst_path - - else: - return model_path - -class StableDiffusion15CheckpointModel(DiffusersModel): - class Cnfig(ModelConfigBase): - vae: Optional[str] = Field(None) - config: Optional[str] = Field(None) - -class StableDiffusion2BaseDiffusersModel(DiffusersModel): - class Config(ModelConfigBase): - vae: Optional[str] = Field(None) - -class StableDiffusion2BaseCheckpointModel(DiffusersModel): - class Cnfig(ModelConfigBase): - vae: Optional[str] = Field(None) - config: Optional[str] = Field(None) - -class StableDiffusion2DiffusersModel(DiffusersModel): - class Config(ModelConfigBase): - vae: Optional[str] = Field(None) - attention_upscale: bool = Field(True) - -class StableDiffusion2CheckpointModel(DiffusersModel): - class Config(ModelConfigBase): - vae: Optional[str] = Field(None) - config: Optional[str] = Field(None) - attention_upscale: bool = Field(True) - - -class ClassifierModel(ModelBase): - #child_types: Dict[str, Type] - #child_sizes: Dict[str, int] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Classifier - super().__init__(model_path, base_model, model_type) - - self.child_types: Dict[str, Type] = dict() - self.child_sizes: Dict[str, int] = dict() - - try: - main_config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - #main_config = json.loads(os.path.join(self.model_path, "config.json")) - except: - raise Exception("Invalid classifier model! (config.json not found or invalid)") - - self._load_tokenizer(main_config) - self._load_text_encoder(main_config) - self._load_feature_extractor(main_config) - - - def _load_tokenizer(self, main_config: dict): - try: - tokenizer_config = EmptyConfigLoader.load_config(self.model_path, config_name="tokenizer_config.json") - #tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json")) - except: - raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)") - - if "tokenizer_class" in tokenizer_config: - tokenizer_class_name = tokenizer_config["tokenizer_class"] - elif "model_type" in main_config: - tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]] - else: - raise Exception("Invalid classifier model! (Failed to detect tokenizer type)") - - self.child_types[SubModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name]) - self.child_sizes[SubModelType.Tokenizer] = 0 - - - def _load_text_encoder(self, main_config: dict): - if "architectures" in main_config and len(main_config["architectures"]) > 0: - text_encoder_class_name = main_config["architectures"][0] - elif "model_type" in main_config: - text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]] - else: - raise Exception("Invalid classifier model! (Failed to detect text_encoder type)") - - self.child_types[SubModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name]) - self.child_sizes[SubModelType.TextEncoder] = calc_model_size_by_fs(self.model_path) - - - def _load_feature_extractor(self, main_config: dict): - self.child_sizes[SubModelType.FeatureExtractor] = 0 - try: - feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json") - except: - return # feature extractor not passed with t5 - - try: - feature_extractor_class_name = feature_extractor_config["feature_extractor_type"] - self.child_types[SubModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name]) - except: - raise Exception("Invalid classifier model! (Unknown feature_extrator type)") - - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is None: - return sum(self.child_sizes.values()) - else: - return self.child_sizes[child_type] - - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is None: - raise Exception("Child model type can't be null on classififer model") - if child_type not in self.child_types: - return None # TODO: or raise - - model = self.child_types[child_type].from_pretrained( - self.model_path, - subfolder=child_type.value, - torch_dtype=torch_dtype, - ) - # calc more accurate size - self.child_sizes[child_type] = calc_model_size_by_data(model) - return model - - @staticmethod - def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: - if not isinstance(model_path, Path): - model_path = Path(model_path) - return model_path - -class VaeModel(ModelBase): - #vae_class: Type - #model_size: int - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Vae - super().__init__(model_path, base_model, model_type) - - try: - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - #config = json.loads(os.path.join(self.model_path, "config.json")) - except: - raise Exception("Invalid vae model! (config.json not found or invalid)") - - try: - vae_class_name = config.get("_class_name", "AutoencoderKL") - self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - except: - raise Exception("Invalid vae model! (Unkown vae type)") - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in vae model") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in vae model") - - model = self.vae_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - ) - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - @staticmethod - def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: - if not isinstance(model_path, Path): - model_path = Path(model_path) - # TODO: - #_convert_vae_ckpt_and_cache - raise Exception("TODO: ") - - -class LoRAModel(ModelBase): - #model_size: int - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Lora - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in lora") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in lora") - - model = LoRAModel.from_checkpoint( - file_path=self.model_path, - dtype=torch_dtype, - ) - - self.model_size = model.calc_size() - return model - - @staticmethod - def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: - if not isinstance(model_path, Path): - model_path = Path(model_path) - - # TODO: add diffusers lora when it stabilizes a bit - return model_path - -class TextualInversionModel(ModelBase): - #model_size: int - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.TextualInversion - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - - model = TextualInversionModel.from_checkpoint( - file_path=self.model_path, - dtype=torch_dtype, - ) - - self.model_size = model.embedding.nelement() * model.embedding.element_size() - return model - - @staticmethod - def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: - if not isinstance(model_path, Path): - model_path = Path(model_path) - return model_path - -class ControlNetModel(ModelBase): - """requires implementation""" - pass - -def calc_model_size_by_fs( - model_path: str, - subfolder: Optional[str] = None, - variant: Optional[str] = None -): - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - - # this can happen when, for example, the safety checker - # is not downloaded. - if not os.path.exists(model_path): - return 0 - - all_files = os.listdir(model_path) - all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] - - fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f]) - bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f]) - other_files = set(all_files) - fp16_files - bit8_files - - if variant is None: - files = other_files - elif variant == "fp16": - files = fp16_files - elif variant == "8bit": - files = bit8_files - else: - raise NotImplementedError(f"Unknown variant: {variant}") - - # try read from index if exists - index_postfix = ".index.json" - if variant is not None: - index_postfix = f".index.{variant}.json" - - for file in files: - if not file.endswith(index_postfix): - continue - try: - with open(os.path.join(model_path, file), "r") as f: - index_data = json.loads(f.read()) - return int(index_data["metadata"]["total_size"]) - except: - pass - - # calculate files size if there is no index file - formats = [ - (".safetensors",), # safetensors - (".bin",), # torch - (".onnx", ".pb"), # onnx - (".msgpack",), # flax - (".ckpt",), # tf - (".h5",), # tf2 - ] - - for file_format in formats: - model_files = [f for f in files if f.endswith(file_format)] - if len(model_files) == 0: - continue - - model_size = 0 - for model_file in model_files: - file_stats = os.stat(os.path.join(model_path, model_file)) - model_size += file_stats.st_size - return model_size - - #raise NotImplementedError(f"Unknown model structure! Files: {all_files}") - return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu - - -def calc_model_size_by_data(model) -> int: - if isinstance(model, DiffusionPipeline): - return _calc_pipeline_by_data(model) - elif isinstance(model, torch.nn.Module): - return _calc_model_by_data(model) - else: - return 0 - - -def _calc_pipeline_by_data(pipeline) -> int: - res = 0 - for submodel_key in pipeline.components.keys(): - submodel = getattr(pipeline, submodel_key) - if submodel is not None and isinstance(submodel, torch.nn.Module): - res += _calc_model_by_data(submodel) - return res - - -def _calc_model_by_data(model) -> int: - mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()]) - mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()]) - mem = mem_params + mem_bufs # in bytes - return mem - - -def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path: - """ - Convert the checkpoint model indicated in mconfig into a - diffusers, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - weights = app_config.root_dir / mconfig.path - config_file = app_config.root_dir / mconfig.config - diffusers_path = app_config.converted_ckpts_dir / weights.stem - - # return cached version if it exists - if diffusers_path.exists(): - return diffusers_path - - # TODO: I think that it more correctly to convert with embedded vae - # as if user will delete custom vae he will got not embedded but also custom vae - #vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig) - vae_ckpt_path, vae_model = None, None - - # to avoid circular import errors - from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers - with SilenceWarnings(): - convert_ckpt_to_diffusers( - weights, - diffusers_path, - extract_ema=True, - original_config_file=config_file, - vae=vae_model, - vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None, - scan_needed=True, - ) - return diffusers_path - -def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path: - """ - Convert the VAE indicated in mconfig into a diffusers AutoencoderKL - object, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - root = app_config.root_dir - weights_file = root / mconfig.path - config_file = root / mconfig.config - diffusers_path = app_config.converted_ckpts_dir / weights_file.stem - image_size = mconfig.get('width') or mconfig.get('height') or 512 - - # return cached version if it exists - if diffusers_path.exists(): - return diffusers_path - - # this avoids circular import error - from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers - if weights_file.suffix == '.safetensors': - checkpoint = safetensors.torch.load_file(weights_file) - else: - checkpoint = torch.load(weights_file, map_location="cpu") - - # sometimes weights are hidden under "state_dict", and sometimes not - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - config = OmegaConf.load(config_file) - - vae_model = convert_ldm_vae_to_diffusers( - checkpoint = checkpoint, - vae_config = config, - image_size = image_size - ) - vae_model.save_pretrained( - diffusers_path, - safe_serialization=is_safetensors_available() - ) - return diffusers_path - -MODEL_CLASSES = { - BaseModelType.StableDiffusion1_5: { - ModelType.Pipeline: StableDiffusionModel, - ModelType.Classifier: ClassifierModel, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - }, - BaseModelType.StableDiffusion2: { - ModelType.Pipeline: StableDiffusionModel, - ModelType.Classifier: ClassifierModel, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - }, - BaseModelType.StableDiffusion2Base: { - ModelType.Pipeline: StableDiffusionModel, - ModelType.Classifier: ClassifierModel, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - }, - #BaseModel.Kandinsky2_1: { - # ModelType.Pipeline: Kandinsky2_1Model, - # ModelType.Classifier: ClassifierModel, - # ModelType.MoVQ: MoVQModel, - # ModelType.Lora: LoraModel, - # ModelType.ControlNet: ControlNetModel, - # ModelType.TextualInversion: TextualInversionModel, - #}, -} - diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 0edbd40afd..7415dcac0a 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -35,6 +35,11 @@ class SubModelType(str, Enum): SafetyChecker = "safety_checker" #MoVQ = "movq" +class VariantType(str, Enum): + Normal = "normal" + Inpaint = "inpaint" + Depth = "depth" + class ModelError(str, Enum): NotFound = "not_found"