diff --git a/invokeai/backend/generate.py b/invokeai/backend/generate.py index c7e2558db1..8cddc1496b 100644 --- a/invokeai/backend/generate.py +++ b/invokeai/backend/generate.py @@ -150,7 +150,7 @@ class Generate: esrgan=None, free_gpu_mem: bool = False, safety_checker: bool = False, - max_loaded_models: int = 2, + max_cache_size: int = 6, # these are deprecated; if present they override values in the conf file weights=None, config=None, @@ -183,7 +183,7 @@ class Generate: self.codeformer = codeformer self.esrgan = esrgan self.free_gpu_mem = free_gpu_mem - self.max_loaded_models = (max_loaded_models,) + self.max_cache_size = max_cache_size self.size_matters = True # used to warn once about large image sizes and VRAM self.txt2mask = None self.safety_checker = None @@ -220,7 +220,7 @@ class Generate: conf, self.device, torch_dtype(self.device), - max_loaded_models=max_loaded_models, + max_cache_size=max_cache_size, sequential_offload=self.free_gpu_mem, # embedding_path=Path(self.embedding_path), ) diff --git a/invokeai/backend/globals.py b/invokeai/backend/globals.py index 37a59b1135..5106ddb67d 100644 --- a/invokeai/backend/globals.py +++ b/invokeai/backend/globals.py @@ -94,6 +94,8 @@ def global_set_root(root_dir: Union[str, Path]): Globals.root = root_dir def global_resolve_path(path: Union[str,Path]): + if path is None: + return None return Path(Globals.root,path).resolve() def global_cache_dir(subdir: Union[str, Path] = "") -> Path: diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 173fd87623..b8f44f82ec 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -361,9 +361,10 @@ class ModelCache(object): )->ModelStatus: key = self._model_key( repo_id_or_path, - model_type.value, revision, - subfolder) + subfolder, + model_type.value, + ) if key not in self.models: return ModelStatus.not_loaded if key in self.loaded_models: @@ -384,9 +385,7 @@ class ModelCache(object): :param revision: optional revision string (if fetching a HF repo_id) ''' revision = revision or "main" - if self.is_legacy_ckpt(repo_id_or_path): - return self._legacy_model_hash(repo_id_or_path) - elif Path(repo_id_or_path).is_dir(): + if Path(repo_id_or_path).is_dir(): return self._local_model_hash(repo_id_or_path) else: return self._hf_commit_hash(repo_id_or_path,revision) @@ -395,15 +394,6 @@ class ModelCache(object): "Return the current size of the cache, in GB" return self.current_cache_size / GIG - @classmethod - def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool: - ''' - Return true if the indicated path is a legacy checkpoint - :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model - ''' - path = Path(repo_id_or_path) - return path.suffix in [".ckpt",".safetensors",".pt"] - @classmethod def scan_model(cls, model_name, checkpoint): """ @@ -482,16 +472,12 @@ class ModelCache(object): ''' # silence transformer and diffuser warnings with SilenceWarnings(): - # !!! NOTE: conversion should not happen here, but in ModelManager - if self.is_legacy_ckpt(repo_id_or_path): - model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info) - else: - model = self._load_diffusers_from_storage( - repo_id_or_path, - subfolder, - revision, - model_class, - ) + model = self._load_diffusers_from_storage( + repo_id_or_path, + subfolder, + revision, + model_class, + ) if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): model.enable_offload_submodels(self.execution_device) return model diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 94e514a013..4c0b1b3ad9 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -143,7 +143,7 @@ from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path -from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo +from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings from ..util import CUDA_DEVICE @@ -225,12 +225,16 @@ class ModelManager(object): self.cache_keys = dict() self.logger = logger - def valid_model(self, model_name: str) -> bool: + def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool: """ Given a model name, returns True if it is a valid identifier. """ - return model_name in self.config + try: + self._disambiguate_name(model_name, model_type) + return True + except InvalidModelError: + return False def get_model(self, model_name: str, @@ -294,17 +298,17 @@ class ModelManager(object): model_parts = dict([(x.name,x) for x in SDModelType]) legacy = None - if format=='ckpt': - location = global_resolve_path(mconfig.weights) - legacy = LegacyInfo( - config_file = global_resolve_path(mconfig.config), - ) - if mconfig.get('vae'): - legacy.vae_file = global_resolve_path(mconfig.vae) - elif format=='diffusers': - location = mconfig.get('repo_id') or mconfig.get('path') + if format == 'diffusers': + # intercept stanzas that point to checkpoint weights and replace them + # with the equivalent diffusers model + if 'weights' in mconfig: + location = self.convert_ckpt_and_cache(mconfig) + else: + location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id') elif format in model_parts: - location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights') + location = global_resolve_path(mconfig.get('path')) \ + or mconfig.get('repo_id') \ + or global_resolve_path(mconfig.get('weights')) else: raise InvalidModelError( f'"{model_key}" has an unknown format {format}' @@ -531,7 +535,7 @@ class ModelManager(object): else: assert "weights" in model_attributes and "description" in model_attributes - model_key = f'{model_name}/{format}' + model_key = f'{model_name}/{model_attributes["format"]}' assert ( clobber or model_key not in omega @@ -776,7 +780,7 @@ class ModelManager(object): # another round of heuristics to guess the correct config file. checkpoint = None if model_path.suffix in [".ckpt", ".pt"]: - self.scan_model(model_path, model_path) + self.cache.scan_model(model_path, model_path) checkpoint = torch.load(model_path) else: checkpoint = safetensors.torch.load_file(model_path) @@ -840,19 +844,86 @@ class ModelManager(object): diffuser_path = Path( Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem ) - model_name = self.convert_and_import( - model_path, - diffusers_path=diffuser_path, - vae=vae, - vae_path=str(vae_path), - model_name=model_name, - model_description=description, - original_config_file=model_config_file, - commit_to_conf=commit_to_conf, - scan_needed=False, - ) + with SilenceWarnings(): + model_name = self.convert_and_import( + model_path, + diffusers_path=diffuser_path, + vae=vae, + vae_path=str(vae_path), + model_name=model_name, + model_description=description, + original_config_file=model_config_file, + commit_to_conf=commit_to_conf, + scan_needed=False, + ) return model_name + def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path: + """ + Convert the checkpoint model indicated in mconfig into a + diffusers, cache it to disk, and return Path to converted + file. If already on disk then just returns Path. + """ + weights = global_resolve_path(mconfig.weights) + config_file = global_resolve_path(mconfig.config) + diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem + + # return cached version if it exists + if diffusers_path.exists(): + return diffusers_path + + vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig) + # to avoid circular import errors + from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers + with SilenceWarnings(): + convert_ckpt_to_diffusers( + weights, + diffusers_path, + extract_ema=True, + original_config_file=config_file, + vae=vae_model, + vae_path=str(global_resolve_path(vae_ckpt_path)), + scan_needed=True, + ) + return diffusers_path + + def _get_vae_for_conversion(self, + weights: Path, + mconfig: DictConfig + )->tuple(Path,SDModelType.vae): + # VAE handling is convoluted + # 1. If there is a .vae.ckpt file sharing same stem as weights, then use + # it as the vae_path passed to convert + vae_ckpt_path = None + vae_diffusers_location = None + vae_model = None + for suffix in ["pt", "ckpt", "safetensors"]: + if (weights.with_suffix(f".vae.{suffix}")).exists(): + vae_ckpt_path = weights.with_suffix(f".vae.{suffix}") + self.logger.debug(f"Using VAE file {vae_ckpt_path.name}") + if vae_ckpt_path: + return (vae_ckpt_path, None) + + # 2. If mconfig has a vae weights path, then we use that as vae_path + vae_config = mconfig.get('vae') + if vae_config and isinstance(vae_config,str): + vae_ckpt_path = vae_config + return (vae_ckpt_path, None) + + # 3. If mconfig has a vae dict, then we use it as the diffusers-style vae + if vae_config and isinstance(vae_config,DictConfig): + vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id') + + # 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works" + else: + vae_diffusers_location = "stabilityai/sd-vae-ft-mse" + + if vae_diffusers_location: + vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model + return (None, vae_model) + + return (None, None) + def convert_and_import( self, ckpt_path: Path, @@ -895,7 +966,8 @@ class ModelManager(object): # will be built into the model rather than tacked on afterward via the config file vae_model = None if vae: - vae_model = self._load_vae(vae) + vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id') + vae_model = self.cache.get_model(vae_location,SDModelType.vae).model vae_path = None convert_ckpt_to_diffusers( ckpt_path, @@ -982,9 +1054,9 @@ class ModelManager(object): def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str: model_type = model_type or SDModelType.diffusers full_name = f"{model_name}/{model_type.name}" - if self.valid_model(full_name): + if full_name in self.config: return full_name - if self.valid_model(model_name): + if model_name in self.config: return model_name raise InvalidModelError( f'Neither "{model_name}" nor "{full_name}" are known model names. Please check your models.yaml file' @@ -1014,3 +1086,20 @@ class ModelManager(object): return path return Path(Globals.root, path).resolve() + # This is not the same as global_resolve_path(), which prepends + # Globals.root. + def _resolve_path( + self, source: Union[str, Path], dest_directory: str + ) -> Optional[Path]: + resolved_path = None + if str(source).startswith(("http:", "https:", "ftp:")): + dest_directory = Path(dest_directory) + if not dest_directory.is_absolute(): + dest_directory = Globals.root / dest_directory + dest_directory.mkdir(parents=True, exist_ok=True) + resolved_path = download_with_resume(str(source), dest_directory) + else: + if not os.path.isabs(source): + source = os.path.join(Globals.root, source) + resolved_path = Path(source) + return resolved_path diff --git a/invokeai/frontend/CLI/CLI.py b/invokeai/frontend/CLI/CLI.py index 0c984080a6..8525853e93 100644 --- a/invokeai/frontend/CLI/CLI.py +++ b/invokeai/frontend/CLI/CLI.py @@ -54,10 +54,6 @@ def main(): "--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead." ) sys.exit(-1) - if args.max_loaded_models is not None: - if args.max_loaded_models <= 0: - print("--max_loaded_models must be >= 1; using 1") - args.max_loaded_models = 1 # alert - setting a few globals here Globals.try_patchmatch = args.patchmatch @@ -136,7 +132,7 @@ def main(): esrgan=esrgan, free_gpu_mem=opt.free_gpu_mem, safety_checker=opt.safety_checker, - max_loaded_models=opt.max_loaded_models, + max_cache_size=opt.max_cache_size, ) except (FileNotFoundError, TypeError, AssertionError) as e: report_model_error(opt, e)