mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
committed by
psychedelicious
parent
2b1dc74080
commit
94e8d1b6d5
@ -27,9 +27,7 @@ if TYPE_CHECKING:
|
||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
@ -55,9 +53,7 @@ class InvocationServices:
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@ -82,9 +78,7 @@ class InvocationServices:
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.download_queue = download_queue
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
@ -43,8 +43,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
|
||||
@contextmanager
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
|
||||
# This is to handle case of the model manager not being initialized, which happens
|
||||
# during some tests.
|
||||
services = self._invoker.services
|
||||
if services.model_records is None or services.model_records.loader is None:
|
||||
if services.model_manager is None or services.model_manager.load is None:
|
||||
yield None
|
||||
if not self._stats.get(graph_execution_state_id):
|
||||
# First time we're seeing this graph_execution_state_id.
|
||||
@ -60,9 +62,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# TO DO [LS]: clean up loader service - shouldn't be an attribute of model records
|
||||
assert services.model_records.loader is not None
|
||||
services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id]
|
||||
assert services.model_manager.load is not None
|
||||
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
|
||||
|
||||
try:
|
||||
# Let the invocation run.
|
||||
|
@ -4,7 +4,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
|
||||
@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's key, load it and return the LoadedModel object."""
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's key, load it and return the LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's configuration, load it and return the LoadedModel object."""
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with these attributes not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
|
@ -3,12 +3,14 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache import ModelCacheBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
@ -21,7 +23,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
ram_cache: Optional[ModelCacheBase] = None,
|
||||
ram_cache: Optional[ModelCacheBase[AnyModel]] = None,
|
||||
convert_cache: Optional[ModelConvertCacheBase] = None,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@ -44,11 +46,104 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
),
|
||||
)
|
||||
|
||||
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's key, load it and return the LoadedModel object."""
|
||||
config = self._store.get_model(key)
|
||||
return self.load_model_by_config(config, submodel_type)
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's key, load it and return the LoadedModel object.
|
||||
|
||||
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's configuration, load it and return the LoadedModel object."""
|
||||
return self._any_loader.load_model(config, submodel_type)
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
config = self._store.get_model(key)
|
||||
return self.load_model_by_config(config, submodel_type, context)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self._store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load_model_by_key(configs[0].key, submodel)
|
||||
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
)
|
||||
loaded_model = self._any_loader.load_model(model_config, submodel_type)
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
@ -2,9 +2,10 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
@ -14,12 +15,13 @@ from ..model_records import ModelRecordServiceBase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ModelManagerServiceBase(BaseModel, ABC):
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Abstract base class for the model manager service."""
|
||||
|
||||
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
# attributes:
|
||||
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@ -37,3 +39,29 @@ class ModelManagerServiceBase(BaseModel, ABC):
|
||||
method simplifies the construction considerably.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -10,9 +11,9 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService
|
||||
from ..model_load import ModelLoadService
|
||||
from ..model_records import ModelRecordServiceSQL
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_manager.load -- Routines to load models into memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: ModelRecordServiceBase,
|
||||
install: ModelInstallServiceBase,
|
||||
load: ModelLoadServiceBase,
|
||||
):
|
||||
self._store = store
|
||||
self._install = install
|
||||
self._load = load
|
||||
|
||||
@property
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
return self._install
|
||||
|
||||
@property
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
return self._load
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "start"):
|
||||
service.start(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
|
@ -10,15 +10,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
LoadedModel,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
@ -111,52 +108,6 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: Invocation context, used for event issuing.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Key of model config to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
|
@ -46,8 +46,6 @@ from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@ -55,9 +53,8 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
@ -220,74 +217,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel: Optional[SubModelType],
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
"""
|
||||
if not self._loader:
|
||||
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
|
||||
model_config = self.get_model(key)
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
)
|
||||
loaded_model = self._loader.load_model(model_config, submodel)
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Key of model config to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load_model(configs[0].key, submodel)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
@ -476,29 +405,3 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import
|
||||
class Migration6Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._recreate_model_triggers(cursor)
|
||||
self._delete_ip_adapters(cursor)
|
||||
|
||||
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
@ -26,6 +27,22 @@ class Migration6Callback:
|
||||
"""
|
||||
)
|
||||
|
||||
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Delete all the IP adapters.
|
||||
|
||||
The model manager will automatically find and re-add them after the migration
|
||||
is done. This allows the manager to add the correct image encoder to their
|
||||
configuration records.
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE type='ip_adapter';
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_6() -> Migration:
|
||||
"""
|
||||
@ -33,6 +50,8 @@ def build_migration_6() -> Migration:
|
||||
|
||||
This migration does the following:
|
||||
- Adds the model_config_updated_at trigger if it does not exist
|
||||
- Delete all ip_adapter models so that the model prober can find and
|
||||
update with the correct image processor model.
|
||||
"""
|
||||
migration_6 = Migration(
|
||||
from_version=5,
|
||||
|
Reference in New Issue
Block a user