mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(ModelManager): factor out get_model_path
This commit is contained in:
parent
b163ae6a4d
commit
bc9a5038fd
@ -451,37 +451,33 @@ class ModelManager(object):
|
|||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
:param model_type: ModelType enum indicating the type of model to return
|
:param model_type: ModelType enum indicating the type of model to return
|
||||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||||
:param submode_typel: an ModelType enum indicating the portion of
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
"""
|
"""
|
||||||
model_class = self._get_implementation(base_model, model_type)
|
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
||||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||||
model_path = self.app_config.root_path / model_config.path
|
|
||||||
|
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||||
|
|
||||||
|
if is_submodel_override:
|
||||||
|
model_type = submodel_type
|
||||||
|
submodel_type = None
|
||||||
|
|
||||||
|
model_class = self._get_implementation(base_model, model_type)
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
self.models[model_key].error = ModelError.NotFound
|
self.models[model_key].error = ModelError.NotFound
|
||||||
raise Exception(f'Files for model "{model_key}" not found')
|
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# vae/movq override
|
|
||||||
# TODO:
|
|
||||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
|
||||||
override_path = getattr(model_config, submodel_type)
|
|
||||||
if override_path:
|
|
||||||
model_path = self.app_config.root_path / override_path
|
|
||||||
model_type = submodel_type
|
|
||||||
submodel_type = None
|
|
||||||
model_class = self._get_implementation(base_model, model_type)
|
|
||||||
|
|
||||||
# TODO: path
|
# TODO: path
|
||||||
# TODO: is it accurate to use path as id
|
# TODO: is it accurate to use path as id
|
||||||
dst_convert_path = self._get_model_cache_path(model_path)
|
dst_convert_path = self._get_model_cache_path(model_path)
|
||||||
@ -518,6 +514,20 @@ class ModelManager(object):
|
|||||||
_cache=self.cache,
|
_cache=self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_model_path(self, model_config: ModelConfigBase, submodel_type: SubModelType = None) -> (Path, bool):
|
||||||
|
model_path = model_config.path
|
||||||
|
is_submodel_override = False
|
||||||
|
|
||||||
|
# Does the config explicitly override the submodel?
|
||||||
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
|
submodel_path = getattr(model_config, submodel_type)
|
||||||
|
if submodel_path is not None:
|
||||||
|
model_path = getattr(model_config, submodel_type)
|
||||||
|
is_submodel_override = True
|
||||||
|
|
||||||
|
model_path = self.app_config.root_path / model_path
|
||||||
|
return model_path, is_submodel_override
|
||||||
|
|
||||||
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
|
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user