Merge branch 'main' into docs/ui/update-ui-readme

This commit is contained in:
blessedcoolant 2023-05-12 15:04:12 +12:00 committed by GitHub
commit 85d03dcd90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -30,7 +30,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
SchedulerMixin, SchedulerMixin,
logging as dlogging, logging as dlogging,
) )
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@ -68,7 +68,7 @@ class SDModelComponent(Enum):
scheduler="scheduler" scheduler="scheduler"
safety_checker="safety_checker" safety_checker="safety_checker"
feature_extractor="feature_extractor" feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2 DEFAULT_MAX_MODELS = 2
class ModelManager(object): class ModelManager(object):
@ -182,7 +182,7 @@ class ModelManager(object):
vae from the model currently in the GPU. vae from the model currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.vae) return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer: def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into """Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no GPU if necessary and return its assigned CLIPTokenizer. If no
@ -190,12 +190,12 @@ class ModelManager(object):
currently in the GPU. currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.tokenizer) return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel: def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into """Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model name is provided, return the UNet from the model
currently in the GPU. currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.unet) return self._get_sub_model(model_name, SDModelComponent.unet)
@ -222,7 +222,7 @@ class ModelManager(object):
currently in the GPU. currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.scheduler) return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model( def _get_sub_model(
self, self,
model_name: str=None, model_name: str=None,
@ -1228,7 +1228,7 @@ class ModelManager(object):
sha.update(chunk) sha.update(chunk)
hash = sha.hexdigest() hash = sha.hexdigest()
toc = time.time() toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
with open(hashpath, "w") as f: with open(hashpath, "w") as f:
f.write(hash) f.write(hash)
return hash return hash