diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index a0a899a319..4f94395a86 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -30,7 +30,7 @@ from diffusers import ( UNet2DConditionModel, SchedulerMixin, logging as dlogging, -) +) from huggingface_hub import scan_cache_dir from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig @@ -68,7 +68,7 @@ class SDModelComponent(Enum): scheduler="scheduler" safety_checker="safety_checker" feature_extractor="feature_extractor" - + DEFAULT_MAX_MODELS = 2 class ModelManager(object): @@ -182,7 +182,7 @@ class ModelManager(object): vae from the model currently in the GPU. """ return self._get_sub_model(model_name, SDModelComponent.vae) - + def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer: """Given a model name identified in models.yaml, load the model into GPU if necessary and return its assigned CLIPTokenizer. If no @@ -190,12 +190,12 @@ class ModelManager(object): currently in the GPU. """ return self._get_sub_model(model_name, SDModelComponent.tokenizer) - + def get_model_unet(self, model_name: str=None)->UNet2DConditionModel: """Given a model name identified in models.yaml, load the model into GPU if necessary and return its assigned UNet2DConditionModel. If no model name is provided, return the UNet from the model - currently in the GPU. + currently in the GPU. """ return self._get_sub_model(model_name, SDModelComponent.unet) @@ -222,7 +222,7 @@ class ModelManager(object): currently in the GPU. """ return self._get_sub_model(model_name, SDModelComponent.scheduler) - + def _get_sub_model( self, model_name: str=None, @@ -1228,7 +1228,7 @@ class ModelManager(object): sha.update(chunk) hash = sha.hexdigest() toc = time.time() - self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) + self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)") with open(hashpath, "w") as f: f.write(hash) return hash