mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor get_submodels() into individual methods
This commit is contained in:
parent
d44151d6ff
commit
4c339dd4b0
@ -10,7 +10,7 @@ from .generator import (
|
|||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager, SDModelComponent
|
from .model_management import ModelManager
|
||||||
from .safety_checker import SafetyChecker
|
from .safety_checker import SafetyChecker
|
||||||
from .args import Args
|
from .args import Args
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
|
@ -5,6 +5,6 @@ from .convert_ckpt_to_diffusers import (
|
|||||||
convert_ckpt_to_diffusers,
|
convert_ckpt_to_diffusers,
|
||||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||||
)
|
)
|
||||||
from .model_manager import ModelManager,SDModelComponent
|
from .model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ import transformers
|
|||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
|
SchedulerMixin,
|
||||||
logging as dlogging,
|
logging as dlogging,
|
||||||
)
|
)
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
@ -169,7 +170,55 @@ class ModelManager(object):
|
|||||||
"hash": hash,
|
"hash": hash,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_sub_model(
|
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned VAE as an
|
||||||
|
AutoencoderKL object. If no model name is provided, return the
|
||||||
|
vae from the model currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||||
|
|
||||||
|
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||||
|
model name is provided, return the tokenizer from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||||
|
|
||||||
|
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||||
|
name is provided, return the UNet from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||||
|
|
||||||
|
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPTextModel. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
|
||||||
|
|
||||||
|
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
|
||||||
|
|
||||||
|
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned scheduler. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||||
|
|
||||||
|
def _get_sub_model(
|
||||||
self,
|
self,
|
||||||
model_name: str=None,
|
model_name: str=None,
|
||||||
model_part: SDModelComponent=SDModelComponent.vae,
|
model_part: SDModelComponent=SDModelComponent.vae,
|
||||||
@ -181,7 +230,7 @@ class ModelManager(object):
|
|||||||
CLIPTextModel,
|
CLIPTextModel,
|
||||||
StableDiffusionSafetyChecker,
|
StableDiffusionSafetyChecker,
|
||||||
]:
|
]:
|
||||||
"""Given a model named identified in models.yaml, and the part of the
|
"""Given a model name identified in models.yaml, and the part of the
|
||||||
model you wish to retrieve, return that part. Parts are in an Enum
|
model you wish to retrieve, return that part. Parts are in an Enum
|
||||||
class named SDModelComponent, and consist of:
|
class named SDModelComponent, and consist of:
|
||||||
SDModelComponent.vae
|
SDModelComponent.vae
|
||||||
@ -190,7 +239,7 @@ class ModelManager(object):
|
|||||||
SDModelComponent.unet
|
SDModelComponent.unet
|
||||||
SDModelComponent.scheduler
|
SDModelComponent.scheduler
|
||||||
SDModelComponent.safety_checker
|
SDModelComponent.safety_checker
|
||||||
SDModelComponent.feature_etractor
|
SDModelComponent.feature_extractor
|
||||||
"""
|
"""
|
||||||
model_dict = self.get_model(model_name)
|
model_dict = self.get_model(model_name)
|
||||||
model = model_dict["model"]
|
model = model_dict["model"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user