refactor(ModelManager): factor out _get_implementation method

This commit is contained in:
Kevin Turner 2023-07-28 21:03:27 -07:00
parent ac22652686
commit e70bedba7d

View File

@ -258,7 +258,7 @@ from .models import (
ModelConfigBase,
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
DuplicateModelException, ModelBase,
)
# We are only starting to number the config file with release 3.
@ -361,7 +361,7 @@ class ModelManager(object):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
@ -446,7 +446,7 @@ class ModelManager(object):
:param submode_typel: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
model_key = self.create_key(model_name, base_model, model_type)
# if model not found try to find it (maybe file just pasted)
@ -475,7 +475,7 @@ class ModelManager(object):
model_path = self.app_config.root_path / override_path
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
# TODO: path
# TODO: is it accurate to use path as id
@ -513,6 +513,10 @@ class ModelManager(object):
_cache=self.cache,
)
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def model_info(
self,
model_name: str,
@ -659,7 +663,7 @@ class ModelManager(object):
if Path(path).is_relative_to(self.app_config.root_path):
model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path))
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
@ -837,7 +841,7 @@ class ModelManager(object):
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@ -888,7 +892,7 @@ class ModelManager(object):
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
model_path = self.app_config.root_path.absolute() / model_config.path
if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
if model_class.save_to_config:
model_config.error = ModelError.NotFound
self.models.pop(model_key, None)
@ -904,7 +908,7 @@ class ModelManager(object):
for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type:
continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
if not models_dir.exists():