diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index f80e845f17..bd70e23f5b 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -133,6 +133,7 @@ from enum import Enum, auto from pathlib import Path from shutil import rmtree from typing import Union, Callable, types +from contextlib import suppress import safetensors import safetensors.torch @@ -192,13 +193,13 @@ class ModelManager(object): logger: types.ModuleType = logger def __init__( - self, - config: Union[Path, DictConfig, str], - device_type: torch.device = CUDA_DEVICE, - precision: torch.dtype = torch.float16, - max_cache_size=MAX_CACHE_SIZE, - sequential_offload=False, - logger: types.ModuleType = logger, + self, + config: Union[Path, DictConfig, str], + device_type: torch.device = CUDA_DEVICE, + precision: torch.dtype = torch.float16, + max_cache_size=MAX_CACHE_SIZE, + sequential_offload=False, + logger: types.ModuleType = logger, ): """ Initialize with the path to the models.yaml config file. @@ -225,22 +226,36 @@ class ModelManager(object): self.cache_keys = dict() self.logger = logger - def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool: + # TODO: rename to smth like - is_model_exists + def valid_model( + self, + model_name: str, + model_type: SDModelType = SDModelType.diffusers, + ) -> bool: """ Given a model name, returns True if it is a valid identifier. """ - try: - self._disambiguate_name(model_name, model_type) - return True - except InvalidModelError: - return False + model_key = self.create_key(model_name, model_class) + return model_key in self.config - def get_model(self, - model_name: str, - model_type: SDModelType=None, - submodel: SDModelType=None, - ) -> SDModelInfo: + def create_key(self, model_name: str, model_type: SDModelType) -> str: + return f"{model_type.name}/{model_name}" + + def parse_key(self, model_key: str) -> Tuple[str, SDModelType]: + model_type_str, model_name = model_key.split('/', 1) + if model_type_str not in SDModelType.__members__: + # TODO: + raise Exception(f"Unkown model type: {model_type_str}") + + return (model_name, SDModelType[model_type_str]) + + def get_model( + self, + model_name: str, + model_type: SDModelType=None, + submodel: SDModelType=None, + ) -> SDModelInfo: """Given a model named identified in models.yaml, return an SDModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml @@ -254,85 +269,77 @@ class ModelManager(object): assume a diffusers pipeline. The behavior is illustrated here: [models.yaml] - test1/diffusers: + diffusers/test1: repo_id: foo/bar - format: diffusers description: Typical diffusers pipeline - test1/lora: + lora/test1: repo_id: /tmp/loras/test1.safetensors - format: lora description: Typical lora file test1_pipeline = mgr.get_model('test1') # returns a StableDiffusionGeneratorPipeline - test1_vae1 = mgr.get_model('test1',submodel=SDModelType.vae) + test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae) # returns the VAE part of a diffusers model as an AutoencoderKL - test1_vae2 = mgr.get_model('test1',model_type=SDModelType.diffusers,submodel=SDModelType.vae) + test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae) # does the same thing as the previous statement. Note that model_type # is for the parent model, and submodel is for the part - test1_lora = mgr.get_model('test1',model_type=SDModelType.lora) + test1_lora = mgr.get_model('test1', model_type=SDModelType.lora) # returns a LoRA embed (as a 'dict' of tensors) - test1_encoder = mgr.get_modelI('test1',model_type=SDModelType.textencoder) + test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder) # raises an InvalidModelError """ - if not model_name: - model_name = self.default_model() + # TODO: delete default model or add check that this stable diffusion model + # if not model_name: + # model_name = self.default_model() - model_key = self._disambiguate_name(model_name, model_type) + model_key = self.create_key(model_name, model_type) + if model_key not in self.config: + raise InvalidModelError( + f'"{model_key}" is not a known model name. Please check your models.yaml file' + ) # get the required loading info out of the config file mconfig = self.config[model_key] - format = mconfig.get('format','diffusers') - if model_type and model_type.name != format: - raise InvalidModelError( - f'Inconsistent model definition; {model_key} has format {format}, but type {model_type.name} was requested' - ) - - model_parts = dict([(x.name,x) for x in SDModelType]) - - if format == 'diffusers': + # type already checked as it's part of key + if model_type == SDModelType.diffusers: # intercept stanzas that point to checkpoint weights and replace them # with the equivalent diffusers model if 'weights' in mconfig: location = self.convert_ckpt_and_cache(mconfig) else: location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id') - elif format in model_parts: - location = global_resolve_path(mconfig.get('path')) \ - or mconfig.get('repo_id') \ - or global_resolve_path(mconfig.get('weights')) else: - raise InvalidModelError( - f'"{model_key}" has an unknown format {format}' + location = global_resolve_path( + mconfig.get('path')) \ + or mconfig.get('repo_id') \ + or global_resolve_path(mconfig.get('weights') ) - model_type = model_parts[format] subfolder = mconfig.get('subfolder') revision = mconfig.get('revision') - hash = self.cache.model_hash(location,revision) + hash = self.cache.model_hash(location, revision) # to support the traditional way of attaching a VAE # to a model, we hacked in `attach_model_part` - vae = (None,None) - try: + vae = (None, None) + with suppress(Exception): vae_id = mconfig.vae.repo_id - vae = (SDModelType.vae,vae_id) - except Exception: - pass + vae = (SDModelType.vae, vae_id) + model_context = self.cache.get_model( location, model_type = model_type, revision = revision, subfolder = subfolder, submodel = submodel, - attach_model_part=vae, + attach_model_part = vae, ) # in case we need to communicate information about this @@ -402,27 +409,28 @@ class ModelManager(object): def list_models(self) -> dict: """ - Return a dict of models in the format: - { model_name1: {'status': ('active'|'cached'|'not loaded'), - 'description': description, - 'format': ('ckpt'|'diffusers'|'vae'), - }, - model_name2: { etc } + Return a dict of models Please use model_manager.models() to get all the model names, model_manager.model_info('model-name') to get the stanza for the model named 'model-name', and model_manager.config to get the full OmegaConf object derived from models.yaml """ models = {} - for name in sorted(self.config, key=str.casefold): - stanza = self.config[name] + for model_key in sorted(self.config, key=str.casefold): + stanza = self.config[model_key] # don't include VAEs in listing (legacy style) if "config" in stanza and "/VAE/" in stanza["config"]: continue - models[name] = dict() - format = stanza.get("format", "ckpt") # Determine Format + model_name, model_type = self.parse_key(model_key) + models[model_name] = dict() + + # TODO: return all models in future + if model_type != SDModelType.diffusers: + continue + + model_format = "ckpt" if "weights" in stanza else "diffusers" # Common Attribs status = self.cache.status( @@ -431,37 +439,38 @@ class ModelManager(object): subfolder=stanza.get('subfolder') ) description = stanza.get("description", None) - models[name].update( + models[model_name].update( description=description, - format=format, + type=model_type, + format=model_format, status=status.value ) # Checkpoint Config Parse - if format == "ckpt": - models[name].update( - config=str(stanza.get("config", None)), - weights=str(stanza.get("weights", None)), - vae=str(stanza.get("vae", None)), - width=str(stanza.get("width", 512)), - height=str(stanza.get("height", 512)), + if model_format == "ckpt": + models[model_name].update( + config = str(stanza.get("config", None)), + weights = str(stanza.get("weights", None)), + vae = str(stanza.get("vae", None)), + width = str(stanza.get("width", 512)), + height = str(stanza.get("height", 512)), ) # Diffusers Config Parse - if vae := stanza.get("vae", None): - if isinstance(vae, DictConfig): - vae = dict( - repo_id=str(vae.get("repo_id", None)), - path=str(vae.get("path", None)), - subfolder=str(vae.get("subfolder", None)), - ) + elif model_format == "diffusers": + if vae := stanza.get("vae", None): + if isinstance(vae, DictConfig): + vae = dict( + repo_id = str(vae.get("repo_id", None)), + path = str(vae.get("path", None)), + subfolder = str(vae.get("subfolder", None)), + ) - if format == "diffusers": - models[name].update( - vae=vae, - repo_id=str(stanza.get("repo_id", None)), - path=str(stanza.get("path", None)), + models[model_name].update( + vae = vae, + repo_id = str(stanza.get("repo_id", None)), + path = str(stanza.get("path", None)), ) return models @@ -472,44 +481,60 @@ class ModelManager(object): """ models = self.list_models() for name in models: - if models[name]["format"] == "vae": + if models[name]["type"] == "vae": continue - line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}' + line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}' if models[name]["status"] == "active": line = f"\033[1m{line}\033[0m" print(line) - def del_model(self, model_name: str, model_type: SDModelType.diffusers, delete_files: bool = False): + def del_model( + self, + model_name: str, + model_type: SDModelType.diffusers, + delete_files: bool = False + ): """ Delete the named model. """ - model_name = self._disambiguate_name(model_name, model_type) - omega = self.config - if model_name not in omega: - self.logger.error(f"Unknown model {model_name}") - return - # save these for use in deletion later - conf = omega[model_name] - repo_id = conf.get("repo_id", None) - path = self._abs_path(conf.get("path", None)) - weights = self._abs_path(conf.get("weights", None)) + model_key = self.create_key(model_name, model_type) + model_cfg = self.pop(model_key, None) + + if model_cfg is None: + self.logger.error( + f"Unknown model {model_key}" + ) + return + + # TODO: some legacy? + #if model_name in self.stack: + # self.stack.remove(model_name) - del omega[model_name] - if model_name in self.stack: - self.stack.remove(model_name) if delete_files: - if weights: + repo_id = conf.get("repo_id", None) + path = self._abs_path(conf.get("path", None)) + weights = self._abs_path(conf.get("weights", None)) + if "weights" in model_cfg: + weights = self._abs_path(model_cfg["weights"]) self.logger.info(f"Deleting file {weights}") Path(weights).unlink(missing_ok=True) - elif path: + + elif "path" in model_cfg: + path = self._abs_path(model_cfg["path"]) self.logger.info(f"Deleting directory {path}") rmtree(path, ignore_errors=True) - elif repo_id: + + elif "repo_id" in model_cfg: + repo_id = model_cfg["repo_id"] self.logger.info(f"Deleting the cached model directory for {repo_id}") self._delete_model_from_cache(repo_id) def add_model( - self, model_name: str, model_attributes: dict, clobber: bool = False + self, + model_name: str, + model_type: SDModelType, + model_attributes: dict, + clobber: bool = False ) -> None: """ Update the named model with a dictionary of attributes. Will fail with an @@ -518,37 +543,47 @@ class ModelManager(object): method will return True. Will fail with an assertion error if provided attributes are incorrect or the model name is missing. """ - omega = self.config - assert "format" in model_attributes, 'missing required field "format"' - if model_attributes["format"] == "diffusers": - assert ( - "description" in model_attributes - ), 'required field "description" is missing' - assert ( - "path" in model_attributes or "repo_id" in model_attributes - ), 'model must have either the "path" or "repo_id" fields defined' - elif model_attributes["format"] == "ckpt": - for field in ("description", "weights", "height", "width", "config"): - assert field in model_attributes, f"required field {field} is missing" + if model_type == SDModelType.diffusers: + # TODO: automaticaly or manualy? + #assert "format" in model_attributes, 'missing required field "format"' + model_format = "ckpt" if "weights" in model_attributes else "diffusers" + + if model_format == "diffusers": + assert ( + "description" in model_attributes + ), 'required field "description" is missing' + assert ( + "path" in model_attributes or "repo_id" in model_attributes + ), 'model must have either the "path" or "repo_id" fields defined' + + elif model_format == "ckpt": + for field in ("description", "weights", "height", "width", "config"): + assert field in model_attributes, f"required field {field} is missing" + else: assert "weights" in model_attributes and "description" in model_attributes - model_key = f'{model_name}/{model_attributes["format"]}' + model_key = self.create_key(model_name, model_type) assert ( - clobber or model_key not in omega + clobber or model_key not in self.config ), f'attempt to overwrite existing model definition "{model_key}"' - omega[model_key] = model_attributes + self.config[model_key] = model_attributes - if "weights" in omega[model_key]: - omega[model_key]["weights"].replace("\\", "/") + if "weights" in self.config[model_key]: + self.config[model_key]["weights"].replace("\\", "/") if clobber and model_key in self.cache_keys: self.cache.uncache_model(self.cache_keys[model_key]) del self.cache_keys[model_key] + + + + + def import_diffuser_model( self, repo_or_path: Union[str, Path], @@ -587,10 +622,10 @@ class ModelManager(object): return model_key def import_lora( - self, - path: Path, - model_name: str=None, - description: str=None, + self, + path: Path, + model_name: str=None, + description: str=None, ): """ Creates an entry for the indicated lora file. Call @@ -599,20 +634,21 @@ class ModelManager(object): path = Path(path) model_name = model_name or path.stem model_description = description or f"LoRA model {model_name}" - self.add_model(f'{model_name}/{SDModelType.lora.name}', - dict( - format="lora", - weights=str(path), - description=model_description, - ), - True - ) + self.add_model( + f'{model_name}/{SDModelType.lora.name}', + dict( + format="lora", + weights=str(path), + description=model_description, + ), + True + ) def import_embedding( - self, - path: Path, - model_name: str=None, - description: str=None, + self, + path: Path, + model_name: str=None, + description: str=None, ): """ Creates an entry for the indicated lora file. Call @@ -626,14 +662,15 @@ class ModelManager(object): model_name = model_name or path.stem model_description = description or f"Textual embedding model {model_name}" - self.add_model(f'{model_name}/{SDModelType.textual_inversion.name}', - dict( - format="textual_inversion", - weights=str(weights), - description=model_description, - ), - True - ) + self.add_model( + f'{model_name}/{SDModelType.textual_inversion.name}', + dict( + format="textual_inversion", + weights=str(weights), + description=model_description, + ), + True + ) @classmethod def probe_model_type(self, checkpoint: dict) -> SDLegacyType: @@ -857,7 +894,7 @@ class ModelManager(object): ) return model_name - def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path: + def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path: """ Convert the checkpoint model indicated in mconfig into a diffusers, cache it to disk, and return Path to converted @@ -872,6 +909,7 @@ class ModelManager(object): return diffusers_path vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig) + # to avoid circular import errors from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers with SilenceWarnings(): @@ -881,15 +919,16 @@ class ModelManager(object): extract_ema=True, original_config_file=config_file, vae=vae_model, - vae_path=str(global_resolve_path(vae_ckpt_path)), + vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None, scan_needed=True, ) return diffusers_path - def _get_vae_for_conversion(self, - weights: Path, - mconfig: DictConfig - )->tuple(Path,SDModelType.vae): + def _get_vae_for_conversion( + self, + weights: Path, + mconfig: DictConfig + ) -> Tuple[Path, SDModelType.vae]: # VAE handling is convoluted # 1. If there is a .vae.ckpt file sharing same stem as weights, then use # it as the vae_path passed to convert @@ -1047,21 +1086,7 @@ class ModelManager(object): # model requires a model config file, a weights file, # and the width and height of the images it # was trained on. - """ - ) - - def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str: - model_type = model_type or SDModelType.diffusers - full_name = f"{model_name}/{model_type.name}" - if full_name in self.config: - return full_name - # special case - if diffusers requested, then allow name without type appended - if model_type==SDModelType.diffusers \ - and model_name in self.config \ - and self.config[model_name].format=='diffusers': - return model_name - raise InvalidModelError( - f'"{full_name}" is not a known model name. Please check your models.yaml file' + """ )