From 9cb962cad7fedcbfcb138fa4d7c2420b0ce1b037 Mon Sep 17 00:00:00 2001
From: Lincoln Stein <lstein@gmail.com>
Date: Mon, 8 May 2023 23:39:44 -0400
Subject: [PATCH] ckpt model conversion now done in ModelCache

---
 invokeai/backend/generate.py                  |   6 +-
 invokeai/backend/globals.py                   |   2 +
 .../backend/model_management/model_cache.py   |  34 ++--
 .../backend/model_management/model_manager.py | 147 ++++++++++++++----
 invokeai/frontend/CLI/CLI.py                  |   6 +-
 5 files changed, 134 insertions(+), 61 deletions(-)

diff --git a/invokeai/backend/generate.py b/invokeai/backend/generate.py
index c7e2558db1..8cddc1496b 100644
--- a/invokeai/backend/generate.py
+++ b/invokeai/backend/generate.py
@@ -150,7 +150,7 @@ class Generate:
         esrgan=None,
         free_gpu_mem: bool = False,
         safety_checker: bool = False,
-        max_loaded_models: int = 2,
+        max_cache_size: int = 6,
         # these are deprecated; if present they override values in the conf file
         weights=None,
         config=None,
@@ -183,7 +183,7 @@ class Generate:
         self.codeformer = codeformer
         self.esrgan = esrgan
         self.free_gpu_mem = free_gpu_mem
-        self.max_loaded_models = (max_loaded_models,)
+        self.max_cache_size = max_cache_size
         self.size_matters = True  # used to warn once about large image sizes and VRAM
         self.txt2mask = None
         self.safety_checker = None
@@ -220,7 +220,7 @@ class Generate:
             conf,
             self.device,
             torch_dtype(self.device),
-            max_loaded_models=max_loaded_models,
+            max_cache_size=max_cache_size,
             sequential_offload=self.free_gpu_mem,
 #            embedding_path=Path(self.embedding_path),
         )
diff --git a/invokeai/backend/globals.py b/invokeai/backend/globals.py
index 37a59b1135..5106ddb67d 100644
--- a/invokeai/backend/globals.py
+++ b/invokeai/backend/globals.py
@@ -94,6 +94,8 @@ def global_set_root(root_dir: Union[str, Path]):
     Globals.root = root_dir
 
 def global_resolve_path(path: Union[str,Path]):
+    if path is None:
+        return None
     return Path(Globals.root,path).resolve()
 
 def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py
index 173fd87623..b8f44f82ec 100644
--- a/invokeai/backend/model_management/model_cache.py
+++ b/invokeai/backend/model_management/model_cache.py
@@ -361,9 +361,10 @@ class ModelCache(object):
                )->ModelStatus:
         key = self._model_key(
             repo_id_or_path,
-            model_type.value,
             revision,
-            subfolder)
+            subfolder,
+            model_type.value,
+        )
         if key not in self.models:
             return ModelStatus.not_loaded
         if key in self.loaded_models:
@@ -384,9 +385,7 @@ class ModelCache(object):
         :param revision: optional revision string (if fetching a HF repo_id)
         '''
         revision = revision or "main"
-        if self.is_legacy_ckpt(repo_id_or_path):
-            return self._legacy_model_hash(repo_id_or_path)
-        elif Path(repo_id_or_path).is_dir():
+        if Path(repo_id_or_path).is_dir():
             return self._local_model_hash(repo_id_or_path)
         else:
             return self._hf_commit_hash(repo_id_or_path,revision)
@@ -395,15 +394,6 @@ class ModelCache(object):
         "Return the current size of the cache, in GB"
         return self.current_cache_size / GIG
 
-    @classmethod
-    def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
-        '''
-        Return true if the indicated path is a legacy checkpoint
-        :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
-        '''
-        path = Path(repo_id_or_path)
-        return path.suffix in [".ckpt",".safetensors",".pt"]
-
     @classmethod
     def scan_model(cls, model_name, checkpoint):
         """
@@ -482,16 +472,12 @@ class ModelCache(object):
         '''
         # silence transformer and diffuser warnings
         with SilenceWarnings():
