From afd9ae771262250865d2a44be6e2213b94d4851c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:04:33 +1100 Subject: [PATCH] tidy(mm): remove convenience methods from high level model manager service These were added as a hold-me-over for the nodes API changes, no longer needed. A followup commit will fix the nodes API to not rely on these. --- invokeai/app/api/routers/model_manager.py | 2 +- .../model_manager/model_manager_base.py | 33 ----------- .../model_manager/model_manager_default.py | 55 +------------------ invokeai/backend/util/test_utils.py | 5 +- .../model_loading/test_model_load.py | 14 ++--- 5 files changed, 10 insertions(+), 99 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 22f6e07719..9b88f82fa3 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -539,7 +539,7 @@ async def convert_model( raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") # loading the model will convert it into a cached diffusers file - model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) + model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler) # Get the path of the converted model from the loader cache_path = loader.convert_cache.cache_path(key) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 6e886df652..af1b68e1ec 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,15 +1,11 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod -from typing import Optional import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import LoadedModel from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase @@ -70,32 +66,3 @@ class ModelManagerServiceBase(ABC): @abstractmethod def stop(self, invoker: Invoker) -> None: pass - - @abstractmethod - def load_model_by_config( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - pass - - @abstractmethod - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - pass - - @abstractmethod - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 7d4b248323..83632d0c0f 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,14 +1,11 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -from typing import Optional import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger @@ -18,7 +15,7 @@ from ..download import DownloadQueueServiceBase from ..events.events_base import EventServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase -from ..model_records import ModelRecordServiceBase, UnknownModelException +from ..model_records import ModelRecordServiceBase from .model_manager_base import ModelManagerServiceBase @@ -64,56 +61,6 @@ class ModelManagerService(ModelManagerServiceBase): if hasattr(service, "stop"): service.stop(invoker) - def load_model_by_config( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - return self.load.load_model(model_config, submodel_type, context_data) - - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - config = self.store.get_model(key) - return self.load.load_model(config, submodel_type, context_data) - - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Name of to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ - configs = self.store.search_by_attr(model_name, base_model, model_type) - if len(configs) == 0: - raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") - elif len(configs) > 1: - raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") - else: - return self.load.load_model(configs[0], submodel, context_data) - @classmethod def build_model_manager( cls, diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 0d76c4633c..add394e71b 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -42,9 +42,10 @@ def install_and_load_model( # If the requested model is already installed, return its LoadedModel with contextlib.suppress(UnknownModelException): # TODO: Replace with wrapper call - loaded_model: LoadedModel = model_manager.load_model_by_attr( + configs = model_manager.store.search_by_attr( model_name=model_name, base_model=base_model, model_type=model_type ) + loaded_model: LoadedModel = model_manager.load.load_model(configs[0]) return loaded_model # Install the requested model. @@ -53,7 +54,7 @@ def install_and_load_model( assert job.complete try: - loaded_model = model_manager.load_model_by_config(job.config_out) + loaded_model = model_manager.load.load_model(job.config_out) return loaded_model except UnknownModelException as e: raise Exception( diff --git a/tests/backend/model_manager/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py index c1fde504ea..3f12f7f8ee 100644 --- a/tests/backend/model_manager/model_loading/test_model_load.py +++ b/tests/backend/model_manager/model_loading/test_model_load.py @@ -14,17 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat matches = store.search_by_attr(model_name="test_embedding") assert len(matches) == 0 key = mm2_model_manager.install.register_path(embedding_file) - loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key)) + loaded_model = mm2_model_manager.load.load_model(store.get_model(key)) assert loaded_model is not None assert loaded_model.config.key == key with loaded_model as model: assert isinstance(model, TextualInversionModelRaw) - loaded_model_2 = mm2_model_manager.load_model_by_key(key) - assert loaded_model.config.key == loaded_model_2.config.key - loaded_model_3 = mm2_model_manager.load_model_by_attr( - model_name=loaded_model.config.name, - model_type=loaded_model.config.type, - base_model=loaded_model.config.base, - ) - assert loaded_model.config.key == loaded_model_3.config.key + config = mm2_model_manager.store.get_model(key) + loaded_model_2 = mm2_model_manager.load.load_model(config) + + assert loaded_model.config.key == loaded_model_2.config.key