make model manager v2 ready for PR review

- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
Lincoln Stein
2024-02-10 18:09:45 -05:00
committed by psychedelicious
parent 2b1dc74080
commit 94e8d1b6d5
36 changed files with 680 additions and 435 deletions

View File

@ -27,9 +27,7 @@ if TYPE_CHECKING:
from .invocation_queue.invocation_queue_base import InvocationQueueABC
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@ -55,9 +53,7 @@ class InvocationServices:
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
download_queue: "DownloadQueueServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -82,9 +78,7 @@ class InvocationServices:
self.image_records = image_records
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.download_queue = download_queue
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -43,8 +43,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
@contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
# This is to handle case of the model manager not being initialized, which happens
# during some tests.
services = self._invoker.services
if services.model_records is None or services.model_records.loader is None:
if services.model_manager is None or services.model_manager.load is None:
yield None
if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id.
@ -60,9 +62,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# TO DO [LS]: clean up loader service - shouldn't be an attribute of model records
assert services.model_records.loader is not None
services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id]
assert services.model_manager.load is not None
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
try:
# Let the invocation run.

View File

@ -4,7 +4,8 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's key, load it and return the LoadedModel object."""
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
pass
@abstractmethod
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's configuration, load it and return the LoadedModel object."""
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with these attributes not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""

View File

@ -3,12 +3,14 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
@ -21,7 +23,7 @@ class ModelLoadService(ModelLoadServiceBase):
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
ram_cache: Optional[ModelCacheBase] = None,
ram_cache: Optional[ModelCacheBase[AnyModel]] = None,
convert_cache: Optional[ModelConvertCacheBase] = None,
):
"""Initialize the model load service."""
@ -44,11 +46,104 @@ class ModelLoadService(ModelLoadServiceBase):
),
)
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's key, load it and return the LoadedModel object."""
config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's configuration, load it and return the LoadedModel object."""
return self._any_loader.load_model(config, submodel_type)
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type, context)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self._store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model_by_key(configs[0].key, submodel)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
if context:
self._emit_load_event(
context=context,
model_config=model_config,
)
loaded_model = self._any_loader.load_model(model_config, submodel_type)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
loaded=True,
)
return loaded_model
def _emit_load_event(
self,
context: InvocationContext,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if not loaded:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)
else:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)

View File

@ -2,9 +2,10 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
@ -14,12 +15,13 @@ from ..model_records import ModelRecordServiceBase
from ..shared.sqlite.sqlite_database import SqliteDatabase
class ModelManagerServiceBase(BaseModel, ABC):
class ModelManagerServiceBase(ABC):
"""Abstract base class for the model manager service."""
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
# attributes:
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
@classmethod
@abstractmethod
@ -37,3 +39,29 @@ class ModelManagerServiceBase(BaseModel, ABC):
method simplifies the construction considerably.
"""
pass
@property
@abstractmethod
def store(self) -> ModelRecordServiceBase:
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
pass
@property
@abstractmethod
def load(self) -> ModelLoadServiceBase:
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
pass
@property
@abstractmethod
def install(self) -> ModelInstallServiceBase:
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
pass
@abstractmethod
def start(self, invoker: Invoker) -> None:
pass
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass

View File

@ -3,6 +3,7 @@
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
@ -10,9 +11,9 @@ from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService
from ..model_load import ModelLoadService
from ..model_records import ModelRecordServiceSQL
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_manager_base import ModelManagerServiceBase
@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase):
model_manager.load -- Routines to load models into memory.
"""
def __init__(
self,
store: ModelRecordServiceBase,
install: ModelInstallServiceBase,
load: ModelLoadServiceBase,
):
self._store = store
self._install = install
self._load = load
@property
def store(self) -> ModelRecordServiceBase:
return self._store
@property
def install(self) -> ModelInstallServiceBase:
return self._install
@property
def load(self) -> ModelLoadServiceBase:
return self._load
def start(self, invoker: Invoker) -> None:
for service in [self._store, self._install, self._load]:
if hasattr(service, "start"):
service.start(invoker)
def stop(self, invoker: Invoker) -> None:
for service in [self._store, self._install, self._load]:
if hasattr(service, "stop"):
service.stop(invoker)
@classmethod
def build_model_manager(
cls,

View File

@ -10,15 +10,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
LoadedModel,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -111,52 +108,6 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def load_model(
self,
key: str,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch
:param context: Invocation context, used for event issuing.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStore:

View File

@ -46,8 +46,6 @@ from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@ -55,9 +53,8 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from ..shared.sqlite.sqlite_database import SqliteDatabase
@ -220,74 +217,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def load_model(
self,
key: str,
submodel: Optional[SubModelType],
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
if not self._loader:
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
# we can emit model loading events if we are executing with access to the invocation context
model_config = self.get_model(key)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
)
loaded_model = self._loader.load_model(model_config, submodel)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
loaded=True,
)
return loaded_model
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model(configs[0].key, submodel)
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
@ -476,29 +405,3 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)
def _emit_load_event(
self,
context: InvocationContext,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if not loaded:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)
else:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)

View File

@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import
class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor)
self._delete_ip_adapters(cursor)
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
"""
@ -26,6 +27,22 @@ class Migration6Callback:
"""
)
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
"""
Delete all the IP adapters.
The model manager will automatically find and re-add them after the migration
is done. This allows the manager to add the correct image encoder to their
configuration records.
"""
cursor.execute(
"""--sql
DELETE FROM model_config
WHERE type='ip_adapter';
"""
)
def build_migration_6() -> Migration:
"""
@ -33,6 +50,8 @@ def build_migration_6() -> Migration:
This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist
- Delete all ip_adapter models so that the model prober can find and
update with the correct image processor model.
"""
migration_6 = Migration(
from_version=5,