-            # !!! NOTE: conversion should not happen here, but in ModelManager
-            if self.is_legacy_ckpt(repo_id_or_path):
-                model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
-            else:
-                model = self._load_diffusers_from_storage(
-                    repo_id_or_path,
-                    subfolder,
-                    revision,
-                    model_class,
-                )
+            model = self._load_diffusers_from_storage(
+                repo_id_or_path,
+                subfolder,
+                revision,
+                model_class,
+            )
             if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
                 model.enable_offload_submodels(self.execution_device)
         return model
diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py
index 94e514a013..4c0b1b3ad9 100644
--- a/invokeai/backend/model_management/model_manager.py
+++ b/invokeai/backend/model_management/model_manager.py
@@ -143,7 +143,7 @@ from omegaconf import OmegaConf
 from omegaconf.dictconfig import DictConfig
 
 from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
-from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
+from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
 
 from ..util import CUDA_DEVICE
 
@@ -225,12 +225,16 @@ class ModelManager(object):
         self.cache_keys = dict()
         self.logger = logger
 
-    def valid_model(self, model_name: str) -> bool:
+    def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
         """
         Given a model name, returns True if it is a valid
         identifier.
         """
-        return model_name in self.config
+        try:
+            self._disambiguate_name(model_name, model_type)
+            return True
+        except InvalidModelError:
+            return False
 
     def get_model(self,
                   model_name: str,
@@ -294,17 +298,17 @@ class ModelManager(object):
         model_parts = dict([(x.name,x) for x in SDModelType])
         legacy = None
         
-        if format=='ckpt':
-            location = global_resolve_path(mconfig.weights)
-            legacy = LegacyInfo(
-                config_file = global_resolve_path(mconfig.config),
-            )
-            if mconfig.get('vae'):
-                legacy.vae_file = global_resolve_path(mconfig.vae)
-        elif format=='diffusers':
-            location = mconfig.get('repo_id') or mconfig.get('path')
+        if format == 'diffusers':
+            # intercept stanzas that point to checkpoint weights and replace them
+            # with the equivalent diffusers model
+            if 'weights' in mconfig:
+                location = self.convert_ckpt_and_cache(mconfig)
+            else:
+                location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
         elif format in model_parts:
-            location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
+            location = global_resolve_path(mconfig.get('path')) \
+                or mconfig.get('repo_id') \
+                or global_resolve_path(mconfig.get('weights'))
         else:
             raise InvalidModelError(
                 f'"{model_key}" has an unknown format {format}'
@@ -531,7 +535,7 @@ class ModelManager(object):
         else:
             assert "weights" in model_attributes and "description" in model_attributes
 
-        model_key = f'{model_name}/{format}'
+        model_key = f'{model_name}/{model_attributes["format"]}'
 
         assert (
             clobber or model_key not in omega
@@ -776,7 +780,7 @@ class ModelManager(object):
         # another round of heuristics to guess the correct config file.
         checkpoint = None
         if model_path.suffix in [".ckpt", ".pt"]:
-            self.scan_model(model_path, model_path)
+            self.cache.scan_model(model_path, model_path)
             checkpoint = torch.load(model_path)
         else:
             checkpoint = safetensors.torch.load_file(model_path)
@@ -840,19 +844,86 @@ class ModelManager(object):
         diffuser_path = Path(
             Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
         )
-        model_name = self.convert_and_import(
-            model_path,
-            diffusers_path=diffuser_path,
-            vae=vae,
-            vae_path=str(vae_path),
-            model_name=model_name,
-            model_description=description,
-            original_config_file=model_config_file,
-            commit_to_conf=commit_to_conf,
-            scan_needed=False,
-        )
+        with SilenceWarnings():
+            model_name = self.convert_and_import(
+                model_path,
+                diffusers_path=diffuser_path,
+                vae=vae,
+                vae_path=str(vae_path),
+                model_name=model_name,
+                model_description=description,
+                original_config_file=model_config_file,
+                commit_to_conf=commit_to_conf,
+                scan_needed=False,
+            )
         return model_name
 
+    def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path:
+        """
+        Convert the checkpoint model indicated in mconfig into a
+        diffusers, cache it to disk, and return Path to converted
+        file. If already on disk then just returns Path.
+        """
+        weights = global_resolve_path(mconfig.weights)
+        config_file = global_resolve_path(mconfig.config)
+        diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem
+
+        # return cached version if it exists
+        if diffusers_path.exists():
+            return diffusers_path
+
+        vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
+        # to avoid circular import errors
+        from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
+        with SilenceWarnings():        
+            convert_ckpt_to_diffusers(
+                weights,
+                diffusers_path,
+                extract_ema=True,
+                original_config_file=config_file,
+                vae=vae_model,
+                vae_path=str(global_resolve_path(vae_ckpt_path)),
+                scan_needed=True,
+            )
+        return diffusers_path
+
+    def _get_vae_for_conversion(self,
+                                weights: Path,
+                                mconfig: DictConfig
+                                )->tuple(Path,SDModelType.vae):
+        # VAE handling is convoluted
+        # 1. If there is a .vae.ckpt file sharing same stem as weights, then use
+        # it as the vae_path passed to convert
+        vae_ckpt_path = None
+        vae_diffusers_location = None
+        vae_model = None
+        for suffix in ["pt", "ckpt", "safetensors"]:
+            if (weights.with_suffix(f".vae.{suffix}")).exists():
+                vae_ckpt_path = weights.with_suffix(f".vae.{suffix}")
+                self.logger.debug(f"Using VAE file {vae_ckpt_path.name}")
+        if vae_ckpt_path:
+            return (vae_ckpt_path, None)
+                
+        # 2. If mconfig has a vae weights path, then we use that as vae_path
+        vae_config = mconfig.get('vae')
+        if vae_config and isinstance(vae_config,str):
+            vae_ckpt_path = vae_config
+            return (vae_ckpt_path, None)
+            
+        # 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
+        if vae_config and isinstance(vae_config,DictConfig):
+            vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id')
+
+        # 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
+        else:
+            vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
+
+        if vae_diffusers_location:
+            vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model
+            return (None, vae_model)
+
+        return (None, None)
+            
     def convert_and_import(
         self,
         ckpt_path: Path,
@@ -895,7 +966,8 @@ class ModelManager(object):
             # will be built into the model rather than tacked on afterward via the config file
             vae_model = None
             if vae:
-                vae_model = self._load_vae(vae)
+                vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
+                vae_model = self.cache.get_model(vae_location,SDModelType.vae).model
                 vae_path = None
             convert_ckpt_to_diffusers(
                 ckpt_path,
@@ -982,9 +1054,9 @@ class ModelManager(object):
     def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
         model_type = model_type or SDModelType.diffusers
         full_name = f"{model_name}/{model_type.name}"
-        if self.valid_model(full_name):
+        if full_name in self.config:
             return full_name
-        if self.valid_model(model_name):
+        if model_name in self.config:
             return model_name
         raise InvalidModelError(
             f'Neither "{model_name}" nor "{full_name}" are known model names. Please check your models.yaml file'
@@ -1014,3 +1086,20 @@ class ModelManager(object):
             return path
         return Path(Globals.root, path).resolve()
 
+    # This is not the same as global_resolve_path(), which prepends
+    # Globals.root.
+    def _resolve_path(
+        self, source: Union[str, Path], dest_directory: str
+    ) -> Optional[Path]:
+        resolved_path = None
+        if str(source).startswith(("http:", "https:", "ftp:")):
+            dest_directory = Path(dest_directory)
+            if not dest_directory.is_absolute():
+                dest_directory = Globals.root / dest_directory
+            dest_directory.mkdir(parents=True, exist_ok=True)
+            resolved_path = download_with_resume(str(source), dest_directory)
+        else:
+            if not os.path.isabs(source):
+                source = os.path.join(Globals.root, source)
+            resolved_path = Path(source)
+        return resolved_path
diff --git a/invokeai/frontend/CLI/CLI.py b/invokeai/frontend/CLI/CLI.py
index 0c984080a6..8525853e93 100644
--- a/invokeai/frontend/CLI/CLI.py
+++ b/invokeai/frontend/CLI/CLI.py
@@ -54,10 +54,6 @@ def main():
             "--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead."
         )
         sys.exit(-1)
-    if args.max_loaded_models is not None:
-        if args.max_loaded_models <= 0:
-            print("--max_loaded_models must be >= 1; using 1")
-            args.max_loaded_models = 1
 
     # alert - setting a few globals here
     Globals.try_patchmatch = args.patchmatch
@@ -136,7 +132,7 @@ def main():
             esrgan=esrgan,
             free_gpu_mem=opt.free_gpu_mem,
             safety_checker=opt.safety_checker,
-            max_loaded_models=opt.max_loaded_models,
+            max_cache_size=opt.max_cache_size,
         )
     except (FileNotFoundError, TypeError, AssertionError) as e:
         report_model_error(opt, e)