mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(mm): remove convenience methods from high level model manager service
These were added as a hold-me-over for the nodes API changes, no longer needed. A followup commit will fix the nodes API to not rely on these.
This commit is contained in:
parent
4eefed12f0
commit
afd9ae7712
@ -539,7 +539,7 @@ async def convert_model(
|
|||||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||||
|
|
||||||
# loading the model will convert it into a cached diffusers file
|
# loading the model will convert it into a cached diffusers file
|
||||||
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||||
|
|
||||||
# Get the path of the converted model from the loader
|
# Get the path of the converted model from the loader
|
||||||
cache_path = loader.convert_cache.cache_path(key)
|
cache_path = loader.convert_cache.cache_path(key)
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
from ..download import DownloadQueueServiceBase
|
from ..download import DownloadQueueServiceBase
|
||||||
@ -70,32 +66,3 @@ class ModelManagerServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stop(self, invoker: Invoker) -> None:
|
def stop(self, invoker: Invoker) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_model_by_config(
|
|
||||||
self,
|
|
||||||
model_config: AnyModelConfig,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
context_data: Optional[InvocationContextData] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_model_by_key(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
context_data: Optional[InvocationContextData] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_model_by_attr(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
context_data: Optional[InvocationContextData] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
pass
|
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -18,7 +15,7 @@ from ..download import DownloadQueueServiceBase
|
|||||||
from ..events.events_base import EventServiceBase
|
from ..events.events_base import EventServiceBase
|
||||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
from ..model_records import ModelRecordServiceBase
|
||||||
from .model_manager_base import ModelManagerServiceBase
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
|
|
||||||
|
|
||||||
@ -64,56 +61,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
if hasattr(service, "stop"):
|
if hasattr(service, "stop"):
|
||||||
service.stop(invoker)
|
service.stop(invoker)
|
||||||
|
|
||||||
def load_model_by_config(
|
|
||||||
self,
|
|
||||||
model_config: AnyModelConfig,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
context_data: Optional[InvocationContextData] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
return self.load.load_model(model_config, submodel_type, context_data)
|
|
||||||
|
|
||||||
def load_model_by_key(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
context_data: Optional[InvocationContextData] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
config = self.store.get_model(key)
|
|
||||||
return self.load.load_model(config, submodel_type, context_data)
|
|
||||||
|
|
||||||
def load_model_by_attr(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
context_data: Optional[InvocationContextData] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
|
||||||
|
|
||||||
This is provided for API compatability with the get_model() method
|
|
||||||
in the original model manager. However, note that LoadedModel is
|
|
||||||
not the same as the original ModelInfo that ws returned.
|
|
||||||
|
|
||||||
:param model_name: Name of to be fetched.
|
|
||||||
:param base_model: Base model
|
|
||||||
:param model_type: Type of the model
|
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch
|
|
||||||
:param context: The invocation context.
|
|
||||||
|
|
||||||
Exceptions: UnknownModelException -- model with this key not known
|
|
||||||
NotImplementedException -- a model loader was not provided at initialization time
|
|
||||||
ValueError -- more than one model matches this combination
|
|
||||||
"""
|
|
||||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
|
||||||
if len(configs) == 0:
|
|
||||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
|
||||||
elif len(configs) > 1:
|
|
||||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
|
||||||
else:
|
|
||||||
return self.load.load_model(configs[0], submodel, context_data)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_model_manager(
|
def build_model_manager(
|
||||||
cls,
|
cls,
|
||||||
|
@ -42,9 +42,10 @@ def install_and_load_model(
|
|||||||
# If the requested model is already installed, return its LoadedModel
|
# If the requested model is already installed, return its LoadedModel
|
||||||
with contextlib.suppress(UnknownModelException):
|
with contextlib.suppress(UnknownModelException):
|
||||||
# TODO: Replace with wrapper call
|
# TODO: Replace with wrapper call
|
||||||
loaded_model: LoadedModel = model_manager.load_model_by_attr(
|
configs = model_manager.store.search_by_attr(
|
||||||
model_name=model_name, base_model=base_model, model_type=model_type
|
model_name=model_name, base_model=base_model, model_type=model_type
|
||||||
)
|
)
|
||||||
|
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
# Install the requested model.
|
# Install the requested model.
|
||||||
@ -53,7 +54,7 @@ def install_and_load_model(
|
|||||||
assert job.complete
|
assert job.complete
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loaded_model = model_manager.load_model_by_config(job.config_out)
|
loaded_model = model_manager.load.load_model(job.config_out)
|
||||||
return loaded_model
|
return loaded_model
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -14,17 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
|
|||||||
matches = store.search_by_attr(model_name="test_embedding")
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
key = mm2_model_manager.install.register_path(embedding_file)
|
key = mm2_model_manager.install.register_path(embedding_file)
|
||||||
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
|
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||||
assert loaded_model is not None
|
assert loaded_model is not None
|
||||||
assert loaded_model.config.key == key
|
assert loaded_model.config.key == key
|
||||||
with loaded_model as model:
|
with loaded_model as model:
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
|
|
||||||
assert loaded_model.config.key == loaded_model_2.config.key
|
|
||||||
|
|
||||||
loaded_model_3 = mm2_model_manager.load_model_by_attr(
|
config = mm2_model_manager.store.get_model(key)
|
||||||
model_name=loaded_model.config.name,
|
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||||
model_type=loaded_model.config.type,
|
|
||||||
base_model=loaded_model.config.base,
|
assert loaded_model.config.key == loaded_model_2.config.key
|
||||||
)
|
|
||||||
assert loaded_model.config.key == loaded_model_3.config.key
|
|
||||||
|
Loading…
Reference in New Issue
Block a user