mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
committed by
psychedelicious
parent
7956602b19
commit
a23dedd2ee
@ -2,9 +2,10 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
@ -14,12 +15,13 @@ from ..model_records import ModelRecordServiceBase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ModelManagerServiceBase(BaseModel, ABC):
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Abstract base class for the model manager service."""
|
||||
|
||||
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
# attributes:
|
||||
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@ -37,3 +39,29 @@ class ModelManagerServiceBase(BaseModel, ABC):
|
||||
method simplifies the construction considerably.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -10,9 +11,9 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService
|
||||
from ..model_load import ModelLoadService
|
||||
from ..model_records import ModelRecordServiceSQL
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_manager.load -- Routines to load models into memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: ModelRecordServiceBase,
|
||||
install: ModelInstallServiceBase,
|
||||
load: ModelLoadServiceBase,
|
||||
):
|
||||
self._store = store
|
||||
self._install = install
|
||||
self._load = load
|
||||
|
||||
@property
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
return self._install
|
||||
|
||||
@property
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
return self._load
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "start"):
|
||||
service.start(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
|
Reference in New Issue
Block a user