mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
import_model and list_install_jobs router APIs written
This commit is contained in:
parent
ec510d34b5
commit
8aefe2cefe
invokeai
app
api
services
backend/model_manager
@ -25,6 +25,7 @@ from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
@ -87,6 +88,7 @@ class ApiDependencies:
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
model_install_service = ModelInstallService(app_config=config, record_store=model_record_service, event_bus=events)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
@ -114,6 +116,7 @@ class ApiDependencies:
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
model_install=model_install_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -162,3 +163,95 @@ async def add_model_record(
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/import",
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource = Body(
|
||||
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
|
||||
),
|
||||
metadata: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
variant: Optional[str] = Body(
|
||||
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
|
||||
default=None,
|
||||
),
|
||||
subfolder: Optional[str] = Body(
|
||||
description="When fetching a repo_id, specify subfolder to fetch model from",
|
||||
default=None,
|
||||
),
|
||||
access_token: Optional[str] = Body(
|
||||
description="When fetching a repo_id or URL, access token for web access",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Listen on the event bus for the following events:
|
||||
"model_install_started", "model_install_completed", and "model_install_error."
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source,
|
||||
metadata=metadata,
|
||||
variant=variant,
|
||||
subfolder=subfolder,
|
||||
access_token=access_token,
|
||||
)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
@model_records_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_install_jobs(
|
||||
source: Optional[str] = Query(description="Filter list by install source, partial string match.",
|
||||
default=None,
|
||||
)
|
||||
) -> List[ModelInstallJob]:
|
||||
"""
|
||||
Return list of model install jobs.
|
||||
|
||||
If the optional 'source' argument is provided, then the list will be filtered
|
||||
for partial string matches against the install source.
|
||||
"""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
|
||||
return jobs
|
||||
|
@ -20,6 +20,7 @@ class SocketIO:
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
@ -28,10 +29,17 @@ class SocketIO:
|
||||
room=event[1]["data"]["queue_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.leave_room(sid, data["queue_id"])
|
||||
|
||||
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"]
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
@ -31,6 +32,13 @@ class EventServiceBase:
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.model_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
@ -321,7 +329,7 @@ class EventServiceBase:
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_queue_event(
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_started",
|
||||
payload={
|
||||
"source": source
|
||||
@ -335,7 +343,7 @@ class EventServiceBase:
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
"""
|
||||
self.__emit_queue_event(
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={
|
||||
"source": source,
|
||||
@ -354,7 +362,7 @@ class EventServiceBase:
|
||||
:param source: Source of the model
|
||||
:param exception: The exception that raised the error
|
||||
"""
|
||||
self.__emit_queue_event(
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={
|
||||
"source": source,
|
||||
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
@ -51,6 +52,7 @@ class InvocationServices:
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
model_install: "ModelRecordInstallServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
@ -79,6 +81,7 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@ -105,6 +108,7 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
@ -1,6 +1,12 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException
|
||||
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException, ModelSource
|
||||
from .model_install_default import ModelInstallService
|
||||
|
||||
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException']
|
||||
__all__ = ['ModelInstallServiceBase',
|
||||
'ModelInstallService',
|
||||
'InstallStatus',
|
||||
'ModelInstallJob',
|
||||
'UnknownInstallJobException',
|
||||
'ModelSource',
|
||||
]
|
||||
|
@ -183,8 +183,12 @@ class ModelInstallServiceBase(ABC):
|
||||
"""Return the ModelInstallJob corresponding to the provided source."""
|
||||
|
||||
@abstractmethod
|
||||
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||
"""Return a dict in which keys are model sources and values are corresponding model install jobs."""
|
||||
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
||||
"""
|
||||
List active and complete install jobs.
|
||||
|
||||
:param source: Filter by jobs whose sources are a partial match to the argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self) -> None:
|
||||
|
@ -12,10 +12,9 @@ from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, DuplicateModelException
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
DuplicateModelException,
|
||||
InvalidModelConfigException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
||||
@ -26,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
|
||||
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
|
||||
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
|
||||
|
||||
@ -76,9 +76,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
def _start_installer_thread(self) -> None:
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
|
||||
@ -184,6 +181,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
else: # waiting for download queue implementation
|
||||
raise NotImplementedError
|
||||
|
||||
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
||||
jobs = self._install_jobs
|
||||
if not source:
|
||||
return jobs.values()
|
||||
else:
|
||||
return [jobs[x] for x in jobs if source in str(x)]
|
||||
|
||||
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
||||
try:
|
||||
return self._install_jobs[source]
|
||||
|
@ -3,7 +3,6 @@
|
||||
from .probe import ModelProbe
|
||||
from .config import (
|
||||
InvalidModelConfigException,
|
||||
DuplicateModelException,
|
||||
ModelConfigFactory,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
@ -17,7 +16,6 @@ from .search import ModelSearch
|
||||
|
||||
__all__ = ['ModelProbe', 'ModelSearch',
|
||||
'InvalidModelConfigException',
|
||||
'DuplicateModelException',
|
||||
'ModelConfigFactory',
|
||||
'BaseModelType',
|
||||
'ModelType',
|
||||
|
@ -29,9 +29,6 @@ from typing_extensions import Annotated, Any, Dict
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Exception for when a duplicate model is detected during installation."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user