mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(ModelManager): factor out _get_implementation method
This commit is contained in:
parent
ac22652686
commit
e70bedba7d
@ -258,7 +258,7 @@ from .models import (
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
DuplicateModelException, ModelBase,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@ -361,7 +361,7 @@ class ModelManager(object):
|
||||
if model_key.startswith("_"):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
@ -446,7 +446,7 @@ class ModelManager(object):
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
@ -475,7 +475,7 @@ class ModelManager(object):
|
||||
model_path = self.app_config.root_path / override_path
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
@ -513,6 +513,10 @@ class ModelManager(object):
|
||||
_cache=self.cache,
|
||||
)
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -659,7 +663,7 @@ class ModelManager(object):
|
||||
if Path(path).is_relative_to(self.app_config.root_path):
|
||||
model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path))
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
model_config = model_class.create_config(**model_attributes)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
@ -837,7 +841,7 @@ class ModelManager(object):
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
@ -888,7 +892,7 @@ class ModelManager(object):
|
||||
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
model_path = self.app_config.root_path.absolute() / model_config.path
|
||||
if not model_path.exists():
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
self.models.pop(model_key, None)
|
||||
@ -904,7 +908,7 @@ class ModelManager(object):
|
||||
for cur_model_type in ModelType:
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
|
||||
|
||||
if not models_dir.exists():
|
||||
|
Loading…
Reference in New Issue
Block a user