diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index dd126a322d..06066dd6b1 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -10,7 +10,7 @@ from .generator import ( Img2Img, Inpaint ) -from .model_management import ModelManager, SDModelComponent +from .model_management import ModelManager from .safety_checker import SafetyChecker from .args import Args from .globals import Globals diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 4a82da849a..1d290050d4 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -5,6 +5,6 @@ from .convert_ckpt_to_diffusers import ( convert_ckpt_to_diffusers, load_pipeline_from_original_stable_diffusion_ckpt, ) -from .model_manager import ModelManager,SDModelComponent +from .model_manager import ModelManager diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index c76be93e8f..a51a2fec22 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -27,6 +27,7 @@ import transformers from diffusers import ( AutoencoderKL, UNet2DConditionModel, + SchedulerMixin, logging as dlogging, ) from huggingface_hub import scan_cache_dir @@ -169,7 +170,55 @@ class ModelManager(object): "hash": hash, } - def get_sub_model( + def get_model_vae(self, model_name: str=None)->AutoencoderKL: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned VAE as an + AutoencoderKL object. If no model name is provided, return the + vae from the model currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.vae) + + def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned CLIPTokenizer. If no + model name is provided, return the tokenizer from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.tokenizer) + + def get_model_unet(self, model_name: str=None)->UNet2DConditionModel: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned UNet2DConditionModel. If no model + name is provided, return the UNet from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.unet) + + def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned CLIPTextModel. If no + model name is provided, return the text encoder from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.text_encoder) + + def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned CLIPFeatureExtractor. If no + model name is provided, return the text encoder from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.feature_extractor) + + def get_model_scheduler(self, model_name: str=None)->SchedulerMixin: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned scheduler. If no + model name is provided, return the text encoder from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.scheduler) + + def _get_sub_model( self, model_name: str=None, model_part: SDModelComponent=SDModelComponent.vae, @@ -181,7 +230,7 @@ class ModelManager(object): CLIPTextModel, StableDiffusionSafetyChecker, ]: - """Given a model named identified in models.yaml, and the part of the + """Given a model name identified in models.yaml, and the part of the model you wish to retrieve, return that part. Parts are in an Enum class named SDModelComponent, and consist of: SDModelComponent.vae @@ -190,7 +239,7 @@ class ModelManager(object): SDModelComponent.unet SDModelComponent.scheduler SDModelComponent.safety_checker - SDModelComponent.feature_etractor + SDModelComponent.feature_extractor """ model_dict = self.get_model(model_name) model = model_dict["model"]