diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 2575a92a5c..8d1392cf93 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -9,5 +9,5 @@ from .generator import ( Img2Img, Inpaint ) -from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo +from .model_management import ModelManager, ModelCache, ModelType, ModelInfo from .safety_checker import SafetyChecker diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 2ed7e21ef8..6fcd705c44 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -1,5 +1,6 @@ """ Initialization file for invokeai.backend.model_management """ -from .model_manager import ModelManager, SDModelInfo -from .model_cache import ModelCache, SDModelType +from .model_manager import ModelManager, ModelInfo +from .model_cache import ModelCache +from .models import ModelType diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 7174db595d..9d5338c8ce 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -29,11 +29,8 @@ import torch from diffusers import logging as diffusers_logging from transformers import logging as transformers_logging - import invokeai.backend.util.logging as logger - -from .model_manager import SDModelInfo, ModelType, SubModelType, ModelBase - +from .models import ModelType, SubModelType, ModelBase # Maximum size of the cache, in gigs # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously @@ -50,6 +47,10 @@ class ModelCache(object): "Forward declaration" pass +class SDModelInfo(object): + """Forward declaration""" + pass + class _CacheRecord: size: int model: Any diff --git a/invokeai/backend/model_management/model_install.py b/invokeai/backend/model_management/model_install.py new file mode 100644 index 0000000000..e1ac441fee --- /dev/null +++ b/invokeai/backend/model_management/model_install.py @@ -0,0 +1,201 @@ +""" +Routines for downloading and installing models. +""" +import json +import safetensors +import safetensors.torch +import torch +import traceback +from dataclasses import dataclass +from diffusers import ModelMixin +from enum import Enum +from typing import Callable +from pathlib import Path +from picklescan.scanner import scan_file_path + +import invokeai.backend.util.logging as logger +from .models import BaseModelType, ModelType + +class CheckpointProbe(object): + PROBES = dict() # see below for redefinition + + def __init__(self, + checkpoint_path: Path, + checkpoint: dict = None, + helper: Callable[[Path], BaseModelType]=None + ): + checkpoint = checkpoint or self._scan_and_load_checkpoint(self.checkpoint_path) + self.checkpoint = checkpoint + self.checkpoint_path = checkpoint_path + self.helper = helper + + def probe(self) -> ModelVariantInfo: + ''' + Probes the checkpoint at path `checkpoint_path` and return + a ModelType object indicating the model base, model type and + model variant for the checkpoint. + ''' + checkpoint = self.checkpoint + state_dict = checkpoint.get("state_dict") or checkpoint + + model_info = None + + try: + model_type = self.get_checkpoint_type(state_dict) + if not model_type: + if self.checkpoint_path.name == "learned_embeds.bin": + model_type = ModelType.TextualInversion + else: + return None # we give up + probe = self.PROBES[model_type]() + base_type = probe.get_base_type(checkpoint, self.checkpoint_path, self.helper) + variant_type = probe.get_variant_type(model_type, checkpoint) + + model_info = ModelVariantInfo( + model_type = model_type, + base_type = base_type, + variant_type = variant_type, + ) + except (KeyError, ValueError) as e: + logger.error(f'An error occurred while probing {self.checkpoint_path}: {str(e)}') + logger.error(traceback.format_exc()) + + return model_info + + class CheckpointProbeBase(object): + def get_base_type(self, + checkpoint: dict, + checkpoint_path: Path = None, + helper: Callable[[Path],BaseModelType] = None + )->BaseModelType: + pass + + def get_variant_type(self, + model_type: ModelType, + checkpoint: dict, + )-> VariantType: + if model_type != ModelType.Pipeline: + return None + state_dict = checkpoint.get('state_dict') or checkpoint + in_channels = state_dict[ + "model.diffusion_model.input_blocks.0.0.weight" + ].shape[1] + if in_channels == 9: + return VariantType.Inpaint + elif in_channels == 5: + return VariantType.depth + else: + return None + + + class CheckpointProbe(CheckpointProbeBase): + def get_base_type(self, + checkpoint: dict, + checkpoint_path: Path = None, + helper: Callable[[Path],BaseModelType] = None + )->BaseModelType: + state_dict = checkpoint.get('state_dict') or checkpoint + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1_5 + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + if 'global_step' in checkpoint: + if checkpoint['global_step'] == 220000: + return BaseModelType.StableDiffusion2Base + elif checkpoint["global_step"] == 110000: + return BaseModelType.StableDiffusion2 + if checkpoint_path and helper: + return helper(checkpoint_path) + else: + return None + + class VaeProbe(CheckpointProbeBase): + def get_base_type(self, + checkpoint: dict, + checkpoint_path: Path = None, + helper: Callable[[Path],BaseModelType] = None + )->BaseModelType: + # I can't find any standalone 2.X VAEs to test with! + return BaseModelType.StableDiffusion1_5 + + class LoRAProbe(CheckpointProbeBase): + def get_base_type(self, + checkpoint: dict, + checkpoint_path: Path = None, + helper: Callable[[Path],BaseModelType] = None + )->BaseModelType: + key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" + key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + lora_token_vector_length = ( + checkpoint[key1].shape[1] + if key1 in checkpoint + else checkpoint[key2].shape[0] + if key2 in checkpoint + else 768 + ) + if lora_token_vector_length == 768: + return BaseModelType.StableDiffusion1_5 + elif lora_token_vector_length == 1024: + return BaseModelType.StableDiffusion2 + else: + return None + + class TextualInversionProbe(CheckpointProbeBase): + def get_base_type(self, + checkpoint: dict, + checkpoint_path: Path = None, + helper: Callable[[Path],BaseModelType] = None + )->BaseModelType: + + if 'string_to_token' in checkpoint: + token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1] + elif 'emb_params' in checkpoint: + token_dim = checkpoint['emb_params'].shape[-1] + else: + token_dim = list(checkpoint.values())[0].shape[0] + if token_dim == 768: + return BaseModelType.StableDiffusion1_5 + elif token_dim == 1024: + return BaseModelType.StableDiffusion2Base + else: + return None + + class ControlNetProbe(CheckpointProbeBase): + def get_base_type(self, + checkpoint: dict, + checkpoint_path: Path = None, + helper: Callable[[Path],BaseModelType] = None + )->BaseModelType: + for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight', + 'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight' + ): + if key_name not in checkpoint: + continue + if checkpoint[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1_5 + elif checkpoint_path and helper: + return helper(checkpoint_path) + PROBES = { + ModelType.Pipeline: CheckpointProbe, + ModelType.Vae: VaeProbe, + ModelType.Lora: LoRAProbe, + ModelType.TextualInversion: TextualInversionProbe, + ModelType.ControlNet: ControlNetProbe, + } + + @classmethod + def get_checkpoint_type(cls, state_dict: dict) -> ModelType: + if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]): + return ModelType.Pipeline + if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]): + return ModelType.Vae + if "string_to_token" in state_dict or "emb_params" in state_dict: + return ModelType.TextualInversion + if any([x.startswith("lora") for x in state_dict.keys()]): + return ModelType.Lora + if any([x.startswith("control_model") for x in state_dict.keys()]): + return ModelType.ControlNet + if any([x.startswith("input_blocks") for x in state_dict.keys()]): + return ModelType.ControlNet + return None # give up + diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index ba931c28c7..19798cd0c5 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -3,13 +3,13 @@ symbolic diffusers model names to the paths and repo_ids used by the underlying `from_pretrained()` call. For fetching models, use manager.get_model('symbolic name'). This will -return a SDModelInfo object that contains the following attributes: +return a ModelInfo object that contains the following attributes: * context -- a context manager Generator that loads and locks the model into GPU VRAM and returns the model for use. See below for usage. * name -- symbolic name of the model - * type -- SDModelType of the model + * type -- SubModelType of the model * hash -- unique hash for the model * location -- path or repo_id of the model * revision -- revision of the model if coming from a repo id, @@ -25,7 +25,7 @@ Typical usage: max_cache_size=8 ) # gigabytes - model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers) + model_info = manager.get_model('stable-diffusion-1.5', SubModelType.Diffusers) with model_info.context as my_model: my_model.latents_from_embeddings(...) @@ -43,7 +43,7 @@ parameter: model_info = manager.get_model( 'clip-tokenizer', - model_type=SDModelType.Tokenizer + model_type=SubModelType.Tokenizer ) This will raise an InvalidModelError if the format defined in the @@ -63,7 +63,7 @@ The general format of a models.yaml section is: The type of model is given in the stanza key, and is one of {diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, lora, textual_inversion, -controlnet}, and correspond to items in the SDModelType enum defined +controlnet}, and correspond to items in the SubModelType enum defined in model_cache.py The format indicates whether the model is organized as a folder with @@ -96,7 +96,7 @@ SUBMODELS: It is also possible to fetch an isolated submodel from a diffusers model. Use the `submodel` parameter to select which part: - vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae) + vae = manager.get_model('stable-diffusion-1.5',submodel=SubModelType.Vae) with vae.context as my_vae: print(type(my_vae)) # "AutoencoderKL" @@ -128,8 +128,8 @@ separated by "/". Example: You can now use the `model_type` argument to indicate which model you want: - tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer) - encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder) + tokenizer = mgr.get('clip-large',model_type=SubModelType.Tokenizer) + encoder = mgr.get('clip-large',model_type=SubModelType.TextEncoder) OTHER FUNCTIONS: @@ -164,15 +164,24 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import CUDA_DEVICE, download_with_resume from .model_cache import ModelCache, ModelLocker -from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES +from .models import BaseModelType, SubModelType, MODEL_CLASSES + # We are only starting to number the config file with release 3. # The config file version doesn't have to start at release version, but it will help # reduce confusion. CONFIG_FILE_VERSION='3.0.0' -# wanted to use pydantic here, but Generator objects not supported +# temporary forward definitions to avoid circular import errors. +class ModelLocker(object): + "Forward declaration" + pass + +class ModelCache(object): + "Forward declaration" + pass + @dataclass -class SDModelInfo(): +class ModelInfo(): context: ModelLocker name: str type: ModelType @@ -303,7 +312,7 @@ class ModelManager(object): submodel_type: Optional[SubModelType] = None ): """Given a model named identified in models.yaml, return - an SDModelInfo object describing it. + an ModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml :param model_type: ModelType enum indicating the type of model to return :param submode_typel: an ModelType enum indicating the portion of @@ -389,7 +398,7 @@ class ModelManager(object): hash = "" # TODO: - return SDModelInfo( + return ModelInfo( context = model_context, name = model_name, base_model = base_model, @@ -746,62 +755,3 @@ class ModelManager(object): if self.config_path: self.commit() - def _delete_defunct_models(self): - ''' - Remove models no longer on disk. - ''' - config = self.config - - to_delete = set() - for key in config: - if 'path' not in config[key]: - continue - path = self.globals.root_dir / config[key].path - if path.exists(): - continue - to_delete.add(key) - - for key in to_delete: - self.logger.warn(f'Removing model {key} from in-memory config because its path is no longer on disk') - config.pop(key) - - def scan_models_directory(self, include_diffusers:bool=False): - ''' - Scan the models directory for loras, textual_inversions and controlnets - and create appropriate entries in the in-memory omegaconf. Diffusers - will not be added unless include_diffusers is true. - ''' - self._delete_defunct_models() - - model_directory = self.globals.models_path - config = self.config - - for root, dirs, files in os.walk(model_directory): - parents = root.split('/') - subpaths = parents[parents.index('models')+1:] - if len(subpaths) < 2: - continue - base, model_type, *_ = subpaths - - if model_type == "diffusers" and not include_diffusers: - continue - - for d in dirs: - config[f'{model_type}/{d}'] = dict( - path = os.path.join(root,d), - description = f'{model_type} model {d}', - format = 'folder', - base = base, - ) - - for f in files: - basename = Path(f).stem - format = Path(f).suffix[1:] - config[f'{model_type}/{basename}'] = dict( - path = os.path.join(root,f), - description = f'{model_type} model {basename}', - format = format, - base = base, - ) - - diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py new file mode 100644 index 0000000000..bb2bbc2a85 --- /dev/null +++ b/invokeai/backend/model_management/model_probe.py @@ -0,0 +1,333 @@ +import json +import traceback +import torch +import safetensors.torch + +from dataclasses import dataclass +from diffusers import ModelMixin +from pathlib import Path +from typing import Callable, Literal, Union, Dict +from picklescan.scanner import scan_file_path + +import invokeai.backend.util.logging as logger +from .models import BaseModelType, ModelType, VariantType + +@dataclass +class ModelVariantInfo(object): + model_type: ModelType + base_type: BaseModelType + variant_type: VariantType + +class ProbeBase(object): + '''forward declaration''' + pass + +class ModelProbe(object): + + PROBES = { + 'folder': { }, + 'file': { }, + } + + CLASS2TYPE = { + "StableDiffusionPipeline" : ModelType.Pipeline, + "AutoencoderKL": ModelType.Vae, + "ControlNetModel" : ModelType.ControlNet, + + } + + @classmethod + def register_probe(cls, + format: Literal['folder','file'], + model_type: ModelType, + probe_class: ProbeBase): + cls.PROBES[format][model_type] = probe_class + + @classmethod + def probe(cls, + model_path: Path, + model: Union[Dict, ModelMixin] = None, + base_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo: + ''' + Probe the model at model_path and return sufficient information about it + to place it somewhere in the models directory hierarchy. If the model is + already loaded into memory, you may provide it as model in order to avoid + opening it a second time. The base_helper callable is a function that receives + the path to the model and returns the BaseModelType. It is called to distinguish + between V2-Base and V2-768 SD models. + ''' + format = 'folder' if model_path.is_dir() else 'file' + model_info = None + try: + model_type = cls.get_model_type_from_folder(model_path, model) \ + if format == 'folder' \ + else cls.get_model_type_from_checkpoint(model_path, model) + probe_class = cls.PROBES[format].get(model_type) + if not probe_class: + return None + probe = probe_class(model_path, model, base_helper) + base_type = probe.get_base_type() + variant_type = probe.get_variant_type() + model_info = ModelVariantInfo( + model_type = model_type, + base_type = base_type, + variant_type = variant_type, + ) + except (KeyError, ValueError) as e: + logger.error(f'An error occurred while probing {model_path}: {str(e)}') + logger.error(traceback.format_exc()) + + return model_info + + @classmethod + def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType: + checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path) + state_dict = checkpoint.get("state_dict") or checkpoint + if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]): + return ModelType.Pipeline + if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]): + return ModelType.Vae + if "string_to_token" in state_dict or "emb_params" in state_dict: + return ModelType.TextualInversion + if any([x.startswith("lora") for x in state_dict.keys()]): + return ModelType.Lora + if any([x.startswith("control_model") for x in state_dict.keys()]): + return ModelType.ControlNet + if any([x.startswith("input_blocks") for x in state_dict.keys()]): + return ModelType.ControlNet + return None # give up + + @classmethod + def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: + ''' + Get the model type of a hugging-face style folder. + ''' + if (folder_path / 'learned_embeds.bin').exists(): + return ModelType.TextualInversion + + if (folder_path / 'pytorch_lora_weights.bin').exists(): + return ModelType.Lora + + i = folder_path / 'model_index.json' + c = folder_path / 'config.json' + config_path = i if i.exists() else c if c.exists() else None + + if config_path: + conf = json.loads(config_path) + class_name = conf['_class_name'] + if type := cls.CLASS2TYPE.get(class_name): + return type + + # give up + raise ValueError("Unable to determine model type of {model_path}") + + @classmethod + def _scan_and_load_checkpoint(cls,model_path: Path)->dict: + if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): + cls._scan_model(model_path, model_path) + return torch.load(model_path) + else: + return safetensors.torch.load_file(model_path) + + @classmethod + def _scan_model(cls, model_name, checkpoint): + """ + Apply picklescanner to the indicated checkpoint and issue a warning + and option to exit if an infected file is identified. + """ + # scan model + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise "The model {model_name} is potentially infected by malware. Aborting import." + +###################################################3 +# Checkpoint probing +###################################################3 +class ProbeBase(object): + def get_base_type(self)->BaseModelType: + pass + + def get_variant_type(self)->VariantType: + pass + +class CheckpointProbeBase(ProbeBase): + def __init__(self, + checkpoint_path: Path, + checkpoint: dict, + helper: Callable[[Path],BaseModelType] = None + )->BaseModelType: + self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) + self.checkpoint_path = checkpoint_path + self.helper = helper + + def get_base_type(self)->BaseModelType: + pass + + def get_variant_type(self)-> VariantType: + model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint) + if model_type != ModelType.Pipeline: + return VariantType.Normal + state_dict = self.checkpoint.get('state_dict') or self.checkpoint + in_channels = state_dict[ + "model.diffusion_model.input_blocks.0.0.weight" + ].shape[1] + if in_channels == 9: + return VariantType.Inpaint + elif in_channels == 5: + return VariantType.Depth + else: + return None + +class PipelineCheckpointProbe(CheckpointProbeBase): + def get_base_type(self)->BaseModelType: + checkpoint = self.checkpoint + helper = self.helper + state_dict = self.checkpoint.get('state_dict') or checkpoint + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1_5 + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + if 'global_step' in checkpoint: + if checkpoint['global_step'] == 220000: + return BaseModelType.StableDiffusion2Base + elif checkpoint["global_step"] == 110000: + return BaseModelType.StableDiffusion2 + if self.checkpoint_path and helper: + return helper(self.checkpoint_path) + else: + return None + +class VaeCheckpointProbe(CheckpointProbeBase): + def get_base_type(self)->BaseModelType: + # I can't find any standalone 2.X VAEs to test with! + return BaseModelType.StableDiffusion1_5 + +class LoRACheckpointProbe(CheckpointProbeBase): + def get_base_type(self)->BaseModelType: + checkpoint = self.checkpoint + key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" + key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + lora_token_vector_length = ( + checkpoint[key1].shape[1] + if key1 in checkpoint + else checkpoint[key2].shape[0] + if key2 in checkpoint + else 768 + ) + if lora_token_vector_length == 768: + return BaseModelType.StableDiffusion1_5 + elif lora_token_vector_length == 1024: + return BaseModelType.StableDiffusion2 + else: + return None + +class TextualInversionCheckpointProbe(CheckpointProbeBase): + def get_base_type(self)->BaseModelType: + checkpoint = self.checkpoint + if 'string_to_token' in checkpoint: + token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1] + elif 'emb_params' in checkpoint: + token_dim = checkpoint['emb_params'].shape[-1] + else: + token_dim = list(checkpoint.values())[0].shape[0] + if token_dim == 768: + return BaseModelType.StableDiffusion1_5 + elif token_dim == 1024: + return BaseModelType.StableDiffusion2Base + else: + return None + +class ControlNetCheckpointProbe(CheckpointProbeBase): + def get_base_type(self)->BaseModelType: + checkpoint = self.checkpoint + for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight', + 'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight' + ): + if key_name not in checkpoint: + continue + if checkpoint[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1_5 + elif self.checkpoint_path and self.helper: + return self.helper(self.checkpoint_path) + +######################################################## +# classes for probing folders +####################################################### +class FolderProbeBase(ProbeBase): + def __init__(self, + model: ModelMixin, + folder_path: Path, + helper: Callable=None # not used + ): + self.model = model + self.folder_path = folder_path + + def get_variant_type(self)->VariantType: + + # only works for pipelines + config_file = self.folder_path / 'unet' / 'config.json' + if not config_file.exists(): + return VariantType.Normal + + conf = json.loads(config_file) + channels = conf['in_channels'] + if channels == 9: + return VariantType.Inpainting + elif channels == 5: + return VariantType.Depth + elif channels == 4: + return VariantType.Normal + else: + return VariantType.Normal + +class PipelineFolderProbe(FolderProbeBase): + def get_base_type(self)->BaseModelType: + config_file = self.folder_path / 'scheduler' / 'scheduler_config.json' + if not config_file.exists(): + return None + conf = json.load(config_file) + if conf['prediction_type'] == "v_prediction": + return BaseModelType.StableDiffusion2 + elif conf['prediction_type'] == 'epsilon': + return BaseModelType.StableDiffusion2Base + else: + return BaseModelType.StableDiffusion2 + +class VaeFolderProbe(FolderProbeBase): + def get_base_type(self)->BaseModelType: + return BaseModelType.StableDiffusion1_5 + +class TextualInversionFolderProbe(FolderProbeBase): + def get_base_type(self)->BaseModelType: + path = self.folder_path / 'learned_embeds.bin' + if not path.exists(): + return None + checkpoint = ModelProbe._scan_and_load_checkpoint(path) + return TextualInversionCheckpointProbe(checkpoint).get_base_type + +class ControlNetFolderProbe(FolderProbeBase): + def get_base_type(self)->BaseModelType: + config_file = self.folder_path / 'scheduler_config.json' + if not config_file.exists(): + return None + config = json.load(config_file) + # no obvious way to distinguish between sd2-base and sd2-768 + return BaseModelType.StableDiffusion1_5 \ + if config['cross_attention_dim']==768 \ + else BaseModelType.StableDiffusion2 + +class LoRAFolderProbe(FolderProbeBase): + # I've never seen one of these in the wild, so this is a noop + pass + +############## register probe classes ###### +ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe) +ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe) +ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe) +ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe) +ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe) +ModelProbe.register_probe('file', ModelType.Pipeline, PipelineCheckpointProbe) +ModelProbe.register_probe('file', ModelType.Vae, VaeCheckpointProbe) +ModelProbe.register_probe('file', ModelType.Lora, LoRACheckpointProbe) +ModelProbe.register_probe('file', ModelType.TextualInversion, TextualInversionCheckpointProbe) +ModelProbe.register_probe('file', ModelType.ControlNet, ControlNetCheckpointProbe) diff --git a/invokeai/backend/model_management/models.py b/invokeai/backend/model_management/models.py index 953eaca383..c09f6f6c30 100644 --- a/invokeai/backend/model_management/models.py +++ b/invokeai/backend/model_management/models.py @@ -1,8 +1,14 @@ import sys +from dataclasses import dataclass from enum import Enum import torch import safetensors.torch +from diffusers import ConfigMixin from diffusers.utils import is_safetensors_available +from omegaconf import DictConfig +from pathlib import Path +from pydantic import BaseModel, Field, root_validator +from typing import Union, List, Type, Optional class BaseModelType(str, Enum): # TODO: maybe then add sample size(512/768)? @@ -26,42 +32,13 @@ class SubModelType: Vae = "vae" Scheduler = "scheduler" SafetyChecker = "safety_checker" + FeatureExtractor = "feature_extractor" #MoVQ = "movq" -MODEL_CLASSES = { - BaseModel.StableDiffusion1_5: { - ModelType.Pipeline: StableDiffusionModel, - ModelType.Classifier: ClassifierModel, - ModelType.Vae: VaeModel, - ModelType.Lora: LoraModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - }, - BaseModel.StableDiffusion2: { - ModelType.Pipeline: StableDiffusionModel, - ModelType.Classifier: ClassifierModel, - ModelType.Vae: VaeModel, - ModelType.Lora: LoraModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - }, - BaseModel.StableDiffusion2Base: { - ModelType.Pipeline: StableDiffusionModel, - ModelType.Classifier: ClassifierModel, - ModelType.Vae: VaeModel, - ModelType.Lora: LoraModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - }, - #BaseModel.Kandinsky2_1: { - # ModelType.Pipeline: Kandinsky2_1Model, - # ModelType.Classifier: ClassifierModel, - # ModelType.MoVQ: MoVQModel, - # ModelType.Lora: LoraModel, - # ModelType.ControlNet: ControlNetModel, - # ModelType.TextualInversion: TextualInversionModel, - #}, -} +class VariantType(str, Enum): + Normal = "normal" + Inpaint = "inpaint" + Depth = "depth" class EmptyConfigLoader(ConfigMixin): @classmethod @@ -323,7 +300,7 @@ class ClassifierModel(ModelBase): #child_sizes: Dict[str, int] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == SDModelType.Classifier + assert model_type == ModelType.Classifier super().__init__(model_path, base_model, model_type) self.child_types: Dict[str, Type] = dict() @@ -354,8 +331,8 @@ class ClassifierModel(ModelBase): else: raise Exception("Invalid classifier model! (Failed to detect tokenizer type)") - self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name]) - self.child_sizes[SDModelType.Tokenizer] = 0 + self.child_types[SubModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name]) + self.child_sizes[SubModelType.Tokenizer] = 0 def _load_text_encoder(self, main_config: dict): @@ -366,12 +343,12 @@ class ClassifierModel(ModelBase): else: raise Exception("Invalid classifier model! (Failed to detect text_encoder type)") - self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name]) - self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path) + self.child_types[SubModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name]) + self.child_sizes[SubModelType.TextEncoder] = calc_model_size_by_fs(self.model_path) def _load_feature_extractor(self, main_config: dict): - self.child_sizes[SDModelType.FeatureExtractor] = 0 + self.child_sizes[SubModelType.FeatureExtractor] = 0 try: feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json") except: @@ -379,12 +356,12 @@ class ClassifierModel(ModelBase): try: feature_extractor_class_name = feature_extractor_config["feature_extractor_type"] - self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name]) + self.child_types[SubModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name]) except: raise Exception("Invalid classifier model! (Unknown feature_extrator type)") - def get_size(self, child_type: Optional[SDModelType] = None): + def get_size(self, child_type: Optional[SubModelType] = None): if child_type is None: return sum(self.child_sizes.values()) else: @@ -394,7 +371,7 @@ class ClassifierModel(ModelBase): def get_model( self, torch_dtype: Optional[torch.dtype], - child_type: Optional[SDModelType] = None, + child_type: Optional[SubModelType] = None, ): if child_type is None: raise Exception("Child model type can't be null on classififer model") @@ -437,7 +414,7 @@ class VaeModel(ModelBase): except: raise Exception("Invalid vae model! (Unkown vae type)") - def get_size(self, child_type: Optional[SDModelType] = None): + def get_size(self, child_type: Optional[SubModelType] = None): if child_type is not None: raise Exception("There is no child models in vae model") return self.model_size @@ -445,7 +422,7 @@ class VaeModel(ModelBase): def get_model( self, torch_dtype: Optional[torch.dtype], - child_type: Optional[SDModelType] = None, + child_type: Optional[SubModelType] = None, ): if child_type is not None: raise Exception("There is no child models in vae model") @@ -476,7 +453,7 @@ class LoRAModel(ModelBase): self.model_size = os.path.getsize(self.model_path) - def get_size(self, child_type: Optional[SDModelType] = None): + def get_size(self, child_type: Optional[SubModelType] = None): if child_type is not None: raise Exception("There is no child models in lora") return self.model_size @@ -484,7 +461,7 @@ class LoRAModel(ModelBase): def get_model( self, torch_dtype: Optional[torch.dtype], - child_type: Optional[SDModelType] = None, + child_type: Optional[SubModelType] = None, ): if child_type is not None: raise Exception("There is no child models in lora") @@ -505,7 +482,6 @@ class LoRAModel(ModelBase): # TODO: add diffusers lora when it stabilizes a bit return model_path - class TextualInversionModel(ModelBase): #model_size: int @@ -515,7 +491,7 @@ class TextualInversionModel(ModelBase): self.model_size = os.path.getsize(self.model_path) - def get_size(self, child_type: Optional[SDModelType] = None): + def get_size(self, child_type: Optional[SubModelType] = None): if child_type is not None: raise Exception("There is no child models in textual inversion") return self.model_size @@ -523,7 +499,7 @@ class TextualInversionModel(ModelBase): def get_model( self, torch_dtype: Optional[torch.dtype], - child_type: Optional[SDModelType] = None, + child_type: Optional[SubModelType] = None, ): if child_type is not None: raise Exception("There is no child models in textual inversion") @@ -542,6 +518,10 @@ class TextualInversionModel(ModelBase): model_path = Path(model_path) return model_path +class ControlNetModel(ModelBase): + """requires implementation""" + pass + def calc_model_size_by_fs( model_path: str, subfolder: Optional[str] = None, @@ -710,3 +690,39 @@ def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path: safe_serialization=is_safetensors_available() ) return diffusers_path + +MODEL_CLASSES = { + BaseModelType.StableDiffusion1_5: { + ModelType.Pipeline: StableDiffusionModel, + ModelType.Classifier: ClassifierModel, + ModelType.Vae: VaeModel, + ModelType.Lora: LoRAModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, + BaseModelType.StableDiffusion2: { + ModelType.Pipeline: StableDiffusionModel, + ModelType.Classifier: ClassifierModel, + ModelType.Vae: VaeModel, + ModelType.Lora: LoRAModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, + BaseModelType.StableDiffusion2Base: { + ModelType.Pipeline: StableDiffusionModel, + ModelType.Classifier: ClassifierModel, + ModelType.Vae: VaeModel, + ModelType.Lora: LoRAModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, + #BaseModel.Kandinsky2_1: { + # ModelType.Pipeline: Kandinsky2_1Model, + # ModelType.Classifier: ClassifierModel, + # ModelType.MoVQ: MoVQModel, + # ModelType.Lora: LoraModel, + # ModelType.ControlNet: ControlNetModel, + # ModelType.TextualInversion: TextualInversionModel, + #}, +} + diff --git a/scripts/scan_models_directory.py b/scripts/scan_models_directory.py old mode 100644 new mode 100755 index 72d17a6755..778a6c5ed5 --- a/scripts/scan_models_directory.py +++ b/scripts/scan_models_directory.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + ''' Scan the models directory and print out a new models.yaml ''' @@ -12,6 +14,12 @@ from omegaconf import OmegaConf def main(): parser = argparse.ArgumentParser(description="Model directory scanner") parser.add_argument('models_directory') + parser.add_argument('--all-models', + default=False, + action='store_true', + help='If true, then generates stanzas for all models; otherwise just diffusers' + ) + args = parser.parse_args() directory = args.models_directory @@ -19,29 +27,30 @@ def main(): conf['_version'] = '3.0.0' for root, dirs, files in os.walk(directory): - for d in dirs: - parents = root.split('/') - subpaths = parents[parents.index('models')+1:] - if len(subpaths) < 2: - continue - base, model_type, *_ = subpaths - - conf[f'{model_type}/{d}'] = dict( - path = os.path.join(root,d), - description = f'{model_type} model {d}', - format = 'folder', - base = base, - ) - - for f in files: - basename = Path(f).stem - format = Path(f).suffix[1:] - conf[f'{model_type}/{basename}'] = dict( - path = os.path.join(root,f), - description = f'{model_type} model {basename}', - format = format, - base = base, - ) + parents = root.split('/') + subpaths = parents[parents.index('models')+1:] + if len(subpaths) < 2: + continue + base, model_type, *_ = subpaths + + if args.all_models or model_type=='diffusers': + for d in dirs: + conf[f'{base}/{model_type}/{d}'] = dict( + path = os.path.join(root,d), + description = f'{model_type} model {d}', + format = 'folder', + base = base, + ) + + for f in files: + basename = Path(f).stem + format = Path(f).suffix[1:] + conf[f'{base}/{model_type}/{basename}'] = dict( + path = os.path.join(root,f), + description = f'{model_type} model {basename}', + format = format, + base = base, + ) OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)