import_model and list_install_jobs router APIs written

This commit is contained in:
Lincoln Stein 2023-11-25 21:45:59 -05:00
parent ec510d34b5
commit 8aefe2cefe
10 changed files with 145 additions and 20 deletions

@ -25,6 +25,7 @@ from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.model_install import ModelInstallService
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -87,6 +88,7 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
model_install_service = ModelInstallService(app_config=config, record_store=model_record_service, event_bus=events)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -114,6 +116,7 @@ class ApiDependencies:
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

@ -4,7 +4,7 @@
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from typing import List, Optional, Any, Dict
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
BaseModelType,
ModelType,
)
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from ..dependencies import ApiDependencies
@ -162,3 +163,95 @@ async def add_model_record(
# now fetch it out
return record_store.get_model(config.key)
@model_records_router.post(
"/import",
operation_id="import_model_record",
responses={
201: {"description": "The model imported successfully"},
404: {"description": "The model could not be found"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource = Body(
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
),
metadata: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values, such as name, description and prediction_type ",
default=None,
),
variant: Optional[str] = Body(
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
default=None,
),
subfolder: Optional[str] = Body(
description="When fetching a repo_id, specify subfolder to fetch model from",
default=None,
),
access_token: Optional[str] = Body(
description="When fetching a repo_id or URL, access token for web access",
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Listen on the event bus for the following events:
"model_install_started", "model_install_completed", and "model_install_error."
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model(
source,
metadata=metadata,
variant=variant,
subfolder=subfolder,
access_token=access_token,
)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_records_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_install_jobs(
source: Optional[str] = Query(description="Filter list by install source, partial string match.",
default=None,
)
) -> List[ModelInstallJob]:
"""
Return list of model install jobs.
If the optional 'source' argument is provided, then the list will be filtered
for partial string matches against the install source.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
return jobs

@ -20,6 +20,7 @@ class SocketIO:
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
@ -28,10 +29,17 @@ class SocketIO:
room=event[1]["data"]["queue_id"],
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"]
)

@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
class EventServiceBase:
queue_event: str = "queue_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
@ -31,6 +32,13 @@ class EventServiceBase:
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
@ -321,7 +329,7 @@ class EventServiceBase:
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_queue_event(
self.__emit_model_event(
event_name="model_install_started",
payload={
"source": source
@ -335,7 +343,7 @@ class EventServiceBase:
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
"""
self.__emit_queue_event(
self.__emit_model_event(
event_name="model_install_completed",
payload={
"source": source,
@ -354,7 +362,7 @@ class EventServiceBase:
:param source: Source of the model
:param exception: The exception that raised the error
"""
self.__emit_queue_event(
self.__emit_model_event(
event_name="model_install_error",
payload={
"source": source,

@ -23,6 +23,7 @@ if TYPE_CHECKING:
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .model_install import ModelInstallServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@ -51,6 +52,7 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
model_install: "ModelRecordInstallServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -79,6 +81,7 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -105,6 +108,7 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

@ -1,6 +1,12 @@
"""Initialization file for model install service package."""
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException, ModelSource
from .model_install_default import ModelInstallService
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException']
__all__ = ['ModelInstallServiceBase',
'ModelInstallService',
'InstallStatus',
'ModelInstallJob',
'UnknownInstallJobException',
'ModelSource',
]

@ -183,8 +183,12 @@ class ModelInstallServiceBase(ABC):
"""Return the ModelInstallJob corresponding to the provided source."""
@abstractmethod
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
"""Return a dict in which keys are model sources and values are corresponding model install jobs."""
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
"""
List active and complete install jobs.
:param source: Filter by jobs whose sources are a partial match to the argument.
"""
@abstractmethod
def prune_jobs(self) -> None:

@ -12,10 +12,9 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase, DuplicateModelException
from invokeai.backend.model_manager.config import (
AnyModelConfig,
DuplicateModelException,
InvalidModelConfigException,
)
from invokeai.backend.model_manager.config import ModelType, BaseModelType
@ -26,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
@ -76,9 +76,6 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
return self._install_jobs
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
@ -184,6 +181,13 @@ class ModelInstallService(ModelInstallServiceBase):
else: # waiting for download queue implementation
raise NotImplementedError
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
jobs = self._install_jobs
if not source:
return jobs.values()
else:
return [jobs[x] for x in jobs if source in str(x)]
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
try:
return self._install_jobs[source]

@ -3,7 +3,6 @@
from .probe import ModelProbe
from .config import (
InvalidModelConfigException,
DuplicateModelException,
ModelConfigFactory,
BaseModelType,
ModelType,
@ -17,7 +16,6 @@ from .search import ModelSearch
__all__ = ['ModelProbe', 'ModelSearch',
'InvalidModelConfigException',
'DuplicateModelException',
'ModelConfigFactory',
'BaseModelType',
'ModelType',

@ -29,9 +29,6 @@ from typing_extensions import Annotated, Any, Dict
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class DuplicateModelException(Exception):
"""Exception for when a duplicate model is detected during installation."""
class BaseModelType(str, Enum):
"""Base model type."""