tidy(mm): remove convenience methods from high level model manager service

These were added as a hold-me-over for the nodes API changes, no longer needed. A followup commit will fix the nodes API to not rely on these.
This commit is contained in:
psychedelicious 2024-03-06 19:04:33 +11:00
parent 4eefed12f0
commit afd9ae7712
5 changed files with 10 additions and 99 deletions

View File

@ -539,7 +539,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file # loading the model will convert it into a cached diffusers file
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader # Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key) cache_path = loader.convert_cache.cache_path(key)

View File

@ -1,15 +1,11 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
import torch import torch
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase from ..download import DownloadQueueServiceBase
@ -70,32 +66,3 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def stop(self, invoker: Invoker) -> None: def stop(self, invoker: Invoker) -> None:
pass pass
@abstractmethod
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass

View File

@ -1,14 +1,11 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase.""" """Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch import torch
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -18,7 +15,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, UnknownModelException from ..model_records import ModelRecordServiceBase
from .model_manager_base import ModelManagerServiceBase from .model_manager_base import ModelManagerServiceBase
@ -64,56 +61,6 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"): if hasattr(service, "stop"):
service.stop(invoker) service.stop(invoker)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
return self.load.load_model(model_config, submodel_type, context_data)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
config = self.store.get_model(key)
return self.load.load_model(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load.load_model(configs[0], submodel, context_data)
@classmethod @classmethod
def build_model_manager( def build_model_manager(
cls, cls,

View File

@ -42,9 +42,10 @@ def install_and_load_model(
# If the requested model is already installed, return its LoadedModel # If the requested model is already installed, return its LoadedModel
with contextlib.suppress(UnknownModelException): with contextlib.suppress(UnknownModelException):
# TODO: Replace with wrapper call # TODO: Replace with wrapper call
loaded_model: LoadedModel = model_manager.load_model_by_attr( configs = model_manager.store.search_by_attr(
model_name=model_name, base_model=base_model, model_type=model_type model_name=model_name, base_model=base_model, model_type=model_type
) )
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
return loaded_model return loaded_model
# Install the requested model. # Install the requested model.
@ -53,7 +54,7 @@ def install_and_load_model(
assert job.complete assert job.complete
try: try:
loaded_model = model_manager.load_model_by_config(job.config_out) loaded_model = model_manager.load.load_model(job.config_out)
return loaded_model return loaded_model
except UnknownModelException as e: except UnknownModelException as e:
raise Exception( raise Exception(

View File

@ -14,17 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
matches = store.search_by_attr(model_name="test_embedding") matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0 assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file) key = mm2_model_manager.install.register_path(embedding_file)
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key)) loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
assert loaded_model is not None assert loaded_model is not None
assert loaded_model.config.key == key assert loaded_model.config.key == key
with loaded_model as model: with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw) assert isinstance(model, TextualInversionModelRaw)
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
assert loaded_model.config.key == loaded_model_2.config.key
loaded_model_3 = mm2_model_manager.load_model_by_attr( config = mm2_model_manager.store.get_model(key)
model_name=loaded_model.config.name, loaded_model_2 = mm2_model_manager.load.load_model(config)
model_type=loaded_model.config.type,
base_model=loaded_model.config.base, assert loaded_model.config.key == loaded_model_2.config.key
)
assert loaded_model.config.key == loaded_model_3.config.key