import_model and list_install_jobs router APIs written

This commit is contained in:
Lincoln Stein 2023-11-25 21:45:59 -05:00
parent ec510d34b5
commit 8aefe2cefe
10 changed files with 145 additions and 20 deletions

View File

@ -25,6 +25,7 @@ from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.model_install import ModelInstallService
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -87,6 +88,7 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger) model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db) model_record_service = ModelRecordServiceSQL(db=db)
model_install_service = ModelInstallService(app_config=config, record_store=model_record_service, event_bus=events)
names = SimpleNameService() names = SimpleNameService()
performance_statistics = InvocationStatsService() performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()
@ -114,6 +116,7 @@ class ApiDependencies:
logger=logger, logger=logger,
model_manager=model_manager, model_manager=model_manager,
model_records=model_record_service, model_records=model_record_service,
model_install=model_install_service,
names=names, names=names,
performance_statistics=performance_statistics, performance_statistics=performance_statistics,
processor=processor, processor=processor,

View File

@ -4,7 +4,7 @@
from hashlib import sha1 from hashlib import sha1
from random import randbytes from random import randbytes
from typing import List, Optional from typing import List, Optional, Any, Dict
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
ModelType, ModelType,
) )
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -162,3 +163,95 @@ async def add_model_record(
# now fetch it out # now fetch it out
return record_store.get_model(config.key) return record_store.get_model(config.key)
@model_records_router.post(
"/import",
operation_id="import_model_record",
responses={
201: {"description": "The model imported successfully"},
404: {"description": "The model could not be found"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource = Body(
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
),
metadata: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values, such as name, description and prediction_type ",
default=None,
),
variant: Optional[str] = Body(
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
default=None,
),
subfolder: Optional[str] = Body(
description="When fetching a repo_id, specify subfolder to fetch model from",
default=None,
),
access_token: Optional[str] = Body(
description="When fetching a repo_id or URL, access token for web access",
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Listen on the event bus for the following events:
"model_install_started", "model_install_completed", and "model_install_error."
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model(
source,
metadata=metadata,
variant=variant,
subfolder=subfolder,
access_token=access_token,
)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_records_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_install_jobs(
source: Optional[str] = Query(description="Filter list by install source, partial string match.",
default=None,
)
) -> List[ModelInstallJob]:
"""
Return list of model install jobs.
If the optional 'source' argument is provided, then the list will be filtered
for partial string matches against the install source.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
return jobs

View File

@ -20,6 +20,7 @@ class SocketIO:
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue) self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue) self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
async def _handle_queue_event(self, event: Event): async def _handle_queue_event(self, event: Event):
await self.__sio.emit( await self.__sio.emit(
@ -28,10 +29,17 @@ class SocketIO:
room=event[1]["data"]["queue_id"], room=event[1]["data"]["queue_id"],
) )
async def _handle_sub_queue(self, sid, data, *args, **kwargs): async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data: if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"]) await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs): async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data: if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"]) await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"]
)

View File

@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
class EventServiceBase: class EventServiceBase:
queue_event: str = "queue_event" queue_event: str = "queue_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed""" """Basic event bus, to have an empty stand-in when not needed"""
@ -31,6 +32,13 @@ class EventServiceBase:
payload={"event": event_name, "data": payload}, payload={"event": event_name, "data": payload},
) )
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system. # Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator. # This will make them easier to integrate until we find a schema generator.
def emit_generator_progress( def emit_generator_progress(
@ -321,7 +329,7 @@ class EventServiceBase:
:param source: Source of the model; local path, repo_id or url :param source: Source of the model; local path, repo_id or url
""" """
self.__emit_queue_event( self.__emit_model_event(
event_name="model_install_started", event_name="model_install_started",
payload={ payload={
"source": source "source": source
@ -335,7 +343,7 @@ class EventServiceBase:
:param source: Source of the model; local path, repo_id or url :param source: Source of the model; local path, repo_id or url
:param key: Model config record key :param key: Model config record key
""" """
self.__emit_queue_event( self.__emit_model_event(
event_name="model_install_completed", event_name="model_install_completed",
payload={ payload={
"source": source, "source": source,
@ -354,7 +362,7 @@ class EventServiceBase:
:param source: Source of the model :param source: Source of the model
:param exception: The exception that raised the error :param exception: The exception that raised the error
""" """
self.__emit_queue_event( self.__emit_model_event(
event_name="model_install_error", event_name="model_install_error",
payload={ payload={
"source": source, "source": source,

View File

@ -23,6 +23,7 @@ if TYPE_CHECKING:
from .latents_storage.latents_storage_base import LatentsStorageBase from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase from .model_records import ModelRecordServiceBase
from .model_install import ModelInstallServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase from .session_queue.session_queue_base import SessionQueueBase
@ -51,6 +52,7 @@ class InvocationServices:
logger: "Logger" logger: "Logger"
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase" model_records: "ModelRecordServiceBase"
model_install: "ModelRecordInstallServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase" performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"
@ -79,6 +81,7 @@ class InvocationServices:
logger: "Logger", logger: "Logger",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase", model_records: "ModelRecordServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase", performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
@ -105,6 +108,7 @@ class InvocationServices:
self.logger = logger self.logger = logger
self.model_manager = model_manager self.model_manager = model_manager
self.model_records = model_records self.model_records = model_records
self.model_install = model_install
self.processor = processor self.processor = processor
self.performance_statistics = performance_statistics self.performance_statistics = performance_statistics
self.queue = queue self.queue = queue

View File

@ -1,6 +1,12 @@
"""Initialization file for model install service package.""" """Initialization file for model install service package."""
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException, ModelSource
from .model_install_default import ModelInstallService from .model_install_default import ModelInstallService
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException'] __all__ = ['ModelInstallServiceBase',
'ModelInstallService',
'InstallStatus',
'ModelInstallJob',
'UnknownInstallJobException',
'ModelSource',
]

View File

@ -183,8 +183,12 @@ class ModelInstallServiceBase(ABC):
"""Return the ModelInstallJob corresponding to the provided source.""" """Return the ModelInstallJob corresponding to the provided source."""
@abstractmethod @abstractmethod
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102 def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
"""Return a dict in which keys are model sources and values are corresponding model install jobs.""" """
List active and complete install jobs.
:param source: Filter by jobs whose sources are a partial match to the argument.
"""
@abstractmethod @abstractmethod
def prune_jobs(self) -> None: def prune_jobs(self) -> None:

View File

@ -12,10 +12,9 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase, DuplicateModelException
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
DuplicateModelException,
InvalidModelConfigException, InvalidModelConfigException,
) )
from invokeai.backend.model_manager.config import ModelType, BaseModelType from invokeai.backend.model_manager.config import ModelType, BaseModelType
@ -26,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
# marker that the queue is done and that thread should exit # marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null")) STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
@ -76,9 +76,6 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus return self._event_bus
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
return self._install_jobs
def _start_installer_thread(self) -> None: def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start() threading.Thread(target=self._install_next_item, daemon=True).start()
@ -184,6 +181,13 @@ class ModelInstallService(ModelInstallServiceBase):
else: # waiting for download queue implementation else: # waiting for download queue implementation
raise NotImplementedError raise NotImplementedError
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
jobs = self._install_jobs
if not source:
return jobs.values()
else:
return [jobs[x] for x in jobs if source in str(x)]
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102 def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
try: try:
return self._install_jobs[source] return self._install_jobs[source]

View File

@ -3,7 +3,6 @@
from .probe import ModelProbe from .probe import ModelProbe
from .config import ( from .config import (
InvalidModelConfigException, InvalidModelConfigException,
DuplicateModelException,
ModelConfigFactory, ModelConfigFactory,
BaseModelType, BaseModelType,
ModelType, ModelType,
@ -17,7 +16,6 @@ from .search import ModelSearch
__all__ = ['ModelProbe', 'ModelSearch', __all__ = ['ModelProbe', 'ModelSearch',
'InvalidModelConfigException', 'InvalidModelConfigException',
'DuplicateModelException',
'ModelConfigFactory', 'ModelConfigFactory',
'BaseModelType', 'BaseModelType',
'ModelType', 'ModelType',

View File

@ -29,9 +29,6 @@ from typing_extensions import Annotated, Any, Dict
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format.""" """Exception for when config parser doesn't recognized this combination of model type and format."""
class DuplicateModelException(Exception):
"""Exception for when a duplicate model is detected during installation."""
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
"""Base model type.""" """Base model type."""