fix(model manager): fix string formatting error on model checksum timer (#3397)

The error occurs when loading a model for the first time. (or after
removing its checksum file, probably.)
This commit is contained in:
blessedcoolant 2023-05-12 15:04:01 +12:00 committed by GitHub
commit 032555bcfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -30,7 +30,7 @@ from diffusers import (
UNet2DConditionModel,
SchedulerMixin,
logging as dlogging,
)
)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@ -68,7 +68,7 @@ class SDModelComponent(Enum):
scheduler="scheduler"
safety_checker="safety_checker"
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2
class ModelManager(object):
@ -182,7 +182,7 @@ class ModelManager(object):
vae from the model currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no
@ -190,12 +190,12 @@ class ModelManager(object):
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model
currently in the GPU.
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.unet)
@ -222,7 +222,7 @@ class ModelManager(object):
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model(
self,
model_name: str=None,
@ -1228,7 +1228,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
with open(hashpath, "w") as f:
f.write(hash)
return hash