fix(model manager): fix string formatting error on model checksum timer (#3397)

The error occurs when loading a model for the first time. (or after
removing its checksum file, probably.)
This commit is contained in:
blessedcoolant
2023-05-12 15:04:01 +12:00
committed by GitHub

View File

@ -30,7 +30,7 @@ from diffusers import (
UNet2DConditionModel,
SchedulerMixin,
logging as dlogging,
)
)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@ -68,7 +68,7 @@ class SDModelComponent(Enum):
scheduler="scheduler"
safety_checker="safety_checker"
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2
class ModelManager(object):
@ -182,7 +182,7 @@ class ModelManager(object):
vae from the model currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no
@ -190,12 +190,12 @@ class ModelManager(object):
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model
currently in the GPU.
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.unet)
@ -222,7 +222,7 @@ class ModelManager(object):
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model(
self,
model_name: str=None,
@ -1228,7 +1228,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
with open(hashpath, "w") as f:
f.write(hash)
return hash