mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
import_model and list_install_jobs router APIs written
This commit is contained in:
parent
ec510d34b5
commit
8aefe2cefe
@ -25,6 +25,7 @@ from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
|||||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
|
from ..services.model_install import ModelInstallService
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
@ -87,6 +88,7 @@ class ApiDependencies:
|
|||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
model_record_service = ModelRecordServiceSQL(db=db)
|
||||||
|
model_install_service = ModelInstallService(app_config=config, record_store=model_record_service, event_bus=events)
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
performance_statistics = InvocationStatsService()
|
performance_statistics = InvocationStatsService()
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
@ -114,6 +116,7 @@ class ApiDependencies:
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
model_records=model_record_service,
|
model_records=model_record_service,
|
||||||
|
model_install=model_install_service,
|
||||||
names=names,
|
names=names,
|
||||||
performance_statistics=performance_statistics,
|
performance_statistics=performance_statistics,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from random import randbytes
|
from random import randbytes
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Any, Dict
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -162,3 +163,95 @@ async def add_model_record(
|
|||||||
|
|
||||||
# now fetch it out
|
# now fetch it out
|
||||||
return record_store.get_model(config.key)
|
return record_store.get_model(config.key)
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.post(
|
||||||
|
"/import",
|
||||||
|
operation_id="import_model_record",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The model imported successfully"},
|
||||||
|
404: {"description": "The model could not be found"},
|
||||||
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
|
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
)
|
||||||
|
async def import_model(
|
||||||
|
source: ModelSource = Body(
|
||||||
|
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
|
||||||
|
),
|
||||||
|
metadata: Optional[Dict[str, Any]] = Body(
|
||||||
|
description="Dict of fields that override auto-probed values, such as name, description and prediction_type ",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
variant: Optional[str] = Body(
|
||||||
|
description="When fetching a repo_id, force variant type to fetch such as 'fp16'",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
subfolder: Optional[str] = Body(
|
||||||
|
description="When fetching a repo_id, specify subfolder to fetch model from",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
access_token: Optional[str] = Body(
|
||||||
|
description="When fetching a repo_id or URL, access token for web access",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
) -> ModelInstallJob:
|
||||||
|
"""Add a model using its local path, repo_id, or remote URL.
|
||||||
|
|
||||||
|
Models will be downloaded, probed, configured and installed in a
|
||||||
|
series of background threads. The return object has `status` attribute
|
||||||
|
that can be used to monitor progress.
|
||||||
|
|
||||||
|
The model's configuration record will be probed and filled in
|
||||||
|
automatically. To override the default guesses, pass "metadata"
|
||||||
|
with a Dict containing the attributes you wish to override.
|
||||||
|
|
||||||
|
Listen on the event bus for the following events:
|
||||||
|
"model_install_started", "model_install_completed", and "model_install_error."
|
||||||
|
On successful completion, the event's payload will contain the field "key"
|
||||||
|
containing the installed ID of the model. On an error, the event's payload
|
||||||
|
will contain the fields "error_type" and "error" describing the nature of the
|
||||||
|
error and its traceback, respectively.
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
installer = ApiDependencies.invoker.services.model_install
|
||||||
|
result: ModelInstallJob = installer.import_model(
|
||||||
|
source,
|
||||||
|
metadata=metadata,
|
||||||
|
variant=variant,
|
||||||
|
subfolder=subfolder,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
except UnknownModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except InvalidModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=415)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@model_records_router.get(
|
||||||
|
"/import",
|
||||||
|
operation_id="list_model_install_jobs",
|
||||||
|
)
|
||||||
|
async def list_install_jobs(
|
||||||
|
source: Optional[str] = Query(description="Filter list by install source, partial string match.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
) -> List[ModelInstallJob]:
|
||||||
|
"""
|
||||||
|
Return list of model install jobs.
|
||||||
|
|
||||||
|
If the optional 'source' argument is provided, then the list will be filtered
|
||||||
|
for partial string matches against the install source.
|
||||||
|
"""
|
||||||
|
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
|
||||||
|
return jobs
|
||||||
|
@ -20,6 +20,7 @@ class SocketIO:
|
|||||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||||
|
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||||
|
|
||||||
async def _handle_queue_event(self, event: Event):
|
async def _handle_queue_event(self, event: Event):
|
||||||
await self.__sio.emit(
|
await self.__sio.emit(
|
||||||
@ -28,10 +29,17 @@ class SocketIO:
|
|||||||
room=event[1]["data"]["queue_id"],
|
room=event[1]["data"]["queue_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||||
if "queue_id" in data:
|
if "queue_id" in data:
|
||||||
await self.__sio.enter_room(sid, data["queue_id"])
|
await self.__sio.enter_room(sid, data["queue_id"])
|
||||||
|
|
||||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||||
if "queue_id" in data:
|
if "queue_id" in data:
|
||||||
await self.__sio.leave_room(sid, data["queue_id"])
|
await self.__sio.leave_room(sid, data["queue_id"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_model_event(self, event: Event) -> None:
|
||||||
|
await self.__sio.emit(
|
||||||
|
event=event[1]["event"],
|
||||||
|
data=event[1]["data"]
|
||||||
|
)
|
||||||
|
@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
|
|||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
queue_event: str = "queue_event"
|
queue_event: str = "queue_event"
|
||||||
|
model_event: str = "model_event"
|
||||||
|
|
||||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||||
|
|
||||||
@ -31,6 +32,13 @@ class EventServiceBase:
|
|||||||
payload={"event": event_name, "data": payload},
|
payload={"event": event_name, "data": payload},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||||
|
payload["timestamp"] = get_timestamp()
|
||||||
|
self.dispatch(
|
||||||
|
event_name=EventServiceBase.model_event,
|
||||||
|
payload={"event": event_name, "data": payload},
|
||||||
|
)
|
||||||
|
|
||||||
# Define events here for every event in the system.
|
# Define events here for every event in the system.
|
||||||
# This will make them easier to integrate until we find a schema generator.
|
# This will make them easier to integrate until we find a schema generator.
|
||||||
def emit_generator_progress(
|
def emit_generator_progress(
|
||||||
@ -321,7 +329,7 @@ class EventServiceBase:
|
|||||||
|
|
||||||
:param source: Source of the model; local path, repo_id or url
|
:param source: Source of the model; local path, repo_id or url
|
||||||
"""
|
"""
|
||||||
self.__emit_queue_event(
|
self.__emit_model_event(
|
||||||
event_name="model_install_started",
|
event_name="model_install_started",
|
||||||
payload={
|
payload={
|
||||||
"source": source
|
"source": source
|
||||||
@ -335,7 +343,7 @@ class EventServiceBase:
|
|||||||
:param source: Source of the model; local path, repo_id or url
|
:param source: Source of the model; local path, repo_id or url
|
||||||
:param key: Model config record key
|
:param key: Model config record key
|
||||||
"""
|
"""
|
||||||
self.__emit_queue_event(
|
self.__emit_model_event(
|
||||||
event_name="model_install_completed",
|
event_name="model_install_completed",
|
||||||
payload={
|
payload={
|
||||||
"source": source,
|
"source": source,
|
||||||
@ -354,7 +362,7 @@ class EventServiceBase:
|
|||||||
:param source: Source of the model
|
:param source: Source of the model
|
||||||
:param exception: The exception that raised the error
|
:param exception: The exception that raised the error
|
||||||
"""
|
"""
|
||||||
self.__emit_queue_event(
|
self.__emit_model_event(
|
||||||
event_name="model_install_error",
|
event_name="model_install_error",
|
||||||
payload={
|
payload={
|
||||||
"source": source,
|
"source": source,
|
||||||
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
|||||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from .model_records import ModelRecordServiceBase
|
from .model_records import ModelRecordServiceBase
|
||||||
|
from .model_install import ModelInstallServiceBase
|
||||||
from .names.names_base import NameServiceBase
|
from .names.names_base import NameServiceBase
|
||||||
from .session_processor.session_processor_base import SessionProcessorBase
|
from .session_processor.session_processor_base import SessionProcessorBase
|
||||||
from .session_queue.session_queue_base import SessionQueueBase
|
from .session_queue.session_queue_base import SessionQueueBase
|
||||||
@ -51,6 +52,7 @@ class InvocationServices:
|
|||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
model_records: "ModelRecordServiceBase"
|
model_records: "ModelRecordServiceBase"
|
||||||
|
model_install: "ModelRecordInstallServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
performance_statistics: "InvocationStatsServiceBase"
|
performance_statistics: "InvocationStatsServiceBase"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
@ -79,6 +81,7 @@ class InvocationServices:
|
|||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
model_records: "ModelRecordServiceBase",
|
model_records: "ModelRecordServiceBase",
|
||||||
|
model_install: "ModelInstallServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
performance_statistics: "InvocationStatsServiceBase",
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
@ -105,6 +108,7 @@ class InvocationServices:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.model_records = model_records
|
self.model_records = model_records
|
||||||
|
self.model_install = model_install
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.performance_statistics = performance_statistics
|
self.performance_statistics = performance_statistics
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
"""Initialization file for model install service package."""
|
"""Initialization file for model install service package."""
|
||||||
|
|
||||||
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException
|
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException, ModelSource
|
||||||
from .model_install_default import ModelInstallService
|
from .model_install_default import ModelInstallService
|
||||||
|
|
||||||
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException']
|
__all__ = ['ModelInstallServiceBase',
|
||||||
|
'ModelInstallService',
|
||||||
|
'InstallStatus',
|
||||||
|
'ModelInstallJob',
|
||||||
|
'UnknownInstallJobException',
|
||||||
|
'ModelSource',
|
||||||
|
]
|
||||||
|
@ -183,8 +183,12 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""Return the ModelInstallJob corresponding to the provided source."""
|
"""Return the ModelInstallJob corresponding to the provided source."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
||||||
"""Return a dict in which keys are model sources and values are corresponding model install jobs."""
|
"""
|
||||||
|
List active and complete install jobs.
|
||||||
|
|
||||||
|
:param source: Filter by jobs whose sources are a partial match to the argument.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prune_jobs(self) -> None:
|
def prune_jobs(self) -> None:
|
||||||
|
@ -12,10 +12,9 @@ from pydantic.networks import AnyHttpUrl
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase, DuplicateModelException
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
DuplicateModelException,
|
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
||||||
@ -26,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
|
|||||||
|
|
||||||
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
|
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
|
||||||
|
|
||||||
|
|
||||||
# marker that the queue is done and that thread should exit
|
# marker that the queue is done and that thread should exit
|
||||||
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
|
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
|
||||||
|
|
||||||
@ -76,9 +76,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||||
return self._event_bus
|
return self._event_bus
|
||||||
|
|
||||||
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
|
||||||
return self._install_jobs
|
|
||||||
|
|
||||||
def _start_installer_thread(self) -> None:
|
def _start_installer_thread(self) -> None:
|
||||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||||
|
|
||||||
@ -184,6 +181,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
else: # waiting for download queue implementation
|
else: # waiting for download queue implementation
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
||||||
|
jobs = self._install_jobs
|
||||||
|
if not source:
|
||||||
|
return jobs.values()
|
||||||
|
else:
|
||||||
|
return [jobs[x] for x in jobs if source in str(x)]
|
||||||
|
|
||||||
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
||||||
try:
|
try:
|
||||||
return self._install_jobs[source]
|
return self._install_jobs[source]
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
from .probe import ModelProbe
|
from .probe import ModelProbe
|
||||||
from .config import (
|
from .config import (
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
DuplicateModelException,
|
|
||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
@ -17,7 +16,6 @@ from .search import ModelSearch
|
|||||||
|
|
||||||
__all__ = ['ModelProbe', 'ModelSearch',
|
__all__ = ['ModelProbe', 'ModelSearch',
|
||||||
'InvalidModelConfigException',
|
'InvalidModelConfigException',
|
||||||
'DuplicateModelException',
|
|
||||||
'ModelConfigFactory',
|
'ModelConfigFactory',
|
||||||
'BaseModelType',
|
'BaseModelType',
|
||||||
'ModelType',
|
'ModelType',
|
||||||
|
@ -29,9 +29,6 @@ from typing_extensions import Annotated, Any, Dict
|
|||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
|
||||||
"""Exception for when a duplicate model is detected during installation."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
"""Base model type."""
|
"""Base model type."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user