mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
35 Commits
v4.2.2
...
psyche/fea
Author | SHA1 | Date | |
---|---|---|---|
8eb5316c9f | |||
12ce095bb2 | |||
242b2a0b59 | |||
86e201612f | |||
f6d1e1be22 | |||
edf043d1d6 | |||
70a21eda78 | |||
452f4fe0e6 | |||
87f2d04ddd | |||
9c45cbe8f7 | |||
5f7c852493 | |||
35ef02bdf7 | |||
7da283f433 | |||
812cf277b8 | |||
182cb51bf0 | |||
64a3adfc64 | |||
a48ef9f7a7 | |||
9aeabf10df | |||
7b93cc8538 | |||
a2480c16e7 | |||
b1e2dd222e | |||
1f92e9eec2 | |||
fb402f3b46 | |||
0abc328ddf | |||
cfa4e5f88e | |||
24d0d4932d | |||
20db93b901 | |||
500a733d79 | |||
338d5f158b | |||
63e4b224b2 | |||
e9043ff060 | |||
c725851c64 | |||
a1c4ef55d7 | |||
e25b39aca2 | |||
32a02b3329 |
@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService
|
||||
from ..services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.events.events_fastapievents import FastAPIEventService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
@ -33,7 +34,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
|
@ -1,52 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put({"event_name": event_name, "payload": payload})
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get("event_name"),
|
||||
payload=event.get("payload"),
|
||||
middleware_id=self.event_handler_id,
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
|
@ -203,6 +203,7 @@ async def get_batch_status(
|
||||
responses={
|
||||
200: {"model": SessionQueueItem},
|
||||
},
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def get_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
|
@ -1,66 +1,131 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from pydantic import BaseModel
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadEventBase,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
FastAPIEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelEventBase,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionStartedEvent,
|
||||
register_events,
|
||||
)
|
||||
|
||||
|
||||
class QueueSubscriptionEvent(BaseModel):
|
||||
queue_id: str
|
||||
|
||||
|
||||
class BulkDownloadSubscriptionEvent(BaseModel):
|
||||
bulk_download_id: str
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
_sub_queue = "subscribe_queue"
|
||||
_unsub_queue = "unsubscribe_queue"
|
||||
|
||||
__sub_queue: str = "subscribe_queue"
|
||||
__unsub_queue: str = "unsubscribe_queue"
|
||||
|
||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||
_sub_bulk_download = "subscribe_bulk_download"
|
||||
_unsub_bulk_download = "unsubscribe_bulk_download"
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self._app)
|
||||
|
||||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
|
||||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
||||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
||||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
|
||||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
|
||||
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["queue_id"],
|
||||
register_events(
|
||||
{
|
||||
InvocationStartedEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
SessionStartedEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionCanceledEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
BatchEnqueuedEvent,
|
||||
QueueClearedEvent,
|
||||
},
|
||||
self._handle_queue_event,
|
||||
)
|
||||
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.leave_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||
|
||||
async def _handle_bulk_download_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["bulk_download_id"],
|
||||
register_events(
|
||||
{
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallErrorEvent,
|
||||
},
|
||||
self._handle_model_event,
|
||||
)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.enter_room(sid, data["bulk_download_id"])
|
||||
register_events(
|
||||
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
|
||||
self._handle_bulk_image_download_event,
|
||||
)
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.leave_room(sid, data["bulk_download_id"])
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
event_name, payload = event
|
||||
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.queue_id)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
|
||||
event_name, payload = event
|
||||
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"))
|
||||
|
||||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
||||
event_name, payload = event
|
||||
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.bulk_download_id)
|
||||
|
@ -27,6 +27,7 @@ import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -182,23 +183,14 @@ def custom_openapi() -> dict[str, Any]:
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
# Add all event schemas
|
||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
if "$defs" in json_schema:
|
||||
for schema_key, schema in json_schema["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
||||
del json_schema["$defs"]
|
||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
self._invoker.services.events.emit_bulk_download_started(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
)
|
||||
|
||||
def _signal_job_completed(
|
||||
@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert bulk_download_item_name is not None
|
||||
self._invoker.services.events.emit_bulk_download_completed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
self._invoker.services.events.emit_bulk_download_complete(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
)
|
||||
|
||||
def _signal_job_failed(
|
||||
@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self._invoker.services.events.emit_bulk_download_failed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=str(exception),
|
||||
self._invoker.services.events.emit_bulk_download_error(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
|
||||
)
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
|
@ -8,14 +8,13 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -30,6 +29,9 @@ from .download_base import (
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
@ -40,7 +42,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
@ -343,8 +345,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||
self._event_bus.emit_download_started(job)
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
@ -355,13 +356,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_progress(
|
||||
str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
self._event_bus.emit_download_progress(job)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
@ -373,10 +368,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_complete(
|
||||
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
|
||||
)
|
||||
self._event_bus.emit_download_complete(job)
|
||||
|
||||
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
@ -390,7 +382,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(str(job.source))
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
@ -403,9 +395,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
|
||||
self._event_bus.emit_download_error(job)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
||||
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
|
@ -1,490 +1,253 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
EventBase,
|
||||
ExtraData,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionStartedEvent,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
bulk_download_event: str = "bulk_download_event"
|
||||
download_event: str = "download_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
def dispatch(self, event: "EventBase") -> None:
|
||||
pass
|
||||
|
||||
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Bulk download events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.bulk_download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
# region: Invocation
|
||||
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.queue_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_download_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.model_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
order: int,
|
||||
total_steps: int,
|
||||
def emit_invocation_started(
|
||||
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
|
||||
"step": step,
|
||||
"order": order,
|
||||
"total_steps": total_steps,
|
||||
},
|
||||
"""Emitted when an invocation is started"""
|
||||
self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra))
|
||||
|
||||
def emit_invocation_denoise_progress(
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: "ProgressImage",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted at each step during denoising of an invocation."""
|
||||
self.dispatch(
|
||||
InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra)
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
output: "BaseInvocationOutput",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"result": result,
|
||||
},
|
||||
)
|
||||
"""Emitted when an invocation is complete"""
|
||||
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra))
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
error_type: str,
|
||||
error: str,
|
||||
user_id: str | None,
|
||||
project_id: str | None,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
},
|
||||
)
|
||||
"""Emitted when an invocation encounters an error"""
|
||||
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra))
|
||||
|
||||
def emit_invocation_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
},
|
||||
)
|
||||
# endregion
|
||||
|
||||
def emit_graph_execution_complete(
|
||||
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
||||
) -> None:
|
||||
# region Session
|
||||
|
||||
def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session has started"""
|
||||
self.dispatch(SessionStartedEvent.build(queue_item, extra))
|
||||
|
||||
def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
self.dispatch(SessionCompleteEvent.build(queue_item, extra))
|
||||
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
self.dispatch(SessionCanceledEvent.build(queue_item, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Queue
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
self,
|
||||
session_queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
queue_item: "SessionQueueItem",
|
||||
batch_status: "BatchStatus",
|
||||
queue_status: "SessionQueueStatus",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload={
|
||||
"queue_id": queue_status.queue_id,
|
||||
"queue_item": {
|
||||
"queue_id": session_queue_item.queue_id,
|
||||
"item_id": session_queue_item.item_id,
|
||||
"status": session_queue_item.status,
|
||||
"batch_id": session_queue_item.batch_id,
|
||||
"session_id": session_queue_item.session_id,
|
||||
"error": session_queue_item.error,
|
||||
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
},
|
||||
"batch_status": batch_status.model_dump(mode="json"),
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra))
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload={
|
||||
"queue_id": enqueue_result.queue_id,
|
||||
"batch_id": enqueue_result.batch.batch_id,
|
||||
"enqueued": enqueue_result.enqueued,
|
||||
},
|
||||
)
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra))
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload={"queue_id": queue_id},
|
||||
)
|
||||
def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a queue is cleared"""
|
||||
self.dispatch(QueueClearedEvent.build(queue_id, extra))
|
||||
|
||||
def emit_download_started(self, source: str, download_path: str) -> None:
|
||||
"""
|
||||
Emit when a download job is started.
|
||||
# endregion
|
||||
|
||||
:param url: The downloaded url
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_started",
|
||||
payload={"source": source, "download_path": download_path},
|
||||
)
|
||||
# region Download
|
||||
|
||||
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit "download_progress" events at regular intervals during a download job.
|
||||
def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is started"""
|
||||
self.dispatch(DownloadStartedEvent.build(job, extra))
|
||||
|
||||
:param source: The downloaded source
|
||||
:param download_path: The local downloaded file
|
||||
:param current_bytes: Number of bytes downloaded so far
|
||||
:param total_bytes: The size of the file being downloaded (if known)
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_progress",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"current_bytes": current_bytes,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted at intervals during a download"""
|
||||
self.dispatch(DownloadProgressEvent.build(job, extra))
|
||||
|
||||
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit a "download_complete" event at the end of a successful download.
|
||||
def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is completed"""
|
||||
self.dispatch(DownloadCompleteEvent.build(job, extra))
|
||||
|
||||
:param source: Source URL
|
||||
:param download_path: Path to the locally downloaded file
|
||||
:param total_bytes: The size of the downloaded file
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_complete",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is cancelled"""
|
||||
self.dispatch(DownloadCancelledEvent.build(job, extra))
|
||||
|
||||
def emit_download_cancelled(self, source: str) -> None:
|
||||
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
|
||||
self.__emit_download_event(
|
||||
event_name="download_cancelled",
|
||||
payload={
|
||||
"source": source,
|
||||
},
|
||||
)
|
||||
def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download encounters an error"""
|
||||
self.dispatch(DownloadErrorEvent.build(job, extra))
|
||||
|
||||
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
|
||||
"""
|
||||
Emit a "download_error" event when an download job encounters an exception.
|
||||
# endregion
|
||||
|
||||
:param source: Source URL
|
||||
:param error_type: The name of the exception that raised the error
|
||||
:param error: The traceback from this error
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_error",
|
||||
payload={
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
# region Model loading
|
||||
|
||||
def emit_model_install_downloading(
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
source: str,
|
||||
local_path: str,
|
||||
bytes: int,
|
||||
total_bytes: int,
|
||||
parts: List[Dict[str, Union[str, int]]],
|
||||
id: int,
|
||||
config: "AnyModelConfig",
|
||||
submodel_type: Optional["SubModelType"] = None,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emit at intervals while the install job is in progress (remote models only).
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra))
|
||||
|
||||
:param source: Source of the model
|
||||
:param local_path: Where model is downloading to
|
||||
:param parts: Progress of downloading URLs that comprise the model, if any.
|
||||
:param bytes: Number of bytes downloaded so far.
|
||||
:param total_bytes: Total size of download, including all files.
|
||||
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloading",
|
||||
payload={
|
||||
"source": source,
|
||||
"local_path": local_path,
|
||||
"bytes": bytes,
|
||||
"total_bytes": total_bytes,
|
||||
"parts": parts,
|
||||
"id": id,
|
||||
},
|
||||
)
|
||||
def emit_model_load_complete(
|
||||
self,
|
||||
config: "AnyModelConfig",
|
||||
submodel_type: Optional["SubModelType"] = None,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra))
|
||||
|
||||
def emit_model_install_downloads_done(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when all parts are downloaded, but before the probing and registration start.
|
||||
# endregion
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloads_done",
|
||||
payload={"source": source},
|
||||
)
|
||||
# region Model install
|
||||
|
||||
def emit_model_install_running(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when an install job becomes active.
|
||||
def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra))
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_running",
|
||||
payload={"source": source},
|
||||
)
|
||||
def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
|
||||
"""
|
||||
Emit when an install job is completed successfully.
|
||||
def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted once when an install job is started (after any download)."""
|
||||
self.dispatch(ModelInstallStartedEvent.build(job, extra))
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
:param total_bytes: Size of the model (may be None for installation of a local path)
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
)
|
||||
def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job is completed successfully."""
|
||||
self.dispatch(ModelInstallCompleteEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_cancelled(self, source: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job is cancelled.
|
||||
def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job is cancelled."""
|
||||
self.dispatch(ModelInstallCancelledEvent.build(job, extra))
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_cancelled",
|
||||
payload={"source": source, "id": id},
|
||||
)
|
||||
def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job encounters an exception."""
|
||||
self.dispatch(ModelInstallErrorEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job encounters an exception.
|
||||
# endregion
|
||||
|
||||
:param source: Source of the model
|
||||
:param error_type: The name of the exception
|
||||
:param error: A text description of the exception
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={"source": source, "error_type": error_type, "error": error, "id": id},
|
||||
)
|
||||
# region Bulk image download
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk download starts"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_started",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
"""Emitted when a bulk image download is started"""
|
||||
self.dispatch(
|
||||
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
|
||||
)
|
||||
|
||||
def emit_bulk_download_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
def emit_bulk_download_complete(
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk download completes"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_completed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
"""Emitted when a bulk image download is complete"""
|
||||
self.dispatch(
|
||||
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
|
||||
)
|
||||
|
||||
def emit_bulk_download_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
def emit_bulk_download_error(
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk download fails"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_failed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
"error": error,
|
||||
},
|
||||
"""Emitted when a bulk image download has an error"""
|
||||
self.dispatch(
|
||||
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra)
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
707
invokeai/app/services/events/events_common.py
Normal file
707
invokeai/app/services/events/events_common.py
Normal file
@ -0,0 +1,707 @@
|
||||
from math import floor
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
|
||||
|
||||
ExtraData: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
"""Base class for all events. All events must inherit from this class.
|
||||
|
||||
Events must define a class attribute `__event_name__` to identify the event.
|
||||
|
||||
All other attributes should be defined as normal for a pydantic model.
|
||||
|
||||
A timestamp is automatically added to the event when it is created.
|
||||
"""
|
||||
|
||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||
extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event")
|
||||
|
||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||
|
||||
@classmethod
|
||||
def get_events(cls) -> set[type["EventBase"]]:
|
||||
"""Get a set of all event models."""
|
||||
|
||||
event_subclasses: set[type["EventBase"]] = set()
|
||||
for subclass in cls.__subclasses__():
|
||||
# We only want to include subclasses that are event models, not intermediary classes
|
||||
if hasattr(subclass, "__event_name__"):
|
||||
event_subclasses.add(subclass)
|
||||
event_subclasses.update(subclass.get_events())
|
||||
|
||||
return event_subclasses
|
||||
|
||||
|
||||
TEvent = TypeVar("TEvent", bound=EventBase)
|
||||
|
||||
FastAPIEvent: TypeAlias = tuple[str, TEvent]
|
||||
"""
|
||||
A tuple representing a `fastapi-events` event, with the event name and payload.
|
||||
Provide a generic type to `TEvent` to specify the payload type.
|
||||
"""
|
||||
|
||||
|
||||
class FastAPIEventFunc(Protocol):
|
||||
def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ...
|
||||
|
||||
|
||||
def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
|
||||
"""Register a function to handle a list of events.
|
||||
|
||||
:param events: A list of event classes to handle
|
||||
:param func: The function to handle the events
|
||||
"""
|
||||
for event in events:
|
||||
assert hasattr(event, "__event_name__")
|
||||
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
|
||||
|
||||
class QueueEventBase(EventBase):
|
||||
"""Base class for queue events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
|
||||
|
||||
class QueueItemEventBase(QueueEventBase):
|
||||
"""Base class for queue item events"""
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
|
||||
|
||||
class SessionEventBase(QueueItemEventBase):
|
||||
"""Base class for session (aka graph execution state) events"""
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
|
||||
|
||||
class InvocationEventBase(SessionEventBase):
|
||||
"""Base class for invocation events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation_id: str = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
invocation_type: str = Field(description="The type of invocation")
|
||||
|
||||
|
||||
class InvocationStartedEvent(InvocationEventBase):
|
||||
"""Event model for invocation_started"""
|
||||
|
||||
__event_name__ = "invocation_started"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None
|
||||
) -> "InvocationStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
"""Event model for invocation_denoise_progress"""
|
||||
|
||||
__event_name__ = "invocation_denoise_progress"
|
||||
|
||||
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
|
||||
step: int = Field(description="The current step of the invocation")
|
||||
total_steps: int = Field(description="The total number of steps in the invocation")
|
||||
order: int = Field(description="The order of the invocation in the session")
|
||||
percentage: float = Field(description="The percentage of completion of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: ProgressImage,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationDenoiseProgressEvent":
|
||||
step = intermediate_state.step
|
||||
total_steps = intermediate_state.total_steps
|
||||
order = intermediate_state.order
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
order=order,
|
||||
percentage=cls.calc_percentage(step, total_steps, order),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
|
||||
"""Calculate the percentage of completion of denoising."""
|
||||
if total_steps == 0:
|
||||
return 0.0
|
||||
if scheduler_order == 2:
|
||||
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
|
||||
# order == 1
|
||||
return (step + 1 + 1) / (total_steps + 1)
|
||||
|
||||
|
||||
class InvocationCompleteEvent(InvocationEventBase):
|
||||
"""Event model for invocation_complete"""
|
||||
|
||||
__event_name__ = "invocation_complete"
|
||||
|
||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
result: BaseInvocationOutput,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
result=result,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class InvocationErrorEvent(InvocationEventBase):
|
||||
"""Event model for invocation_error"""
|
||||
|
||||
__event_name__ = "invocation_error"
|
||||
|
||||
error_type: str = Field(description="The type of error")
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
error_type: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationErrorEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class SessionStartedEvent(SessionEventBase):
|
||||
"""Event model for session_started"""
|
||||
|
||||
__event_name__ = "session_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class SessionCompleteEvent(SessionEventBase):
|
||||
"""Event model for session_complete"""
|
||||
|
||||
__event_name__ = "session_complete"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class SessionCanceledEvent(SessionEventBase):
|
||||
"""Event model for session_canceled"""
|
||||
|
||||
__event_name__ = "session_canceled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
"""Event model for queue_item_status_changed"""
|
||||
|
||||
__event_name__ = "queue_item_status_changed"
|
||||
|
||||
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
|
||||
error: Optional[str] = Field(default=None, description="The error message, if any")
|
||||
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
|
||||
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
|
||||
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
|
||||
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
|
||||
batch_status: BatchStatus = Field(description="The status of the batch")
|
||||
queue_status: SessionQueueStatus = Field(description="The status of the queue")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "QueueItemStatusChangedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
status=queue_item.status,
|
||||
error=queue_item.error,
|
||||
created_at=str(queue_item.created_at) if queue_item.created_at else None,
|
||||
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
|
||||
started_at=str(queue_item.started_at) if queue_item.started_at else None,
|
||||
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BatchEnqueuedEvent(QueueEventBase):
|
||||
"""Event model for batch_enqueued"""
|
||||
|
||||
__event_name__ = "batch_enqueued"
|
||||
|
||||
batch_id: str = Field(description="The ID of the batch")
|
||||
enqueued: int = Field(description="The number of invocations enqueued")
|
||||
requested: int = Field(
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class QueueClearedEvent(QueueEventBase):
|
||||
"""Event model for queue_cleared"""
|
||||
|
||||
__event_name__ = "queue_cleared"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent":
|
||||
return cls(
|
||||
queue_id=queue_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a download"""
|
||||
|
||||
source: str = Field(description="The source of the download")
|
||||
|
||||
|
||||
class DownloadStartedEvent(DownloadEventBase):
|
||||
"""Event model for download_started"""
|
||||
|
||||
__event_name__ = "download_started"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadProgressEvent(DownloadEventBase):
|
||||
"""Event model for download_progress"""
|
||||
|
||||
__event_name__ = "download_progress"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
current_bytes: int = Field(description="The number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="The total number of bytes to be downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadCompleteEvent(DownloadEventBase):
|
||||
"""Event model for download_complete"""
|
||||
|
||||
__event_name__ = "download_complete"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
total_bytes: int = Field(description="The total number of bytes downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadCancelledEvent(DownloadEventBase):
|
||||
"""Event model for download_cancelled"""
|
||||
|
||||
__event_name__ = "download_cancelled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent":
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadErrorEvent(DownloadEventBase):
|
||||
"""Event model for download_error"""
|
||||
|
||||
__event_name__ = "download_error"
|
||||
|
||||
error_type: str = Field(description="The type of error")
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent":
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
error_type=job.error_type,
|
||||
error=job.error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for events associated with a model"""
|
||||
|
||||
|
||||
class ModelLoadStartedEvent(ModelEventBase):
|
||||
"""Event model for model_load_started"""
|
||||
|
||||
__event_name__ = "model_load_started"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
|
||||
) -> "ModelLoadStartedEvent":
|
||||
return cls(
|
||||
config=config,
|
||||
submodel_type=submodel_type,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_load_complete"""
|
||||
|
||||
__event_name__ = "model_load_complete"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
|
||||
) -> "ModelLoadCompleteEvent":
|
||||
return cls(
|
||||
config=config,
|
||||
submodel_type=submodel_type,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
"""Event model for model_install_download_progress"""
|
||||
|
||||
__event_name__ = "model_install_download_progress"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
local_path: str = Field(description="Where model is downloading to")
|
||||
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="Total size of download, including all files")
|
||||
parts: list[dict[str, int | str]] = Field(
|
||||
description="Progress of downloading URLs that comprise the model, if any"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent":
|
||||
parts: list[dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||
"""Emitted once when an install job becomes active."""
|
||||
|
||||
__event_name__ = "model_install_downloads_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallStartedEvent(ModelEventBase):
|
||||
"""Event model for model_install_started"""
|
||||
|
||||
__event_name__ = "model_install_started"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_install_complete"""
|
||||
|
||||
__event_name__ = "model_install_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
key: str = Field(description="Model config record key")
|
||||
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent":
|
||||
assert job.config_out is not None
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
key=(job.config_out.key),
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallCancelledEvent(ModelEventBase):
|
||||
"""Event model for model_install_cancelled"""
|
||||
|
||||
__event_name__ = "model_install_cancelled"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallErrorEvent(ModelEventBase):
|
||||
"""Event model for model_install_error"""
|
||||
|
||||
__event_name__ = "model_install_error"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
error_type: str = Field(description="The name of the exception")
|
||||
error: str = Field(description="A text description of the exception")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent":
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
error_type=job.error_type,
|
||||
error=job.error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BulkDownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a bulk image download"""
|
||||
|
||||
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
||||
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
||||
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
||||
|
||||
|
||||
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_started"""
|
||||
|
||||
__event_name__ = "bulk_download_started"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadStartedEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_complete"""
|
||||
|
||||
__event_name__ = "bulk_download_complete"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadCompleteEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_error"""
|
||||
|
||||
__event_name__ = "bulk_download_error"
|
||||
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadErrorEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=error,
|
||||
extra=extra,
|
||||
)
|
46
invokeai/app/services/events/events_fastapievents.py
Normal file
46
invokeai/app/services/events/events_fastapievents.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
|
||||
from .events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._queue.put(None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._queue.put(event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
@ -1,11 +1,13 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from .model_install_common import (
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
UnknownInstallJobException,
|
||||
URLModelSource,
|
||||
|
@ -1,244 +1,19 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
|
||||
"""Baseclass definitions for the model installer."""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def event_bus(self) -> Optional[EventServiceBase]:
|
||||
def event_bus(self) -> Optional["EventServiceBase"]:
|
||||
"""Return the event service base object associated with the installer."""
|
||||
|
||||
@abstractmethod
|
||||
|
233
invokeai/app/services/model_install/model_install_common.py
Normal file
233
invokeai/app/services/model_install/model_install_common.py
Normal file
@ -0,0 +1,233 @@
|
||||
import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -20,8 +20,8 @@ from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -45,13 +45,12 @@ from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_base import (
|
||||
from .model_install_common import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
StringLikeSource,
|
||||
URLModelSource,
|
||||
@ -59,6 +58,9 @@ from .model_install_base import (
|
||||
|
||||
TMPDIR_PREFIX = "tmpinstall_"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
@ -68,7 +70,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
"""
|
||||
@ -104,7 +106,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return self._record_store
|
||||
|
||||
@property
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
# make the invoker optional here because we don't need it and it
|
||||
@ -855,35 +857,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"Model install started: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_running(str(job.source))
|
||||
self._event_bus.emit_model_install_started(job)
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
parts: List[Dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_downloading(
|
||||
str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
id=job.id,
|
||||
)
|
||||
self._event_bus.emit_model_install_download_progress(job)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._logger.info(f"Model download complete: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_downloads_done(str(job.source))
|
||||
self._event_bus.emit_model_install_downloads_complete(job)
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
@ -891,24 +875,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Model install complete: {job.source}")
|
||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
|
||||
self._event_bus.emit_model_install_complete(job)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
||||
if self._event_bus:
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
self._event_bus.emit_model_install_error(job)
|
||||
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"Model install canceled: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
self._event_bus.emit_model_install_cancelled(job)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
|
@ -4,7 +4,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
|
@ -5,7 +5,6 @@ from typing import Optional, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
@ -51,25 +50,18 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
# We don't have an invoker during testing
|
||||
# TODO(psyche): Mock this method on the invoker in the tests
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
@ -79,40 +71,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
loaded=True,
|
||||
)
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
|
||||
if not loaded:
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
@ -4,11 +4,16 @@ from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
FastAPIEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@ -31,8 +36,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event = ThreadEvent()
|
||||
self._cancel_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_limit = thread_limit
|
||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
@ -49,6 +52,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
register_events({SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent}, self._on_queue_event)
|
||||
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
@ -67,30 +72,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
|
||||
_event_name, payload = event
|
||||
if (
|
||||
event_name == "session_canceled"
|
||||
isinstance(payload, SessionCanceledEvent)
|
||||
and self._queue_item
|
||||
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
|
||||
and self._queue_item.item_id == payload.item_id
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "queue_cleared"
|
||||
isinstance(payload, QueueClearedEvent)
|
||||
and self._queue_item
|
||||
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
|
||||
and self._queue_item.queue_id == payload.queue_id
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
elif isinstance(payload, BatchEnqueuedEvent):
|
||||
self._poll_now()
|
||||
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
|
||||
"completed",
|
||||
"failed",
|
||||
"canceled",
|
||||
]:
|
||||
elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ("completed", "failed", "canceled"):
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
@ -139,6 +139,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._invoker.services.events.emit_session_started(self._queue_item)
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
@ -153,16 +154,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
self._invoker.services.events.emit_invocation_started(self._queue_item, self._invocation)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
@ -189,19 +181,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
self._queue_item, self._invocation, outputs
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
# TODO(MM2): I don't think this is ever raised...
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
@ -227,30 +212,20 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
invocation=self._invocation,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
user_id=None,
|
||||
project_id=None,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
self._invoker.services.session_queue.set_queue_item_session(
|
||||
self._queue_item.item_id, self._queue_item.session
|
||||
)
|
||||
self._invoker.services.events.emit_session_complete(self._queue_item)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
@ -281,6 +256,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
# Cancel the queue item
|
||||
if self._queue_item is not None:
|
||||
self._invoker.services.session_queue.set_queue_item_session(
|
||||
self._queue_item.item_id, self._queue_item.session
|
||||
)
|
||||
self._invoker.services.session_queue.cancel_queue_item(
|
||||
self._queue_item.item_id, error=traceback.format_exc()
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
|
||||
|
||||
@ -103,3 +104,8 @@ class SessionQueueBase(ABC):
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
"""Sets the session for a session queue item. Use this to update the session state."""
|
||||
pass
|
||||
|
@ -2,10 +2,13 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
FastAPIEvent,
|
||||
InvocationErrorEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@ -27,6 +30,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
@ -41,7 +45,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
|
||||
register_events(events={InvocationErrorEvent}, func=self._handle_error_event)
|
||||
register_events(events={SessionCompleteEvent}, func=self._handle_complete_event)
|
||||
register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event)
|
||||
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
@ -51,51 +59,35 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name == "invocation_error":
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
_event_name, payload = event
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
queue_item = self.get_queue_item(payload.item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
self._set_queue_item_status(item_id=payload.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error = event[1]["data"]["error"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
_event_name, payload = event
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||
self._set_queue_item_status(item_id=payload.item_id, status="failed", error=payload.error)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
pass
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
_event_name, payload = event
|
||||
queue_item = self.get_queue_item(payload.item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
self._set_queue_item_status(item_id=payload.item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
pass
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
@ -292,11 +284,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
@ -429,12 +417,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
status = "failed" if error is not None else "canceled"
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(queue_item)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
@ -470,18 +453,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@ -521,18 +497,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@ -562,6 +531,29 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
try:
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
session_json = session.model_dump_json(warnings=False, exclude_none=True)
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET session = ?
|
||||
WHERE item_id = ?
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
|
@ -353,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@ -382,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -13,8 +13,36 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
SDXL_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
]
|
||||
SDXL_SMOOTH_MATRIX = [
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
]
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
SD1_5_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
@ -47,64 +75,12 @@ def stable_diffusion_step_callback(
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
else:
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
@ -113,15 +89,9 @@ def stable_diffusion_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_generator_progress(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
node_id=context_data.invocation.id,
|
||||
source_node_id=context_data.source_invocation_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
total_steps=intermediate_state.total_steps,
|
||||
events.emit_invocation_denoise_progress(
|
||||
context_data.queue_item,
|
||||
context_data.invocation,
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
@ -35,28 +35,23 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
||||
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
||||
import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError';
|
||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
||||
import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError';
|
||||
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
|
||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -110,12 +105,8 @@ addInvocationErrorEventListener(startAppListening);
|
||||
addInvocationStartedEventListener(startAppListening);
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
addSocketDisconnectedEventListener(startAppListening);
|
||||
addSocketSubscribedEventListener(startAppListening);
|
||||
addSocketUnsubscribedEventListener(startAppListening);
|
||||
addModelLoadEventListener(startAppListening);
|
||||
addModelInstallEventListener(startAppListening);
|
||||
addSessionRetrievalErrorEventListener(startAppListening);
|
||||
addInvocationRetrievalErrorEventListener(startAppListening);
|
||||
addSocketQueueItemStatusChangedEventListener(startAppListening);
|
||||
addBulkDownloadListeners(startAppListening);
|
||||
|
||||
|
@ -6,8 +6,8 @@ import { toast } from 'common/util/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import {
|
||||
socketBulkDownloadCompleted,
|
||||
socketBulkDownloadFailed,
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadStarted,
|
||||
} from 'services/events/actions';
|
||||
|
||||
@ -56,7 +56,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadCompleted,
|
||||
actionCreator: socketBulkDownloadComplete,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation completed');
|
||||
|
||||
@ -89,7 +89,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadFailed,
|
||||
actionCreator: socketBulkDownloadError,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation failed');
|
||||
|
||||
|
@ -133,8 +133,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === processorNode.id
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === processorNode.id
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
|
@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === nodeId
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === nodeId
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
|
@ -12,8 +12,8 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
actionCreator: socketGeneratorProgress,
|
||||
effect: (action) => {
|
||||
log.trace(action.payload, `Generator progress`);
|
||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
|
@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
actionCreator: socketInvocationComplete,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { data } = action.payload;
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation_type})`);
|
||||
|
||||
const { result, node, queue_batch_id, source_node_id } = data;
|
||||
const { result, invocation_source_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||
const { image_name } = result.image;
|
||||
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation_type)) {
|
||||
const { image_name } = data.result.image;
|
||||
const { canvas, gallery } = getState();
|
||||
|
||||
// This populates the `getImageDTO` cache
|
||||
@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
imageDTORequest.unsubscribe();
|
||||
|
||||
// Add canvas images to the staging area
|
||||
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
|
||||
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
}
|
||||
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
|
@ -11,9 +11,9 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.invocation_type})`);
|
||||
const { invocation_source_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.FAILED;
|
||||
nes.error = action.payload.data.error;
|
||||
|
@ -1,14 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketInvocationRetrievalError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationRetrievalError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`);
|
||||
},
|
||||
});
|
||||
};
|
@ -11,9 +11,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationStarted,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.invocation_type})`);
|
||||
const { invocation_source_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
|
@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import {
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallCompleted,
|
||||
socketModelInstallDownloading,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallError,
|
||||
} from 'services/events/actions';
|
||||
|
||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloading,
|
||||
actionCreator: socketModelInstallDownloadProgress,
|
||||
effect: async (action, { dispatch }) => {
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallCompleted,
|
||||
actionCreator: socketModelInstallComplete,
|
||||
effect: (action, { dispatch }) => {
|
||||
const { id } = action.payload.data;
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions';
|
||||
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
@ -8,10 +8,11 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadStarted,
|
||||
effect: (action) => {
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
@ -23,10 +24,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadCompleted,
|
||||
actionCreator: socketModelLoadComplete,
|
||||
effect: (action) => {
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
if (submodel_type) {
|
||||
|
@ -14,16 +14,23 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
actionCreator: socketQueueItemStatusChanged,
|
||||
effect: async (action, { dispatch }) => {
|
||||
// we've got new status for the queue item, batch and queue
|
||||
const { queue_item, batch_status, queue_status } = action.payload.data;
|
||||
const { item_id, status, started_at, updated_at, error, completed_at, batch_status, queue_status } =
|
||||
action.payload.data;
|
||||
|
||||
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`);
|
||||
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
|
||||
|
||||
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
|
||||
queueItemsAdapter.updateOne(draft, {
|
||||
id: String(queue_item.item_id),
|
||||
changes: queue_item,
|
||||
id: String(item_id),
|
||||
changes: {
|
||||
status,
|
||||
started_at,
|
||||
updated_at: updated_at ?? undefined,
|
||||
error,
|
||||
completed_at: completed_at ?? undefined,
|
||||
},
|
||||
});
|
||||
})
|
||||
);
|
||||
@ -43,23 +50,18 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)
|
||||
);
|
||||
|
||||
// Update the queue item status (this is the full queue item, including the session)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => {
|
||||
if (!draft) {
|
||||
return;
|
||||
}
|
||||
Object.assign(draft, queue_item);
|
||||
})
|
||||
);
|
||||
|
||||
// Invalidate caches for things we cannot update
|
||||
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
|
||||
dispatch(
|
||||
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
|
||||
queueApi.util.invalidateTags([
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'InvocationCacheStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
])
|
||||
);
|
||||
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
if (['in_progress'].includes(action.payload.data.status)) {
|
||||
forEach($nodeExecutionStates.get(), (nes) => {
|
||||
if (!nes) {
|
||||
return;
|
||||
|
@ -1,14 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketSessionRetrievalError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketSessionRetrievalError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`);
|
||||
},
|
||||
});
|
||||
};
|
@ -1,14 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketSubscribedSession } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketSubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Subscribed');
|
||||
},
|
||||
});
|
||||
};
|
@ -1,13 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketUnsubscribedSession } from 'services/events/actions';
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketUnsubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Unsubscribed');
|
||||
},
|
||||
});
|
||||
};
|
@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
|
||||
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
||||
}
|
||||
|
||||
const queueItemStatus = action.payload.data.queue_item.status;
|
||||
const queueItemStatus = action.payload.data.status;
|
||||
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
|
||||
resetStagingAreaIfEmpty(state);
|
||||
}
|
||||
|
@ -616,24 +616,12 @@ export const controlLayersSlice = createSlice({
|
||||
iiLayerAdded: {
|
||||
reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
|
||||
const { layerId, imageDTO } = action.payload;
|
||||
|
||||
// Retain opacity and denoising strength of existing initial image layer if exists
|
||||
let opacity = 1;
|
||||
let denoisingStrength = 0.75;
|
||||
const iiLayer = state.layers.find((l) => l.id === layerId);
|
||||
if (iiLayer) {
|
||||
assert(isInitialImageLayer(iiLayer));
|
||||
opacity = iiLayer.opacity;
|
||||
denoisingStrength = iiLayer.denoisingStrength;
|
||||
}
|
||||
|
||||
// Highlander! There can be only one!
|
||||
state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true));
|
||||
|
||||
const layer: InitialImageLayer = {
|
||||
id: layerId,
|
||||
type: 'initial_image_layer',
|
||||
opacity,
|
||||
opacity: 1,
|
||||
x: 0,
|
||||
y: 0,
|
||||
bbox: null,
|
||||
@ -641,7 +629,7 @@ export const controlLayersSlice = createSlice({
|
||||
isEnabled: true,
|
||||
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
|
||||
isSelected: true,
|
||||
denoisingStrength,
|
||||
denoisingStrength: 0.75,
|
||||
};
|
||||
state.layers.push(layer);
|
||||
exclusivelySelectLayer(state, layer.id);
|
||||
|
@ -1,8 +1,7 @@
|
||||
import type { UseToastOptions } from '@invoke-ai/ui-library';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { calculateStepPercentage } from 'features/system/util/calculateStepPercentage';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { startCase } from 'lodash-es';
|
||||
@ -14,12 +13,10 @@ import {
|
||||
socketGraphExecutionStateComplete,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationRetrievalError,
|
||||
socketInvocationStarted,
|
||||
socketModelLoadCompleted,
|
||||
socketModelLoadComplete,
|
||||
socketModelLoadStarted,
|
||||
socketQueueItemStatusChanged,
|
||||
socketSessionRetrievalError,
|
||||
} from 'services/events/actions';
|
||||
|
||||
import type { Language, SystemState } from './types';
|
||||
@ -110,20 +107,12 @@ export const systemSlice = createSlice({
|
||||
* Generator Progress
|
||||
*/
|
||||
builder.addCase(socketGeneratorProgress, (state, action) => {
|
||||
const {
|
||||
step,
|
||||
total_steps,
|
||||
order,
|
||||
progress_image,
|
||||
graph_execution_state_id: session_id,
|
||||
queue_batch_id: batch_id,
|
||||
} = action.payload.data;
|
||||
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data;
|
||||
|
||||
state.denoiseProgress = {
|
||||
step,
|
||||
total_steps,
|
||||
order,
|
||||
percentage: calculateStepPercentage(step, total_steps, order),
|
||||
percentage,
|
||||
progress_image,
|
||||
session_id,
|
||||
batch_id,
|
||||
@ -152,12 +141,12 @@ export const systemSlice = createSlice({
|
||||
state.status = 'LOADING_MODEL';
|
||||
});
|
||||
|
||||
builder.addCase(socketModelLoadCompleted, (state) => {
|
||||
builder.addCase(socketModelLoadComplete, (state) => {
|
||||
state.status = 'CONNECTED';
|
||||
});
|
||||
|
||||
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
|
||||
if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) {
|
||||
if (['completed', 'canceled', 'failed'].includes(action.payload.data.status)) {
|
||||
state.status = 'CONNECTED';
|
||||
state.denoiseProgress = null;
|
||||
}
|
||||
@ -168,7 +157,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Any server error
|
||||
*/
|
||||
builder.addMatcher(isAnyServerError, (state, action) => {
|
||||
builder.addCase(socketInvocationError, (state, action) => {
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
@ -194,8 +183,6 @@ export const {
|
||||
setShouldEnableInformationalPopovers,
|
||||
} = systemSlice.actions;
|
||||
|
||||
const isAnyServerError = isAnyOf(socketInvocationError, socketSessionRetrievalError, socketInvocationRetrievalError);
|
||||
|
||||
export const selectSystemSlice = (state: RootState) => state.system;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
|
@ -11,7 +11,6 @@ type DenoiseProgress = {
|
||||
progress_image: ProgressImage | null | undefined;
|
||||
step: number;
|
||||
total_steps: number;
|
||||
order: number;
|
||||
percentage: number;
|
||||
};
|
||||
|
||||
|
@ -1,13 +0,0 @@
|
||||
export const calculateStepPercentage = (step: number, total_steps: number, order: number) => {
|
||||
if (total_steps === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// we add one extra to step so that the progress bar will be full when denoise completes
|
||||
|
||||
if (order === 2) {
|
||||
return Math.floor((step + 1 + 1) / 2) / Math.floor((total_steps + 1) / 2);
|
||||
}
|
||||
|
||||
return (step + 1 + 1) / (total_steps + 1);
|
||||
};
|
File diff suppressed because one or more lines are too long
@ -39,7 +39,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
|
||||
|
||||
// Models
|
||||
export type ModelType = S['ModelType'];
|
||||
export type SubModelType = S['SubModelType'];
|
||||
export type BaseModelType = S['BaseModelType'];
|
||||
|
||||
// Model Configs
|
||||
|
@ -1,22 +1,29 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type {
|
||||
BulkDownloadCompletedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadFailedEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
GeneratorProgressEvent,
|
||||
GraphExecutionStateCompleteEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationRetrievalErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompletedEvent,
|
||||
ModelInstallDownloadingEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelLoadCompletedEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionRetrievalErrorEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionStartedEvent,
|
||||
} from 'services/events/types';
|
||||
|
||||
// Create actions for each socket
|
||||
@ -26,12 +33,6 @@ export const socketConnected = createAction('socket/socketConnected');
|
||||
|
||||
export const socketDisconnected = createAction('socket/socketDisconnected');
|
||||
|
||||
export const socketSubscribedSession = createAction<{
|
||||
sessionId: string;
|
||||
}>('socket/socketSubscribedSession');
|
||||
|
||||
export const socketUnsubscribedSession = createAction<{ sessionId: string }>('socket/socketUnsubscribedSession');
|
||||
|
||||
export const socketInvocationStarted = createAction<{
|
||||
data: InvocationStartedEvent;
|
||||
}>('socket/socketInvocationStarted');
|
||||
@ -44,29 +45,65 @@ export const socketInvocationError = createAction<{
|
||||
data: InvocationErrorEvent;
|
||||
}>('socket/socketInvocationError');
|
||||
|
||||
export const socketSessionStarted = createAction<{
|
||||
data: SessionStartedEvent;
|
||||
}>('socket/socketSessionStarted');
|
||||
|
||||
export const socketGraphExecutionStateComplete = createAction<{
|
||||
data: GraphExecutionStateCompleteEvent;
|
||||
data: SessionCompleteEvent;
|
||||
}>('socket/socketGraphExecutionStateComplete');
|
||||
|
||||
export const socketSessionCanceled = createAction<{
|
||||
data: SessionCanceledEvent;
|
||||
}>('socket/socketSessionCanceled');
|
||||
|
||||
export const socketGeneratorProgress = createAction<{
|
||||
data: GeneratorProgressEvent;
|
||||
data: InvocationDenoiseProgressEvent;
|
||||
}>('socket/socketGeneratorProgress');
|
||||
|
||||
export const socketModelLoadStarted = createAction<{
|
||||
data: ModelLoadStartedEvent;
|
||||
}>('socket/socketModelLoadStarted');
|
||||
|
||||
export const socketModelLoadCompleted = createAction<{
|
||||
data: ModelLoadCompletedEvent;
|
||||
}>('socket/socketModelLoadCompleted');
|
||||
export const socketModelLoadComplete = createAction<{
|
||||
data: ModelLoadCompleteEvent;
|
||||
}>('socket/socketModelLoadComplete');
|
||||
|
||||
export const socketModelInstallDownloading = createAction<{
|
||||
data: ModelInstallDownloadingEvent;
|
||||
}>('socket/socketModelInstallDownloading');
|
||||
export const socketDownloadStarted = createAction<{
|
||||
data: DownloadStartedEvent;
|
||||
}>('socket/socketDownloadStarted');
|
||||
|
||||
export const socketModelInstallCompleted = createAction<{
|
||||
data: ModelInstallCompletedEvent;
|
||||
}>('socket/socketModelInstallCompleted');
|
||||
export const socketDownloadProgress = createAction<{
|
||||
data: DownloadProgressEvent;
|
||||
}>('socket/socketDownloadProgress');
|
||||
|
||||
export const socketDownloadComplete = createAction<{
|
||||
data: DownloadCompleteEvent;
|
||||
}>('socket/socketDownloadComplete');
|
||||
|
||||
export const socketDownloadCancelled = createAction<{
|
||||
data: DownloadCancelledEvent;
|
||||
}>('socket/socketDownloadCancelled');
|
||||
|
||||
export const socketDownloadError = createAction<{
|
||||
data: DownloadErrorEvent;
|
||||
}>('socket/socketDownloadError');
|
||||
|
||||
export const socketModelInstallStarted = createAction<{
|
||||
data: ModelInstallStartedEvent;
|
||||
}>('socket/socketModelInstallStarted');
|
||||
|
||||
export const socketModelInstallDownloadProgress = createAction<{
|
||||
data: ModelInstallDownloadProgressEvent;
|
||||
}>('socket/socketModelInstallDownloadProgress');
|
||||
|
||||
export const socketModelInstallDownloadsComplete = createAction<{
|
||||
data: ModelInstallDownloadsCompleteEvent;
|
||||
}>('socket/socketModelInstallDownloadsComplete');
|
||||
|
||||
export const socketModelInstallComplete = createAction<{
|
||||
data: ModelInstallCompleteEvent;
|
||||
}>('socket/socketModelInstallComplete');
|
||||
|
||||
export const socketModelInstallError = createAction<{
|
||||
data: ModelInstallErrorEvent;
|
||||
@ -76,14 +113,6 @@ export const socketModelInstallCancelled = createAction<{
|
||||
data: ModelInstallCancelledEvent;
|
||||
}>('socket/socketModelInstallCancelled');
|
||||
|
||||
export const socketSessionRetrievalError = createAction<{
|
||||
data: SessionRetrievalErrorEvent;
|
||||
}>('socket/socketSessionRetrievalError');
|
||||
|
||||
export const socketInvocationRetrievalError = createAction<{
|
||||
data: InvocationRetrievalErrorEvent;
|
||||
}>('socket/socketInvocationRetrievalError');
|
||||
|
||||
export const socketQueueItemStatusChanged = createAction<{
|
||||
data: QueueItemStatusChangedEvent;
|
||||
}>('socket/socketQueueItemStatusChanged');
|
||||
@ -92,10 +121,10 @@ export const socketBulkDownloadStarted = createAction<{
|
||||
data: BulkDownloadStartedEvent;
|
||||
}>('socket/socketBulkDownloadStarted');
|
||||
|
||||
export const socketBulkDownloadCompleted = createAction<{
|
||||
data: BulkDownloadCompletedEvent;
|
||||
}>('socket/socketBulkDownloadCompleted');
|
||||
export const socketBulkDownloadComplete = createAction<{
|
||||
data: BulkDownloadCompleteEvent;
|
||||
}>('socket/socketBulkDownloadComplete');
|
||||
|
||||
export const socketBulkDownloadFailed = createAction<{
|
||||
export const socketBulkDownloadError = createAction<{
|
||||
data: BulkDownloadFailedEvent;
|
||||
}>('socket/socketBulkDownloadFailed');
|
||||
}>('socket/socketBulkDownloadError');
|
||||
|
@ -1,275 +1,75 @@
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { AnyModelConfig, Graph, GraphExecutionState, SubModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* A progress image, we get one for each step in the generation
|
||||
*/
|
||||
export type ProgressImage = {
|
||||
dataURL: string;
|
||||
width: number;
|
||||
height: number;
|
||||
};
|
||||
import type { Graph, GraphExecutionState, S } from 'services/api/types';
|
||||
|
||||
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
|
||||
|
||||
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
|
||||
|
||||
type BaseNode = {
|
||||
id: string;
|
||||
type: string;
|
||||
[key: string]: AnyInvocation[keyof AnyInvocation];
|
||||
};
|
||||
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
||||
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
||||
|
||||
export type ModelLoadStartedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
model_config: AnyModelConfig;
|
||||
submodel_type?: SubModelType | null;
|
||||
};
|
||||
export type InvocationStartedEvent = S['InvocationStartedEvent'];
|
||||
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
|
||||
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result'> & { result: AnyResult };
|
||||
export type InvocationErrorEvent = S['InvocationErrorEvent'];
|
||||
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
||||
|
||||
export type ModelLoadCompletedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
model_config: AnyModelConfig;
|
||||
submodel_type?: SubModelType | null;
|
||||
};
|
||||
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
||||
export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent'];
|
||||
export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
|
||||
export type ModelInstallErrorEvent = S['ModelInstallErrorEvent'];
|
||||
export type ModelInstallStartedEvent = S['ModelInstallStartedEvent'];
|
||||
export type ModelInstallCancelledEvent = S['ModelInstallCancelledEvent'];
|
||||
|
||||
export type ModelInstallDownloadingEvent = {
|
||||
bytes: number;
|
||||
local_path: string;
|
||||
source: string;
|
||||
timestamp: number;
|
||||
total_bytes: number;
|
||||
id: number;
|
||||
};
|
||||
export type DownloadStartedEvent = S['DownloadStartedEvent'];
|
||||
export type DownloadProgressEvent = S['DownloadProgressEvent'];
|
||||
export type DownloadCompleteEvent = S['DownloadCompleteEvent'];
|
||||
export type DownloadCancelledEvent = S['DownloadCancelledEvent'];
|
||||
export type DownloadErrorEvent = S['DownloadErrorEvent'];
|
||||
|
||||
export type ModelInstallCompletedEvent = {
|
||||
key: number;
|
||||
source: string;
|
||||
timestamp: number;
|
||||
id: number;
|
||||
};
|
||||
export type SessionStartedEvent = S['SessionStartedEvent'];
|
||||
export type SessionCompleteEvent = S['SessionCompleteEvent'];
|
||||
export type SessionCanceledEvent = S['SessionCanceledEvent'];
|
||||
|
||||
export type ModelInstallErrorEvent = {
|
||||
error: string;
|
||||
error_type: string;
|
||||
source: string;
|
||||
timestamp: number;
|
||||
id: number;
|
||||
};
|
||||
export type QueueItemStatusChangedEvent = S['QueueItemStatusChangedEvent'];
|
||||
|
||||
export type ModelInstallCancelledEvent = {
|
||||
source: string;
|
||||
timestamp: number;
|
||||
id: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `generator_progress` socket.io event.
|
||||
*
|
||||
* @example socket.on('generator_progress', (data: GeneratorProgressEvent) => { ... }
|
||||
*/
|
||||
export type GeneratorProgressEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node_id: string;
|
||||
source_node_id: string;
|
||||
progress_image?: ProgressImage;
|
||||
step: number;
|
||||
order: number;
|
||||
total_steps: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_complete` socket.io event.
|
||||
*
|
||||
* `result` is a discriminated union with a `type` property as the discriminant.
|
||||
*
|
||||
* @example socket.on('invocation_complete', (data: InvocationCompleteEvent) => { ... }
|
||||
*/
|
||||
export type InvocationCompleteEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
result: AnyResult;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_error', (data: InvocationErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_started` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_started', (data: InvocationStartedEvent) => { ... }
|
||||
*/
|
||||
export type InvocationStartedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `graph_execution_state_complete` socket.io event.
|
||||
*
|
||||
* @example socket.on('graph_execution_state_complete', (data: GraphExecutionStateCompleteEvent) => { ... }
|
||||
*/
|
||||
export type GraphExecutionStateCompleteEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `session_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type SessionRetrievalErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationRetrievalErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `queue_item_status_changed` socket.io event.
|
||||
*
|
||||
* @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... }
|
||||
*/
|
||||
export type QueueItemStatusChangedEvent = {
|
||||
queue_id: string;
|
||||
queue_item: {
|
||||
queue_id: string;
|
||||
item_id: number;
|
||||
batch_id: string;
|
||||
session_id: string;
|
||||
status: components['schemas']['SessionQueueItemDTO']['status'];
|
||||
error: string | undefined;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
started_at: string | undefined;
|
||||
completed_at: string | undefined;
|
||||
};
|
||||
batch_status: {
|
||||
queue_id: string;
|
||||
batch_id: string;
|
||||
pending: number;
|
||||
in_progress: number;
|
||||
completed: number;
|
||||
failed: number;
|
||||
canceled: number;
|
||||
total: number;
|
||||
};
|
||||
queue_status: {
|
||||
queue_id: string;
|
||||
item_id?: number;
|
||||
batch_id?: string;
|
||||
session_id?: string;
|
||||
pending: number;
|
||||
in_progress: number;
|
||||
completed: number;
|
||||
failed: number;
|
||||
canceled: number;
|
||||
total: number;
|
||||
};
|
||||
};
|
||||
export type BulkDownloadStartedEvent = S['BulkDownloadStartedEvent'];
|
||||
export type BulkDownloadCompleteEvent = S['BulkDownloadCompleteEvent'];
|
||||
export type BulkDownloadFailedEvent = S['BulkDownloadErrorEvent'];
|
||||
|
||||
type ClientEmitSubscribeQueue = {
|
||||
queue_id: string;
|
||||
};
|
||||
|
||||
type ClientEmitUnsubscribeQueue = {
|
||||
queue_id: string;
|
||||
};
|
||||
|
||||
export type BulkDownloadStartedEvent = {
|
||||
bulk_download_id: string;
|
||||
bulk_download_item_id: string;
|
||||
bulk_download_item_name: string;
|
||||
};
|
||||
|
||||
export type BulkDownloadCompletedEvent = {
|
||||
bulk_download_id: string;
|
||||
bulk_download_item_id: string;
|
||||
bulk_download_item_name: string;
|
||||
};
|
||||
|
||||
export type BulkDownloadFailedEvent = {
|
||||
bulk_download_id: string;
|
||||
bulk_download_item_id: string;
|
||||
bulk_download_item_name: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
|
||||
type ClientEmitSubscribeBulkDownload = {
|
||||
bulk_download_id: string;
|
||||
};
|
||||
|
||||
type ClientEmitUnsubscribeBulkDownload = {
|
||||
bulk_download_id: string;
|
||||
};
|
||||
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
|
||||
|
||||
export type ServerToClientEvents = {
|
||||
generator_progress: (payload: GeneratorProgressEvent) => void;
|
||||
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void;
|
||||
invocation_complete: (payload: InvocationCompleteEvent) => void;
|
||||
invocation_error: (payload: InvocationErrorEvent) => void;
|
||||
invocation_started: (payload: InvocationStartedEvent) => void;
|
||||
graph_execution_state_complete: (payload: GraphExecutionStateCompleteEvent) => void;
|
||||
session_started: (payload: SessionStartedEvent) => void;
|
||||
session_complete: (payload: SessionCompleteEvent) => void;
|
||||
session_canceled: (payload: SessionCanceledEvent) => void;
|
||||
download_started: (payload: DownloadStartedEvent) => void;
|
||||
download_progress: (payload: DownloadProgressEvent) => void;
|
||||
download_complete: (payload: DownloadCompleteEvent) => void;
|
||||
download_cancelled: (payload: DownloadCancelledEvent) => void;
|
||||
download_error: (payload: DownloadErrorEvent) => void;
|
||||
model_load_started: (payload: ModelLoadStartedEvent) => void;
|
||||
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
|
||||
model_install_downloading: (payload: ModelInstallDownloadingEvent) => void;
|
||||
model_install_completed: (payload: ModelInstallCompletedEvent) => void;
|
||||
model_install_started: (payload: ModelInstallStartedEvent) => void;
|
||||
model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void;
|
||||
model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void;
|
||||
model_install_complete: (payload: ModelInstallCompleteEvent) => void;
|
||||
model_install_error: (payload: ModelInstallErrorEvent) => void;
|
||||
model_install_canceled: (payload: ModelInstallCancelledEvent) => void;
|
||||
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
|
||||
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
|
||||
model_install_cancelled: (payload: ModelInstallCancelledEvent) => void;
|
||||
model_load_complete: (payload: ModelLoadCompleteEvent) => void;
|
||||
queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void;
|
||||
bulk_download_started: (payload: BulkDownloadStartedEvent) => void;
|
||||
bulk_download_completed: (payload: BulkDownloadCompletedEvent) => void;
|
||||
bulk_download_failed: (payload: BulkDownloadFailedEvent) => void;
|
||||
bulk_download_complete: (payload: BulkDownloadCompleteEvent) => void;
|
||||
bulk_download_error: (payload: BulkDownloadFailedEvent) => void;
|
||||
};
|
||||
|
||||
export type ClientToServerEvents = {
|
||||
|
@ -5,24 +5,32 @@ import type { AppDispatch } from 'app/store/store';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import {
|
||||
socketBulkDownloadCompleted,
|
||||
socketBulkDownloadFailed,
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadStarted,
|
||||
socketConnected,
|
||||
socketDisconnected,
|
||||
socketDownloadCancelled,
|
||||
socketDownloadComplete,
|
||||
socketDownloadError,
|
||||
socketDownloadProgress,
|
||||
socketDownloadStarted,
|
||||
socketGeneratorProgress,
|
||||
socketGraphExecutionStateComplete,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationRetrievalError,
|
||||
socketInvocationStarted,
|
||||
socketModelInstallCompleted,
|
||||
socketModelInstallDownloading,
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallDownloadsComplete,
|
||||
socketModelInstallError,
|
||||
socketModelLoadCompleted,
|
||||
socketModelInstallStarted,
|
||||
socketModelLoadComplete,
|
||||
socketModelLoadStarted,
|
||||
socketQueueItemStatusChanged,
|
||||
socketSessionRetrievalError,
|
||||
socketSessionCanceled,
|
||||
socketSessionStarted,
|
||||
} from 'services/events/actions';
|
||||
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
|
||||
import type { Socket } from 'socket.io-client';
|
||||
@ -65,131 +73,87 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Disconnect
|
||||
*/
|
||||
socket.on('disconnect', () => {
|
||||
dispatch(socketDisconnected());
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation started
|
||||
*/
|
||||
socket.on('invocation_started', (data) => {
|
||||
dispatch(socketInvocationStarted({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Generator progress
|
||||
*/
|
||||
socket.on('generator_progress', (data) => {
|
||||
socket.on('invocation_denoise_progress', (data) => {
|
||||
dispatch(socketGeneratorProgress({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation error
|
||||
*/
|
||||
socket.on('invocation_error', (data) => {
|
||||
dispatch(socketInvocationError({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation complete
|
||||
*/
|
||||
socket.on('invocation_complete', (data) => {
|
||||
dispatch(
|
||||
socketInvocationComplete({
|
||||
data,
|
||||
})
|
||||
);
|
||||
dispatch(socketInvocationComplete({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Graph complete
|
||||
*/
|
||||
socket.on('graph_execution_state_complete', (data) => {
|
||||
dispatch(
|
||||
socketGraphExecutionStateComplete({
|
||||
data,
|
||||
})
|
||||
);
|
||||
socket.on('session_started', (data) => {
|
||||
dispatch(socketSessionStarted({ data }));
|
||||
});
|
||||
|
||||
socket.on('session_complete', (data) => {
|
||||
dispatch(socketGraphExecutionStateComplete({ data }));
|
||||
});
|
||||
|
||||
socket.on('session_canceled', (data) => {
|
||||
dispatch(socketSessionCanceled({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Model load started
|
||||
*/
|
||||
socket.on('model_load_started', (data) => {
|
||||
dispatch(
|
||||
socketModelLoadStarted({
|
||||
data,
|
||||
})
|
||||
);
|
||||
dispatch(socketModelLoadStarted({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Model load completed
|
||||
*/
|
||||
socket.on('model_load_completed', (data) => {
|
||||
dispatch(
|
||||
socketModelLoadCompleted({
|
||||
data,
|
||||
})
|
||||
);
|
||||
socket.on('model_load_complete', (data) => {
|
||||
dispatch(socketModelLoadComplete({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Model Install Downloading
|
||||
*/
|
||||
socket.on('model_install_downloading', (data) => {
|
||||
dispatch(
|
||||
socketModelInstallDownloading({
|
||||
data,
|
||||
})
|
||||
);
|
||||
socket.on('download_started', (data) => {
|
||||
dispatch(socketDownloadStarted({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Model Install Completed
|
||||
*/
|
||||
socket.on('model_install_completed', (data) => {
|
||||
dispatch(
|
||||
socketModelInstallCompleted({
|
||||
data,
|
||||
})
|
||||
);
|
||||
socket.on('download_progress', (data) => {
|
||||
dispatch(socketDownloadProgress({ data }));
|
||||
});
|
||||
|
||||
socket.on('download_complete', (data) => {
|
||||
dispatch(socketDownloadComplete({ data }));
|
||||
});
|
||||
|
||||
socket.on('download_cancelled', (data) => {
|
||||
dispatch(socketDownloadCancelled({ data }));
|
||||
});
|
||||
|
||||
socket.on('download_error', (data) => {
|
||||
dispatch(socketDownloadError({ data }));
|
||||
});
|
||||
|
||||
socket.on('model_install_started', (data) => {
|
||||
dispatch(socketModelInstallStarted({ data }));
|
||||
});
|
||||
|
||||
socket.on('model_install_download_progress', (data) => {
|
||||
dispatch(socketModelInstallDownloadProgress({ data }));
|
||||
});
|
||||
|
||||
socket.on('model_install_downloads_complete', (data) => {
|
||||
dispatch(socketModelInstallDownloadsComplete({ data }));
|
||||
});
|
||||
|
||||
socket.on('model_install_complete', (data) => {
|
||||
dispatch(socketModelInstallComplete({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Model Install Error
|
||||
*/
|
||||
socket.on('model_install_error', (data) => {
|
||||
dispatch(
|
||||
socketModelInstallError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
dispatch(socketModelInstallError({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Session retrieval error
|
||||
*/
|
||||
socket.on('session_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketSessionRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation retrieval error
|
||||
*/
|
||||
socket.on('invocation_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketInvocationRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
socket.on('model_install_cancelled', (data) => {
|
||||
dispatch(socketModelInstallCancelled({ data }));
|
||||
});
|
||||
|
||||
socket.on('queue_item_status_changed', (data) => {
|
||||
@ -200,11 +164,11 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||
dispatch(socketBulkDownloadStarted({ data }));
|
||||
});
|
||||
|
||||
socket.on('bulk_download_completed', (data) => {
|
||||
dispatch(socketBulkDownloadCompleted({ data }));
|
||||
socket.on('bulk_download_complete', (data) => {
|
||||
dispatch(socketBulkDownloadComplete({ data }));
|
||||
});
|
||||
|
||||
socket.on('bulk_download_failed', (data) => {
|
||||
dispatch(socketBulkDownloadFailed({ data }));
|
||||
socket.on('bulk_download_error', (data) => {
|
||||
dispatch(socketBulkDownloadError({ data }));
|
||||
});
|
||||
};
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "4.2.2"
|
||||
__version__ = "4.2.1"
|
||||
|
@ -9,6 +9,11 @@ import pytest
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
|
||||
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecordNotFoundException,
|
||||
@ -281,9 +286,9 @@ def assert_handler_success(
|
||||
|
||||
# Check that the correct events were emitted
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_completed"
|
||||
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent)
|
||||
assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path)
|
||||
|
||||
|
||||
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
@ -329,9 +334,9 @@ def test_handler_on_generic_exception(
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == exception.__str__()
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
|
||||
assert event_bus.events[1].error == exception.__str__()
|
||||
|
||||
|
||||
def execute_handler_test_on_error(
|
||||
@ -344,9 +349,9 @@ def execute_handler_test_on_error(
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
|
||||
assert event_bus.events[1].error == error.__str__()
|
||||
|
||||
|
||||
def test_delete(tmp_path: Path):
|
||||
|
@ -10,6 +10,13 @@ from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.events.events_common import (
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
)
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
@ -116,14 +123,14 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.join()
|
||||
events = event_bus.events
|
||||
assert len(events) == 3
|
||||
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
||||
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
||||
assert events[0].event_name == "download_started"
|
||||
assert events[1].event_name == "download_progress"
|
||||
assert events[1].payload["total_bytes"] > 0
|
||||
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
||||
assert events[2].event_name == "download_complete"
|
||||
assert events[2].payload["total_bytes"] == 32029
|
||||
assert isinstance(events[0], DownloadStartedEvent)
|
||||
assert isinstance(events[1], DownloadProgressEvent)
|
||||
assert isinstance(events[2], DownloadCompleteEvent)
|
||||
assert events[0].timestamp <= events[1].timestamp
|
||||
assert events[1].timestamp <= events[2].timestamp
|
||||
assert events[1].total_bytes > 0
|
||||
assert events[1].current_bytes <= events[1].total_bytes
|
||||
assert events[2].total_bytes == 32029
|
||||
|
||||
# test a failure
|
||||
event_bus.events = [] # reset our accumulator
|
||||
@ -132,10 +139,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
events = event_bus.events
|
||||
print("\n".join([x.model_dump_json() for x in events]))
|
||||
assert len(events) == 1
|
||||
assert events[0].event_name == "download_error"
|
||||
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
||||
assert events[0].payload["error"] is not None
|
||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
||||
assert isinstance(events[0], DownloadErrorEvent)
|
||||
assert events[0].error_type == "HTTPError(NOT FOUND)"
|
||||
assert events[0].error is not None
|
||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error)
|
||||
queue.stop()
|
||||
|
||||
|
||||
@ -202,6 +209,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert job.status == DownloadJobStatus.CANCELLED
|
||||
assert cancelled
|
||||
events = event_bus.events
|
||||
assert events[-1].event_name == "download_cancelled"
|
||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
assert isinstance(events[-1], DownloadCancelledEvent)
|
||||
assert events[-1].source == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
@ -13,16 +13,24 @@ from pydantic_core import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.model_install import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
OS = platform.uname().system
|
||||
|
||||
@ -130,17 +138,16 @@ def test_background_install(
|
||||
assert job.total_bytes == size
|
||||
|
||||
# test that the expected events were issued
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
assert bus
|
||||
assert hasattr(bus, "events")
|
||||
|
||||
assert len(bus.events) == 2
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert "model_install_running" in event_names
|
||||
assert "model_install_completed" in event_names
|
||||
assert Path(bus.events[0].payload["source"]) == source
|
||||
assert Path(bus.events[1].payload["source"]) == source
|
||||
key = bus.events[1].payload["key"]
|
||||
assert isinstance(bus.events[0], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
|
||||
assert Path(bus.events[0].source) == source
|
||||
assert Path(bus.events[1].source) == source
|
||||
key = bus.events[1].key
|
||||
assert key is not None
|
||||
|
||||
# see if the thing actually got installed at the expected location
|
||||
@ -219,7 +226,7 @@ def test_delete_register(
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert store is not None
|
||||
assert bus is not None
|
||||
@ -237,20 +244,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == [
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
]
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
|
||||
assert isinstance(bus.events[2], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
@ -267,15 +271,10 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
}
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
|
@ -3,16 +3,13 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||
@ -39,27 +36,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoHFModelJson1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
|
||||
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
|
||||
@ -127,7 +104,7 @@ def mm2_installer(
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
events = TestEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
|
||||
installer = ModelInstallService(
|
||||
|
@ -20,6 +20,7 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -10,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@ -117,11 +116,10 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
||||
)
|
||||
|
||||
|
||||
class TestEvent(BaseModel):
|
||||
class TestEvent(EventBase):
|
||||
__test__ = False # not a pytest test case
|
||||
|
||||
event_name: str
|
||||
payload: Any
|
||||
__event_name__ = "test_event"
|
||||
|
||||
|
||||
class TestEventService(EventServiceBase):
|
||||
@ -129,10 +127,10 @@ class TestEventService(EventServiceBase):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.events: list[TestEvent] = []
|
||||
self.events: list[EventBase] = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self.events.append(event)
|
||||
pass
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user