From 32a02b3329a6bab64fec1bbda201bb3bcec0ca35 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 14 Mar 2024 19:04:19 +1100 Subject: [PATCH] refactor(events): use pydantic schemas for events Our events handling and implementation has a couple pain points: - Adding or removing data from event payloads requires changes wherever the events are dispatched from. - We have no type safety for events and need to rely on string matching and dict access when interacting with events. - Frontend types for socket events must be manually typed. This has caused several bugs. `fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly. This allows us to eliminate a layer of indirection and some unpleasant complexity: - Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed. - Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload. - Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync. This commit moves the backend over to this improved event handling setup. --- invokeai/app/api/dependencies.py | 2 +- invokeai/app/api/events.py | 52 -- invokeai/app/api/routers/model_manager.py | 2 +- invokeai/app/api/sockets.py | 149 +++-- invokeai/app/api_app.py | 31 +- .../bulk_download/bulk_download_default.py | 17 +- .../app/services/download/download_default.py | 8 +- invokeai/app/services/events/events_base.py | 533 ++++----------- invokeai/app/services/events/events_common.py | 625 ++++++++++++++++++ .../services/events/events_fastapievents.py | 46 ++ .../app/services/model_install/__init__.py | 4 +- .../model_install/model_install_base.py | 235 +------ .../model_install/model_install_common.py | 231 +++++++ .../model_install/model_install_default.py | 53 +- .../services/model_load/model_load_base.py | 9 +- .../services/model_load/model_load_default.py | 55 +- .../session_processor_default.py | 67 +- .../session_queue/session_queue_sqlite.py | 102 +-- .../app/services/shared/invocation_context.py | 6 +- invokeai/app/util/step_callback.py | 17 +- .../model_install/test_model_install.py | 4 +- 21 files changed, 1241 insertions(+), 1007 deletions(-) delete mode 100644 invokeai/app/api/events.py create mode 100644 invokeai/app/services/events/events_common.py create mode 100644 invokeai/app/services/events/events_fastapievents.py create mode 100644 invokeai/app/services/model_install/model_install_common.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9a6c7416f6..c017baf9aa 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService +from ..services.events.events_fastapievents import FastAPIEventService from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService @@ -33,7 +34,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage -from .events import FastAPIEventService # TODO: is there a better way to achieve this? diff --git a/invokeai/app/api/events.py b/invokeai/app/api/events.py deleted file mode 100644 index 2ac07e6dfe..0000000000 --- a/invokeai/app/api/events.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import asyncio -import threading -from queue import Empty, Queue -from typing import Any - -from fastapi_events.dispatcher import dispatch - -from ..services.events.events_base import EventServiceBase - - -class FastAPIEventService(EventServiceBase): - event_handler_id: int - __queue: Queue - __stop_event: threading.Event - - def __init__(self, event_handler_id: int) -> None: - self.event_handler_id = event_handler_id - self.__queue = Queue() - self.__stop_event = threading.Event() - asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event)) - - super().__init__() - - def stop(self, *args, **kwargs): - self.__stop_event.set() - self.__queue.put(None) - - def dispatch(self, event_name: str, payload: Any) -> None: - self.__queue.put({"event_name": event_name, "payload": payload}) - - async def __dispatch_from_queue(self, stop_event: threading.Event): - """Get events on from the queue and dispatch them, from the correct thread""" - while not stop_event.is_set(): - try: - event = self.__queue.get(block=False) - if not event: # Probably stopping - continue - - dispatch( - event.get("event_name"), - payload=event.get("payload"), - middleware_id=self.event_handler_id, - ) - - except Empty: - await asyncio.sleep(0.1) - pass - - except asyncio.CancelledError as e: - raise e # Raise a proper error diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 1ba3e30e07..b1221f7a34 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException -from invokeai.app.services.model_install import ModelInstallJob +from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 463545d9bc..972d222f47 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -1,66 +1,119 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from typing import Any + from fastapi import FastAPI -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event +from pydantic import BaseModel from socketio import ASGIApp, AsyncServer -from ..services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadEvent, + BulkDownloadStartedEvent, + FastAPIEvent, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelEvent, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueEvent, + QueueItemStatusChangedEvent, + SessionCanceledEvent, + SessionCompleteEvent, + SessionStartedEvent, + register_events, +) + + +class QueueSubscriptionEvent(BaseModel): + queue_id: str + + +class BulkDownloadSubscriptionEvent(BaseModel): + bulk_download_id: str class SocketIO: - __sio: AsyncServer - __app: ASGIApp + _sub_queue = "subscribe_queue" + _unsub_queue = "unsubscribe_queue" - __sub_queue: str = "subscribe_queue" - __unsub_queue: str = "unsubscribe_queue" - - __sub_bulk_download: str = "subscribe_bulk_download" - __unsub_bulk_download: str = "unsubscribe_bulk_download" + _sub_bulk_download = "subscribe_bulk_download" + _unsub_bulk_download = "unsubscribe_bulk_download" def __init__(self, app: FastAPI): - self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") - self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") - app.mount("/ws", self.__app) + self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") + self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io") + app.mount("/ws", self._app) - self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) - self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) - local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) + self._sio.on(self._sub_queue, handler=self._handle_sub_queue) + self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue) + self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) + self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) - self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) - self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) - local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) - - async def _handle_queue_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["queue_id"], + register_events( + { + InvocationStartedEvent, + InvocationDenoiseProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + SessionStartedEvent, + SessionCompleteEvent, + SessionCanceledEvent, + QueueItemStatusChangedEvent, + BatchEnqueuedEvent, + QueueClearedEvent, + }, + self._handle_queue_event, ) - async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.enter_room(sid, data["queue_id"]) - - async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.leave_room(sid, data["queue_id"]) - - async def _handle_model_event(self, event: Event) -> None: - await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) - - async def _handle_bulk_download_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["bulk_download_id"], + register_events( + { + ModelLoadStartedEvent, + ModelLoadCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ModelInstallCancelledEvent, + ModelInstallErrorEvent, + }, + self._handle_model_event, ) - async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.enter_room(sid, data["bulk_download_id"]) + register_events( + {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}, + self._handle_bulk_image_download_event, + ) - async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.leave_room(sid, data["bulk_download_id"]) + async def _handle_sub_queue(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) + + async def _handle_unsub_queue(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) + + async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) + + async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) + + async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]): + event_name, payload = event + await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id) + + async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None: + event_name, payload = event + await self._sio.emit(event=event_name, data=payload.model_dump()) + + async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None: + event_name, payload = event + await self._sio.emit(event=event_name, data=payload.model_dump()) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 062682f7d0..710c878c9b 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -5,7 +5,7 @@ import socket from contextlib import asynccontextmanager from inspect import signature from pathlib import Path -from typing import Any +from typing import Any, cast import torch import uvicorn @@ -17,6 +17,8 @@ from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware +from fastapi_events.registry.payload_schema import registry as fastapi_events_registry +from pydantic import BaseModel from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available @@ -182,23 +184,16 @@ def custom_openapi() -> dict[str, Any]: openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) invoker_schema["class"] = "invocation" - # This code no longer seems to be necessary? - # Leave it here just in case - # - # from invokeai.backend.model_manager import get_model_config_formats - # formats = get_model_config_formats() - # for model_config_name, enum_set in formats.items(): - - # if model_config_name in openapi_schema["components"]["schemas"]: - # # print(f"Config with name {name} already defined") - # continue - - # openapi_schema["components"]["schemas"][model_config_name] = { - # "title": model_config_name, - # "description": "An enumeration.", - # "type": "string", - # "enum": [v.value for v in enum_set], - # } + # Add all pydantic event schemas registered with fastapi-events + for payload in fastapi_events_registry.data.values(): + json_schema = cast(BaseModel, payload).model_json_schema( + mode="serialization", ref_template="#/components/schemas/{model}" + ) + if "$defs" in json_schema: + for schema_key, schema in json_schema["$defs"].items(): + openapi_schema["components"]["schemas"][schema_key] = schema + del json_schema["$defs"] + openapi_schema["components"]["schemas"][payload.__name__] = json_schema app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 04cec928f4..d4bf059b8f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None self._invoker.services.events.emit_bulk_download_started( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_completed( @@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None - self._invoker.services.events.emit_bulk_download_completed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + self._invoker.services.events.emit_bulk_download_complete( + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_failed( @@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert exception is not None - self._invoker.services.events.emit_bulk_download_failed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, - error=str(exception), + self._invoker.services.events.emit_bulk_download_error( + bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception) ) def stop(self, *args, **kwargs): diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7d8229fba1..4c9d9bda13 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,14 +8,13 @@ import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError from tqdm import tqdm -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp from invokeai.backend.util.logging import InvokeAILogger @@ -30,6 +29,9 @@ from .download_base import ( UnknownJobIDException, ) +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + # Maximum number of bytes to download during each call to requests.iter_content() DOWNLOAD_CHUNK_SIZE = 100000 @@ -40,7 +42,7 @@ class DownloadQueueService(DownloadQueueServiceBase): def __init__( self, max_parallel_dl: int = 5, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index aa91cdaec8..8d305f530d 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -1,490 +1,181 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Optional -from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.app.services.session_queue.session_queue_common import ( - BatchStatus, - EnqueueBatchResult, - SessionQueueItem, - SessionQueueStatus, +from invokeai.app.services.events.events_common import ( + BaseEvent, + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadStartedEvent, + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueItemStatusChangedEvent, + SessionCanceledEvent, + SessionCompleteEvent, + SessionStartedEvent, ) -from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import AnyModelConfig -from invokeai.backend.model_manager.config import SubModelType + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput + from invokeai.app.services.events.events_common import BaseEvent + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + from invokeai.app.services.session_processor.session_processor_common import ProgressImage + from invokeai.app.services.session_queue.session_queue_common import ( + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, + ) + from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType class EventServiceBase: - queue_event: str = "queue_event" - bulk_download_event: str = "bulk_download_event" - download_event: str = "download_event" - model_event: str = "model_event" - """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event_name: str, payload: Any) -> None: + def dispatch(self, event: "BaseEvent") -> None: pass - def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: - """Bulk download events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.bulk_download_event, - payload={"event": event_name, "data": payload}, - ) + # region: Invocation - def __emit_queue_event(self, event_name: str, payload: dict) -> None: - """Queue events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.queue_event, - payload={"event": event_name, "data": payload}, - ) + def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None: + self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) - def __emit_download_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.download_event, - payload={"event": event_name, "data": payload}, - ) - - def __emit_model_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.model_event, - payload={"event": event_name, "data": payload}, - ) - - # Define events here for every event in the system. - # This will make them easier to integrate until we find a schema generator. - def emit_generator_progress( + def emit_invocation_denoise_progress( self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node_id: str, - source_node_id: str, - progress_image: Optional[ProgressImage], + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", step: int, - order: int, total_steps: int, + progress_image: "ProgressImage", ) -> None: - """Emitted when there is generation progress""" - self.__emit_queue_event( - event_name="generator_progress", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node_id": node_id, - "source_node_id": source_node_id, - "progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None, - "step": step, - "order": order, - "total_steps": total_steps, - }, - ) + self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, step, total_steps, progress_image)) def emit_invocation_complete( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - result: dict, - node: dict, - source_node_id: str, + self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "result": result, - }, - ) + self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output)) def emit_invocation_error( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, - error_type: str, - error: str, - user_id: str | None, - project_id: str | None, + self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", error_type: str, error: str ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "error_type": error_type, - "error": error, - "user_id": user_id, - "project_id": project_id, - }, - ) + self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error)) - def emit_invocation_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, - ) -> None: - """Emitted when an invocation has started""" - self.__emit_queue_event( - event_name="invocation_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - }, - ) + # endregion - def emit_graph_execution_complete( - self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str - ) -> None: - """Emitted when a session has completed all invocations""" - self.__emit_queue_event( - event_name="graph_execution_state_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) + # region Session - def emit_model_load_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is requested""" - self.__emit_queue_event( - event_name="model_load_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) + def emit_session_started(self, queue_item: "SessionQueueItem") -> None: + self.dispatch(SessionStartedEvent.build(queue_item)) - def emit_model_load_completed( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is correctly loaded (returns model info)""" - self.__emit_queue_event( - event_name="model_load_completed", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) + def emit_session_complete(self, queue_item: "SessionQueueItem") -> None: + self.dispatch(SessionCompleteEvent.build(queue_item)) - def emit_session_canceled( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - ) -> None: - """Emitted when a session is canceled""" - self.__emit_queue_event( - event_name="session_canceled", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) + def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None: + self.dispatch(SessionCanceledEvent.build(queue_item)) + + # endregion + + # region Queue def emit_queue_item_status_changed( - self, - session_queue_item: SessionQueueItem, - batch_status: BatchStatus, - queue_status: SessionQueueStatus, + self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus" ) -> None: - """Emitted when a queue item's status changes""" - self.__emit_queue_event( - event_name="queue_item_status_changed", - payload={ - "queue_id": queue_status.queue_id, - "queue_item": { - "queue_id": session_queue_item.queue_id, - "item_id": session_queue_item.item_id, - "status": session_queue_item.status, - "batch_id": session_queue_item.batch_id, - "session_id": session_queue_item.session_id, - "error": session_queue_item.error, - "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None, - "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None, - "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, - "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None, - }, - "batch_status": batch_status.model_dump(mode="json"), - "queue_status": queue_status.model_dump(mode="json"), - }, - ) + self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) - def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: - """Emitted when a batch is enqueued""" - self.__emit_queue_event( - event_name="batch_enqueued", - payload={ - "queue_id": enqueue_result.queue_id, - "batch_id": enqueue_result.batch.batch_id, - "enqueued": enqueue_result.enqueued, - }, - ) + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: + self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) def emit_queue_cleared(self, queue_id: str) -> None: - """Emitted when the queue is cleared""" - self.__emit_queue_event( - event_name="queue_cleared", - payload={"queue_id": queue_id}, - ) + self.dispatch(QueueClearedEvent.build(queue_id)) + + # endregion + + # region Download def emit_download_started(self, source: str, download_path: str) -> None: - """ - Emit when a download job is started. - - :param url: The downloaded url - """ - self.__emit_download_event( - event_name="download_started", - payload={"source": source, "download_path": download_path}, - ) + self.dispatch(DownloadStartedEvent.build(source, download_path)) def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: - """ - Emit "download_progress" events at regular intervals during a download job. - - :param source: The downloaded source - :param download_path: The local downloaded file - :param current_bytes: Number of bytes downloaded so far - :param total_bytes: The size of the file being downloaded (if known) - """ - self.__emit_download_event( - event_name="download_progress", - payload={ - "source": source, - "download_path": download_path, - "current_bytes": current_bytes, - "total_bytes": total_bytes, - }, - ) + self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes)) def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: - """ - Emit a "download_complete" event at the end of a successful download. - - :param source: Source URL - :param download_path: Path to the locally downloaded file - :param total_bytes: The size of the downloaded file - """ - self.__emit_download_event( - event_name="download_complete", - payload={ - "source": source, - "download_path": download_path, - "total_bytes": total_bytes, - }, - ) + self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes)) def emit_download_cancelled(self, source: str) -> None: - """Emit a "download_cancelled" event in the event that the download was cancelled by user.""" - self.__emit_download_event( - event_name="download_cancelled", - payload={ - "source": source, - }, - ) + self.dispatch(DownloadCancelledEvent.build(source)) def emit_download_error(self, source: str, error_type: str, error: str) -> None: - """ - Emit a "download_error" event when an download job encounters an exception. + self.dispatch(DownloadErrorEvent.build(source, error_type, error)) - :param source: Source URL - :param error_type: The name of the exception that raised the error - :param error: The traceback from this error - """ - self.__emit_download_event( - event_name="download_error", - payload={ - "source": source, - "error_type": error_type, - "error": error, - }, - ) + # endregion - def emit_model_install_downloading( - self, - source: str, - local_path: str, - bytes: int, - total_bytes: int, - parts: List[Dict[str, Union[str, int]]], - id: int, - ) -> None: - """ - Emit at intervals while the install job is in progress (remote models only). + # region Model loading - :param source: Source of the model - :param local_path: Where model is downloading to - :param parts: Progress of downloading URLs that comprise the model, if any. - :param bytes: Number of bytes downloaded so far. - :param total_bytes: Total size of download, including all files. - This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes". - """ - self.__emit_model_event( - event_name="model_install_downloading", - payload={ - "source": source, - "local_path": local_path, - "bytes": bytes, - "total_bytes": total_bytes, - "parts": parts, - "id": id, - }, - ) + def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None: + self.dispatch(ModelLoadStartedEvent.build(config, submodel_type)) - def emit_model_install_downloads_done(self, source: str) -> None: - """ - Emit once when all parts are downloaded, but before the probing and registration start. + def emit_model_load_complete(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None: + self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_downloads_done", - payload={"source": source}, - ) + # endregion - def emit_model_install_running(self, source: str) -> None: - """ - Emit once when an install job becomes active. + # region Model install - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_running", - payload={"source": source}, - ) + def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallDownloadProgressEvent.build(job)) - def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: - """ - Emit when an install job is completed successfully. + def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallDownloadsCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - :param key: Model config record key - :param total_bytes: Size of the model (may be None for installation of a local path) - """ - self.__emit_model_event( - event_name="model_install_completed", - payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id}, - ) + def emit_model_install_started(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallStartedEvent.build(job)) - def emit_model_install_cancelled(self, source: str, id: int) -> None: - """ - Emit when an install job is cancelled. + def emit_model_install_complete(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_cancelled", - payload={"source": source, "id": id}, - ) + def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallCancelledEvent.build(job)) - def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: - """ - Emit when an install job encounters an exception. + def emit_model_install_error(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallErrorEvent.build(job)) - :param source: Source of the model - :param error_type: The name of the exception - :param error: A text description of the exception - """ - self.__emit_model_event( - event_name="model_install_error", - payload={"source": source, "error_type": error_type, "error": error, "id": id}, - ) + # endregion + + # region Bulk image download def emit_bulk_download_started( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download starts""" - self._emit_bulk_download_event( - event_name="bulk_download_started", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_completed( + def emit_bulk_download_complete( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download completes""" - self._emit_bulk_download_event( - event_name="bulk_download_completed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_failed( + def emit_bulk_download_error( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str ) -> None: - """Emitted when a bulk download fails""" - self._emit_bulk_download_event( - event_name="bulk_download_failed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - "error": error, - }, + self.dispatch( + BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) ) + + # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py new file mode 100644 index 0000000000..e0af97d121 --- /dev/null +++ b/invokeai/app/services/events/events_common.py @@ -0,0 +1,625 @@ +from abc import ABC +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar + +from fastapi_events.handlers.local import local_handler +from fastapi_events.registry.payload_schema import registry as payload_schema +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.services.session_queue.session_queue_common import ( + QUEUE_ITEM_STATUS, + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, +) +from invokeai.app.util.misc import get_timestamp +from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType + +if TYPE_CHECKING: + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + + +class EventType(str, Enum): + QUEUE = "queue" + MODEL = "model" + DOWNLOAD = "download" + BULK_IMAGE_DOWNLOAD = "bulk_image_download" + + +class BaseEvent(BaseModel, ABC): + """Base class for all events. All events must inherit from this class. + + Events must define the following class attributes: + - `__event_name__: str`: The name of the event + - `__event_type__: EventType`: The type of the event + + All other attributes should be defined as normal for a pydantic model. + + A timestamp is automatically added to the event when it is created. + """ + + __event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType] + __event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType] + + timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) + + def __init_subclass__(cls, **kwargs: ConfigDict): + for required_attr in ("__event_name__", "__event_type__"): + if getattr(cls, required_attr) is ...: + raise TypeError(f"{cls.__name__} must define {required_attr}") + + model_config = ConfigDict(json_schema_serialization_defaults_required=True) + + +TEvent = TypeVar("TEvent", bound=BaseEvent) + +FastAPIEvent: TypeAlias = tuple[str, TEvent] +""" +A tuple representing a `fastapi-events` event, with the event name and payload. +Provide a generic type to `TEvent` to specify the payload type. +""" + + +class FastAPIEventFunc(Protocol): + def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ... + + +def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None: + """Register a function to handle a list of events. + + :param events: A list of event classes to handle + :param func: The function to handle the events + """ + for event in events: + local_handler.register(event_name=event.__event_name__, _func=func) + + +class QueueEvent(BaseEvent, ABC): + """Base class for queue events""" + + __event_type__ = EventType.QUEUE + __event_name__ = "queue_event" + + queue_id: str = Field(description="The ID of the queue") + + +class QueueItemEvent(QueueEvent, ABC): + """Base class for queue item events""" + + __event_name__ = "queue_item_event" + + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + + +class SessionEvent(QueueItemEvent, ABC): + """Base class for session (aka graph execution state) events""" + + __event_name__ = "session_event" + + session_id: str = Field(description="The ID of the session (aka graph execution state)") + + +class InvocationEvent(SessionEvent, ABC): + """Base class for invocation events""" + + __event_name__ = "invocation_event" + + queue_id: str = Field(description="The ID of the queue") + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + session_id: str = Field(description="The ID of the session (aka graph execution state)") + invocation_id: str = Field(description="The ID of the invocation") + invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") + invocation_type: str = Field(description="The type of invocation") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationStartedEvent(InvocationEvent): + """Emitted when an invocation is started""" + + __event_name__ = "invocation_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationDenoiseProgressEvent(InvocationEvent): + """Emitted at each step during denoising of an invocation.""" + + __event_name__ = "invocation_denoise_progress" + + progress_image: ProgressImage = Field(description="The progress image sent at each step during processing") + step: int = Field(description="The current step of the invocation") + total_steps: int = Field(description="The total number of steps in the invocation") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + step: int, + total_steps: int, + progress_image: ProgressImage, + ) -> "InvocationDenoiseProgressEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + progress_image=progress_image, + step=step, + total_steps=total_steps, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationCompleteEvent(InvocationEvent): + """Emitted when an invocation is complete""" + + __event_name__ = "invocation_complete" + + result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + ) -> "InvocationCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + result=result, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationErrorEvent(InvocationEvent): + """Emitted when an invocation encounters an error""" + + __event_name__ = "invocation_error" + + error_type: str = Field(description="The type of error") + error: str = Field(description="The error message") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, error_type: str, error: str + ) -> "InvocationErrorEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + error_type=error_type, + error=error, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionStartedEvent(SessionEvent): + """Emitted when a session has started""" + + __event_name__ = "session_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionCompleteEvent(SessionEvent): + """Emitted when a session has completed all invocations""" + + __event_name__ = "session_complete" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionCanceledEvent(SessionEvent): + """Emitted when a session is canceled""" + + __event_name__ = "session_canceled" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class QueueItemStatusChangedEvent(QueueItemEvent): + """Emitted when a queue item's status changes""" + + __event_name__ = "queue_item_status_changed" + + status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item") + error: Optional[str] = Field(default=None, description="The error message, if any") + created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created") + updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated") + started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started") + completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed") + batch_status: BatchStatus = Field(description="The status of the batch") + queue_status: SessionQueueStatus = Field(description="The status of the queue") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus + ) -> "QueueItemStatusChangedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + status=queue_item.status, + error=queue_item.error, + created_at=str(queue_item.created_at) if queue_item.created_at else None, + updated_at=str(queue_item.updated_at) if queue_item.updated_at else None, + started_at=str(queue_item.started_at) if queue_item.started_at else None, + completed_at=str(queue_item.completed_at) if queue_item.completed_at else None, + batch_status=batch_status, + queue_status=queue_status, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BatchEnqueuedEvent(QueueEvent): + """Emitted when a batch is enqueued""" + + __event_name__ = "batch_enqueued" + + batch_id: str = Field(description="The ID of the batch") + enqueued: int = Field(description="The number of invocations enqueued") + requested: int = Field( + description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" + ) + priority: int = Field(description="The priority of the batch") + + @classmethod + def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + return cls( + queue_id=enqueue_result.queue_id, + batch_id=enqueue_result.batch.batch_id, + enqueued=enqueue_result.enqueued, + requested=enqueue_result.requested, + priority=enqueue_result.priority, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class QueueClearedEvent(QueueEvent): + """Emitted when a queue is cleared""" + + __event_name__ = "queue_cleared" + + @classmethod + def build(cls, queue_id: str) -> "QueueClearedEvent": + return cls(queue_id=queue_id) + + +class DownloadEvent(BaseEvent, ABC): + """Base class for events associated with a download""" + + __event_type__ = EventType.DOWNLOAD + __event_name__ = "download_event" + + source: str = Field(description="The source of the download") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadStartedEvent(DownloadEvent): + """Emitted when a download is started""" + + __event_name__ = "download_started" + + download_path: str = Field(description="The local path where the download is saved") + + @classmethod + def build(cls, source: str, download_path: str) -> "DownloadStartedEvent": + return cls(source=source, download_path=download_path) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadProgressEvent(DownloadEvent): + """Emitted at intervals during a download""" + + __event_name__ = "download_progress" + + download_path: str = Field(description="The local path where the download is saved") + current_bytes: int = Field(description="The number of bytes downloaded so far") + total_bytes: int = Field(description="The total number of bytes to be downloaded") + + @classmethod + def build(cls, source: str, download_path: str, current_bytes: int, total_bytes: int) -> "DownloadProgressEvent": + return cls(source=source, download_path=download_path, current_bytes=current_bytes, total_bytes=total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadCompleteEvent(DownloadEvent): + """Emitted when a download is completed""" + + __event_name__ = "download_complete" + + download_path: str = Field(description="The local path where the download is saved") + total_bytes: int = Field(description="The total number of bytes downloaded") + + @classmethod + def build(cls, source: str, download_path: str, total_bytes: int) -> "DownloadCompleteEvent": + return cls(source=source, download_path=download_path, total_bytes=total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadCancelledEvent(DownloadEvent): + """Emitted when a download is cancelled""" + + __event_name__ = "download_cancelled" + + @classmethod + def build(cls, source: str) -> "DownloadCancelledEvent": + return cls(source=source) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadErrorEvent(DownloadEvent): + """Emitted when a download encounters an error""" + + __event_name__ = "download_error" + + error_type: str = Field(description="The type of error") + error: str = Field(description="The error message") + + @classmethod + def build(cls, source: str, error_type: str, error: str) -> "DownloadErrorEvent": + return cls(source=source, error_type=error_type, error=error) + + +class ModelEvent(BaseEvent, ABC): + """Base class for events associated with a model""" + + __event_type__ = EventType.MODEL + __event_name__ = "model_event" + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelLoadStartedEvent(ModelEvent): + """Emitted when a model is requested""" + + __event_name__ = "model_load_started" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelLoadCompleteEvent(ModelEvent): + """Emitted when a model is requested""" + + __event_name__ = "model_load_complete" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallDownloadProgressEvent(ModelEvent): + """Emitted at intervals while the install job is in progress (remote models only).""" + + __event_name__ = "model_install_download_progress" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent": + parts: list[dict[str, str | int]] = [ + { + "url": str(x.source), + "local_path": str(x.download_path), + "bytes": x.bytes, + "total_bytes": x.total_bytes, + } + for x in job.download_parts + ] + return cls( + id=job.id, + source=str(job.source), + local_path=job.local_path.as_posix(), + parts=parts, + bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallDownloadsCompleteEvent(ModelEvent): + """Emitted once when an install job becomes active.""" + + __event_name__ = "model_install_downloads_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallStartedEvent(ModelEvent): + """Emitted once when an install job becomes active.""" + + __event_name__ = "model_install_started" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallCompleteEvent(ModelEvent): + """Emitted when an install job is completed successfully.""" + + __event_name__ = "model_install_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + key: str = Field(description="Model config record key") + total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent": + assert job.config_out is not None + return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallCancelledEvent(ModelEvent): + """Emitted when an install job is cancelled.""" + + __event_name__ = "model_install_cancelled" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallErrorEvent(ModelEvent): + """Emitted when an install job encounters an exception.""" + + __event_name__ = "model_install_error" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + error_type: str = Field(description="The name of the exception") + error: str = Field(description="A text description of the exception") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent": + assert job.error_type is not None + assert job.error is not None + return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error) + + +class BulkDownloadEvent(BaseEvent, ABC): + """Base class for events associated with a bulk image download""" + + __event_type__ = EventType.BULK_IMAGE_DOWNLOAD + __event_name__ = "bulk_image_download_event" + + bulk_download_id: str = Field(description="The ID of the bulk image download") + bulk_download_item_id: str = Field(description="The ID of the bulk image download item") + bulk_download_item_name: str = Field(description="The name of the bulk image download item") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BulkDownloadStartedEvent(BulkDownloadEvent): + """Emitted when a bulk image download is started""" + + __event_name__ = "bulk_download_started" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadStartedEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BulkDownloadCompleteEvent(BulkDownloadEvent): + """Emitted when a bulk image download is started""" + + __event_name__ = "bulk_download_complete" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadCompleteEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BulkDownloadErrorEvent(BulkDownloadEvent): + """Emitted when a bulk image download is started""" + + __event_name__ = "bulk_download_error" + + error: str = Field(description="The error message") + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + ) -> "BulkDownloadErrorEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + error=error, + ) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py new file mode 100644 index 0000000000..a8317911cf --- /dev/null +++ b/invokeai/app/services/events/events_fastapievents.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +import threading +from queue import Empty, Queue + +from fastapi_events.dispatcher import dispatch + +from invokeai.app.services.events.events_common import ( + BaseEvent, +) + +from .events_base import EventServiceBase + + +class FastAPIEventService(EventServiceBase): + def __init__(self, event_handler_id: int) -> None: + self.event_handler_id = event_handler_id + self._queue = Queue[BaseEvent | None]() + self._stop_event = threading.Event() + asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) + + super().__init__() + + def stop(self, *args, **kwargs): + self._stop_event.set() + self._queue.put(None) + + def dispatch(self, event: BaseEvent) -> None: + self._queue.put(event) + + async def _dispatch_from_queue(self, stop_event: threading.Event): + """Get events on from the queue and dispatch them, from the correct thread""" + while not stop_event.is_set(): + try: + event = self._queue.get(block=False) + if not event: # Probably stopping + continue + dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False) + + except Empty: + await asyncio.sleep(0.1) + pass + + except asyncio.CancelledError as e: + raise e # Raise a proper error diff --git a/invokeai/app/services/model_install/__init__.py b/invokeai/app/services/model_install/__init__.py index 00a33c203e..941485a134 100644 --- a/invokeai/app/services/model_install/__init__.py +++ b/invokeai/app/services/model_install/__init__.py @@ -1,11 +1,13 @@ """Initialization file for model install service package.""" from .model_install_base import ( + ModelInstallServiceBase, +) +from .model_install_common import ( HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, UnknownInstallJobException, URLModelSource, diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 0ea901fb46..6ee671062d 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -1,244 +1,19 @@ # Copyright 2023 Lincoln D. Stein and the InvokeAI development team """Baseclass definitions for the model installer.""" -import re -import traceback from abc import ABC, abstractmethod -from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Set, Union +from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.networks import AnyHttpUrl -from typing_extensions import Annotated from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.config import ModelSourceType -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - - -class InstallStatus(str, Enum): - """State of an install job running in the background.""" - - WAITING = "waiting" # waiting to be dequeued - DOWNLOADING = "downloading" # downloading of model files in process - DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run - RUNNING = "running" # being processed - COMPLETED = "completed" # finished running - ERROR = "error" # terminated with an error message - CANCELLED = "cancelled" # terminated with an error message - - -class ModelInstallPart(BaseModel): - url: AnyHttpUrl - path: Path - bytes: int = 0 - total_bytes: int = 0 - - -class UnknownInstallJobException(Exception): - """Raised when the status of an unknown job is requested.""" - - -class StringLikeSource(BaseModel): - """ - Base class for model sources, implements functions that lets the source be sorted and indexed. - - These shenanigans let this stuff work: - - source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') - mydict = {source1: 'model 1'} - assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' - assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' - - source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) - assert source1 == source2 - assert source1 == 'C:/users/mort/foo.safetensors' - """ - - def __hash__(self) -> int: - """Return hash of the path field, for indexing.""" - return hash(str(self)) - - def __lt__(self, other: object) -> int: - """Return comparison of the stringified version, for sorting.""" - return str(self) < str(other) - - def __eq__(self, other: object) -> bool: - """Return equality on the stringified version.""" - if isinstance(other, Path): - return str(self) == other.as_posix() - else: - return str(self) == str(other) - - -class LocalModelSource(StringLikeSource): - """A local file or directory path.""" - - path: str | Path - inplace: Optional[bool] = False - type: Literal["local"] = "local" - - # these methods allow the source to be used in a string-like way, - # for example as an index into a dict - def __str__(self) -> str: - """Return string version of path when string rep needed.""" - return Path(self.path).as_posix() - - -class HFModelSource(StringLikeSource): - """ - A HuggingFace repo_id with optional variant, sub-folder and access token. - Note that the variant option, if not provided to the constructor, will default to fp16, which is - what people (almost) always want. - """ - - repo_id: str - variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 - subfolder: Optional[Path] = None - access_token: Optional[str] = None - type: Literal["hf"] = "hf" - - @field_validator("repo_id") - @classmethod - def proper_repo_id(cls, v: str) -> str: # noqa D102 - if not re.match(r"^([.\w-]+/[.\w-]+)$", v): - raise ValueError(f"{v}: invalid repo_id format") - return v - - def __str__(self) -> str: - """Return string version of repoid when string rep needed.""" - base: str = self.repo_id - if self.variant: - base += f":{self.variant or ''}" - if self.subfolder: - base += f":{self.subfolder}" - return base - - -class URLModelSource(StringLikeSource): - """A generic URL point to a checkpoint file.""" - - url: AnyHttpUrl - access_token: Optional[str] = None - type: Literal["url"] = "url" - - def __str__(self) -> str: - """Return string version of the url when string rep needed.""" - return str(self.url) - - -ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] - -MODEL_SOURCE_TO_TYPE_MAP = { - URLModelSource: ModelSourceType.Url, - HFModelSource: ModelSourceType.HFRepoID, - LocalModelSource: ModelSourceType.Path, -} - - -class ModelInstallJob(BaseModel): - """Object that tracks the current status of an install request.""" - - id: int = Field(description="Unique ID for this job") - status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") - error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") - config_in: Dict[str, Any] = Field( - default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." - ) - config_out: Optional[AnyModelConfig] = Field( - default=None, description="After successful installation, this will hold the configuration object." - ) - inplace: bool = Field( - default=False, description="Leave model in its current location; otherwise install under models directory" - ) - source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") - local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") - bytes: int = Field( - default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" - ) - total_bytes: int = Field(default=0, description="Total size of the model to be installed") - source_metadata: Optional[AnyModelRepoMetadata] = Field( - default=None, description="Metadata provided by the model source" - ) - download_parts: Set[DownloadJob] = Field( - default_factory=set, description="Download jobs contributing to this install" - ) - error: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the text of the exception" - ) - error_traceback: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the exception traceback" - ) - # internal flags and transitory settings - _install_tmpdir: Optional[Path] = PrivateAttr(default=None) - _exception: Optional[Exception] = PrivateAttr(default=None) - - def set_error(self, e: Exception) -> None: - """Record the error and traceback from an exception.""" - self._exception = e - self.error = str(e) - self.error_traceback = self._format_error(e) - self.status = InstallStatus.ERROR - self.error_reason = self._exception.__class__.__name__ if self._exception else None - - def cancel(self) -> None: - """Call to cancel the job.""" - self.status = InstallStatus.CANCELLED - - @property - def error_type(self) -> Optional[str]: - """Class name of the exception that led to status==ERROR.""" - return self._exception.__class__.__name__ if self._exception else None - - def _format_error(self, exception: Exception) -> str: - """Error traceback.""" - return "".join(traceback.format_exception(exception)) - - @property - def cancelled(self) -> bool: - """Set status to CANCELLED.""" - return self.status == InstallStatus.CANCELLED - - @property - def errored(self) -> bool: - """Return true if job has errored.""" - return self.status == InstallStatus.ERROR - - @property - def waiting(self) -> bool: - """Return true if job is waiting to run.""" - return self.status == InstallStatus.WAITING - - @property - def downloading(self) -> bool: - """Return true if job is downloading.""" - return self.status == InstallStatus.DOWNLOADING - - @property - def downloads_done(self) -> bool: - """Return true if job's downloads ae done.""" - return self.status == InstallStatus.DOWNLOADS_DONE - - @property - def running(self) -> bool: - """Return true if job is running.""" - return self.status == InstallStatus.RUNNING - - @property - def complete(self) -> bool: - """Return true if job completed without errors.""" - return self.status == InstallStatus.COMPLETED - - @property - def in_terminal_state(self) -> bool: - """Return true if job is in a terminal state.""" - return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] +from invokeai.backend.model_manager.config import AnyModelConfig class ModelInstallServiceBase(ABC): @@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC): @property @abstractmethod - def event_bus(self) -> Optional[EventServiceBase]: + def event_bus(self) -> Optional["EventServiceBase"]: """Return the event service base object associated with the installer.""" @abstractmethod diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py new file mode 100644 index 0000000000..2de1db5474 --- /dev/null +++ b/invokeai/app/services/model_install/model_install_common.py @@ -0,0 +1,231 @@ +import re +import traceback +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Set, Union + +from pydantic import BaseModel, Field, PrivateAttr, field_validator +from pydantic.networks import AnyHttpUrl +from typing_extensions import Annotated + +from invokeai.app.services.download import DownloadJob +from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant +from invokeai.backend.model_manager.config import ModelSourceType +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class InstallStatus(str, Enum): + """State of an install job running in the background.""" + + WAITING = "waiting" # waiting to be dequeued + DOWNLOADING = "downloading" # downloading of model files in process + DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run + RUNNING = "running" # being processed + COMPLETED = "completed" # finished running + ERROR = "error" # terminated with an error message + CANCELLED = "cancelled" # terminated with an error message + + +class ModelInstallPart(BaseModel): + url: AnyHttpUrl + path: Path + bytes: int = 0 + total_bytes: int = 0 + + +class UnknownInstallJobException(Exception): + """Raised when the status of an unknown job is requested.""" + + +class StringLikeSource(BaseModel): + """ + Base class for model sources, implements functions that lets the source be sorted and indexed. + + These shenanigans let this stuff work: + + source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') + mydict = {source1: 'model 1'} + assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' + assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' + + source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) + assert source1 == source2 + assert source1 == 'C:/users/mort/foo.safetensors' + """ + + def __hash__(self) -> int: + """Return hash of the path field, for indexing.""" + return hash(str(self)) + + def __lt__(self, other: object) -> int: + """Return comparison of the stringified version, for sorting.""" + return str(self) < str(other) + + def __eq__(self, other: object) -> bool: + """Return equality on the stringified version.""" + if isinstance(other, Path): + return str(self) == other.as_posix() + else: + return str(self) == str(other) + + +class LocalModelSource(StringLikeSource): + """A local file or directory path.""" + + path: str | Path + inplace: Optional[bool] = False + type: Literal["local"] = "local" + + # these methods allow the source to be used in a string-like way, + # for example as an index into a dict + def __str__(self) -> str: + """Return string version of path when string rep needed.""" + return Path(self.path).as_posix() + + +class HFModelSource(StringLikeSource): + """ + A HuggingFace repo_id with optional variant, sub-folder and access token. + Note that the variant option, if not provided to the constructor, will default to fp16, which is + what people (almost) always want. + """ + + repo_id: str + variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 + subfolder: Optional[Path] = None + access_token: Optional[str] = None + type: Literal["hf"] = "hf" + + @field_validator("repo_id") + @classmethod + def proper_repo_id(cls, v: str) -> str: # noqa D102 + if not re.match(r"^([.\w-]+/[.\w-]+)$", v): + raise ValueError(f"{v}: invalid repo_id format") + return v + + def __str__(self) -> str: + """Return string version of repoid when string rep needed.""" + base: str = self.repo_id + base += f":{self.variant or ''}" + base += f":{self.subfolder}" if self.subfolder else "" + return base + + +class URLModelSource(StringLikeSource): + """A generic URL point to a checkpoint file.""" + + url: AnyHttpUrl + access_token: Optional[str] = None + type: Literal["url"] = "url" + + def __str__(self) -> str: + """Return string version of the url when string rep needed.""" + return str(self.url) + + +ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] + +MODEL_SOURCE_TO_TYPE_MAP = { + URLModelSource: ModelSourceType.Url, + HFModelSource: ModelSourceType.HFRepoID, + LocalModelSource: ModelSourceType.Path, +} + + +class ModelInstallJob(BaseModel): + """Object that tracks the current status of an install request.""" + + id: int = Field(description="Unique ID for this job") + status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") + error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") + config_in: Dict[str, Any] = Field( + default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." + ) + config_out: Optional[AnyModelConfig] = Field( + default=None, description="After successful installation, this will hold the configuration object." + ) + inplace: bool = Field( + default=False, description="Leave model in its current location; otherwise install under models directory" + ) + source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") + local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") + bytes: int = Field( + default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" + ) + total_bytes: int = Field(default=0, description="Total size of the model to be installed") + source_metadata: Optional[AnyModelRepoMetadata] = Field( + default=None, description="Metadata provided by the model source" + ) + download_parts: Set[DownloadJob] = Field( + default_factory=set, description="Download jobs contributing to this install" + ) + error: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the text of the exception" + ) + error_traceback: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the exception traceback" + ) + # internal flags and transitory settings + _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _exception: Optional[Exception] = PrivateAttr(default=None) + + def set_error(self, e: Exception) -> None: + """Record the error and traceback from an exception.""" + self._exception = e + self.error = str(e) + self.error_traceback = self._format_error(e) + self.status = InstallStatus.ERROR + self.error_reason = self._exception.__class__.__name__ if self._exception else None + + def cancel(self) -> None: + """Call to cancel the job.""" + self.status = InstallStatus.CANCELLED + + @property + def error_type(self) -> Optional[str]: + """Class name of the exception that led to status==ERROR.""" + return self._exception.__class__.__name__ if self._exception else None + + def _format_error(self, exception: Exception) -> str: + """Error traceback.""" + return "".join(traceback.format_exception(exception)) + + @property + def cancelled(self) -> bool: + """Set status to CANCELLED.""" + return self.status == InstallStatus.CANCELLED + + @property + def errored(self) -> bool: + """Return true if job has errored.""" + return self.status == InstallStatus.ERROR + + @property + def waiting(self) -> bool: + """Return true if job is waiting to run.""" + return self.status == InstallStatus.WAITING + + @property + def downloading(self) -> bool: + """Return true if job is downloading.""" + return self.status == InstallStatus.DOWNLOADING + + @property + def downloads_done(self) -> bool: + """Return true if job's downloads ae done.""" + return self.status == InstallStatus.DOWNLOADS_DONE + + @property + def running(self) -> bool: + """Return true if job is running.""" + return self.status == InstallStatus.RUNNING + + @property + def complete(self) -> bool: + """Return true if job completed without errors.""" + return self.status == InstallStatus.COMPLETED + + @property + def in_terminal_state(self) -> bool: + """Return true if job is in a terminal state.""" + return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 6eb9549ef0..df060caff3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -10,7 +10,7 @@ from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch import yaml @@ -20,8 +20,8 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.config import ( @@ -45,13 +45,12 @@ from invokeai.backend.util import InvokeAILogger from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.devices import TorchDevice -from .model_install_base import ( +from .model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, StringLikeSource, URLModelSource, @@ -59,6 +58,9 @@ from .model_install_base import ( TMPDIR_PREFIX = "tmpinstall_" +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + class ModelInstallService(ModelInstallServiceBase): """class for InvokeAI model installation.""" @@ -68,7 +70,7 @@ class ModelInstallService(ModelInstallServiceBase): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, session: Optional[Session] = None, ): """ @@ -104,7 +106,7 @@ class ModelInstallService(ModelInstallServiceBase): return self._record_store @property - def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 + def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102 return self._event_bus # make the invoker optional here because we don't need it and it @@ -855,35 +857,17 @@ class ModelInstallService(ModelInstallServiceBase): job.status = InstallStatus.RUNNING self._logger.info(f"Model install started: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_running(str(job.source)) + self._event_bus.emit_model_install_started(job) def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: - parts: List[Dict[str, str | int]] = [ - { - "url": str(x.source), - "local_path": str(x.download_path), - "bytes": x.bytes, - "total_bytes": x.total_bytes, - } - for x in job.download_parts - ] - assert job.bytes is not None - assert job.total_bytes is not None - self._event_bus.emit_model_install_downloading( - str(job.source), - local_path=job.local_path.as_posix(), - parts=parts, - bytes=job.bytes, - total_bytes=job.total_bytes, - id=job.id, - ) + self._event_bus.emit_model_install_download_progress(job) def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: job.status = InstallStatus.DOWNLOADS_DONE self._logger.info(f"Model download complete: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_downloads_done(str(job.source)) + self._event_bus.emit_model_install_downloads_complete(job) def _signal_job_completed(self, job: ModelInstallJob) -> None: job.status = InstallStatus.COMPLETED @@ -891,24 +875,19 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"Model install complete: {job.source}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") if self._event_bus: - assert job.local_path is not None - assert job.config_out is not None - key = job.config_out.key - self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id) + self._event_bus.emit_model_install_complete(job) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") if self._event_bus: - error_type = job.error_type - error = job.error - assert error_type is not None - assert error is not None - self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id) + assert job.error_type is not None + assert job.error is not None + self._event_bus.emit_model_install_error(job) def _signal_job_cancelled(self, job: ModelInstallJob) -> None: self._logger.info(f"Model install canceled: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) + self._event_bus.emit_model_install_cancelled(job) @staticmethod def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index cc80333e93..9d75aafde1 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context_data: Invocation context data used for event reporting """ @property diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 21d3c56f36..1d6423af5a 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -5,7 +5,6 @@ from typing import Optional, Type from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, @@ -51,25 +50,15 @@ class ModelLoadService(ModelLoadServiceBase): """Return the checkpoint convert cache used by this loader.""" return self._convert_cache - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting """ - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - ) + + self._invoker.services.events.emit_model_load_started(model_config, submodel_type) implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore loaded_model: LoadedModel = implementation( @@ -79,40 +68,6 @@ class ModelLoadService(ModelLoadServiceBase): convert_cache=self._convert_cache, ).load_model(model_config, submodel_type) - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - loaded=True, - ) + self._invoker.services.events.emit_model_load_started(model_config, submodel_type) + return loaded_model - - def _emit_load_event( - self, - context_data: InvocationContextData, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - submodel_type: Optional[SubModelType] = None, - ) -> None: - if not self._invoker: - return - - if not loaded: - self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) - else: - self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 894996b1e6..0b4e4927e6 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -4,11 +4,15 @@ from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from typing import Optional -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - from invokeai.app.invocations.baseinvocation import BaseInvocation -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + FastAPIEvent, + QueueClearedEvent, + QueueEvent, + SessionCanceledEvent, + register_events, +) from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem @@ -31,8 +35,6 @@ class DefaultSessionProcessor(SessionProcessorBase): self._poll_now_event = ThreadEvent() self._cancel_event = ThreadEvent() - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self._thread_limit = thread_limit self._thread_semaphore = BoundedSemaphore(thread_limit) self._polling_interval = polling_interval @@ -49,6 +51,8 @@ class DefaultSessionProcessor(SessionProcessorBase): else None ) + register_events({SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent}, self._on_queue_event) + self._thread = Thread( name="session_processor", target=self._process, @@ -67,24 +71,23 @@ class DefaultSessionProcessor(SessionProcessorBase): def _poll_now(self) -> None: self._poll_now_event.set() - async def _on_queue_event(self, event: FastAPIEvent) -> None: - event_name = event[1]["event"] - + async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None: + _event_name, payload = event if ( - event_name == "session_canceled" + isinstance(payload, SessionCanceledEvent) and self._queue_item - and self._queue_item.item_id == event[1]["data"]["queue_item_id"] + and self._queue_item.item_id == payload.item_id ): self._cancel_event.set() self._poll_now() elif ( - event_name == "queue_cleared" + isinstance(payload, QueueClearedEvent) and self._queue_item - and self._queue_item.queue_id == event[1]["data"]["queue_id"] + and self._queue_item.queue_id == payload.queue_id ): self._cancel_event.set() self._poll_now() - elif event_name == "batch_enqueued": + elif isinstance(payload, BatchEnqueuedEvent): self._poll_now() elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [ "completed", @@ -139,6 +142,7 @@ class DefaultSessionProcessor(SessionProcessorBase): poll_now_event.wait(self._polling_interval) continue + self._invoker.services.events.emit_session_started(self._queue_item) self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() @@ -154,15 +158,7 @@ class DefaultSessionProcessor(SessionProcessorBase): # get the source node id to provide to clients (the prepared node id is not as useful) source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] - # Send starting event - self._invoker.services.events.emit_invocation_started( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - ) + self._invoker.services.events.emit_invocation_started(self._queue_item, self._invocation) # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph try: @@ -189,15 +185,9 @@ class DefaultSessionProcessor(SessionProcessorBase): # Save outputs and history self._queue_item.session.complete(self._invocation.id, outputs) - # Send complete event + # Dispatch invocation complete event self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), + self._queue_item, self._invocation, outputs ) except KeyboardInterrupt: @@ -229,12 +219,8 @@ class DefaultSessionProcessor(SessionProcessorBase): # Send error event self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, + queue_item=self._queue_item, + invocation=self._invocation, error_type=e.__class__.__name__, error=error, user_id=None, @@ -245,12 +231,7 @@ class DefaultSessionProcessor(SessionProcessorBase): # The session is complete if the all invocations are complete or there was an error if self._queue_item.session.is_complete() or cancel_event.is_set(): # Send complete event - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - ) + self._invoker.services.events.emit_session_complete(self._queue_item) # If we are profiling, stop the profiler and dump the profile & stats if self._profiler: profile_path = self._profiler.stop() diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index ffcd7c40ca..7696e3bd04 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -2,10 +2,13 @@ import sqlite3 import threading from typing import Optional, Union, cast -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + FastAPIEvent, + InvocationErrorEvent, + SessionCanceledEvent, + SessionCompleteEvent, + register_events, +) from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_common import ( @@ -41,7 +44,11 @@ class SqliteSessionQueue(SessionQueueBase): self.__invoker = invoker self._set_in_progress_to_canceled() prune_result = self.prune(DEFAULT_QUEUE_ID) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) + + register_events(events={InvocationErrorEvent}, func=self._handle_error_event) + register_events(events={SessionCompleteEvent}, func=self._handle_complete_event) + register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event) + if prune_result.deleted > 0: self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") @@ -51,51 +58,35 @@ class SqliteSessionQueue(SessionQueueBase): self.__conn = db.conn self.__cursor = self.__conn.cursor() - def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: - return event[1]["event"] in match_in - - async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent: - event_name = event[1]["event"] - - # This was a match statement, but match is not supported on python 3.9 - if event_name == "graph_execution_state_complete": - await self._handle_complete_event(event) - elif event_name == "invocation_error": - await self._handle_error_event(event) - elif event_name == "session_canceled": - await self._handle_cancel_event(event) - return event - - async def _handle_complete_event(self, event: FastAPIEvent) -> None: + async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] # When a queue item has an error, we get an error event, then a completed event. # Mark the queue item completed only if it isn't already marked completed, e.g. # by a previously-handled error event. - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed") - except SessionQueueItemNotFoundError: - return + _event_name, payload = event - async def _handle_error_event(self, event: FastAPIEvent) -> None: + queue_item = self.get_queue_item(payload.item_id) + if queue_item.status not in ["completed", "failed", "canceled"]: + self._set_queue_item_status(item_id=payload.item_id, status="completed") + except SessionQueueItemNotFoundError: + pass + + async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] - error = event[1]["data"]["error"] - queue_item = self.get_queue_item(item_id) + _event_name, payload = event # always set to failed if have an error, even if previously the item was marked completed or canceled - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error) + self._set_queue_item_status(item_id=payload.item_id, status="failed", error=payload.error) except SessionQueueItemNotFoundError: - return + pass - async def _handle_cancel_event(self, event: FastAPIEvent) -> None: + async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] - queue_item = self.get_queue_item(item_id) + _event_name, payload = event + queue_item = self.get_queue_item(payload.item_id) if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled") + self._set_queue_item_status(item_id=payload.item_id, status="canceled") except SessionQueueItemNotFoundError: - return + pass def _set_in_progress_to_canceled(self) -> None: """ @@ -292,11 +283,7 @@ class SqliteSessionQueue(SessionQueueBase): queue_item = self.get_queue_item(item_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=queue_item, - batch_status=batch_status, - queue_status=queue_status, - ) + self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) return queue_item def is_empty(self, queue_id: str) -> IsEmptyResult: @@ -429,12 +416,7 @@ class SqliteSessionQueue(SessionQueueBase): if queue_item.status not in ["canceled", "failed", "completed"]: status = "failed" if error is not None else "canceled" queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here - self.__invoker.services.events.emit_session_canceled( - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - queue_batch_id=queue_item.batch_id, - graph_execution_state_id=queue_item.session_id, - ) + self.__invoker.services.events.emit_session_canceled(queue_item) return queue_item def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: @@ -470,18 +452,11 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) + self.__invoker.services.events.emit_session_canceled(current_queue_item) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() @@ -521,18 +496,11 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) + self.__invoker.services.events.emit_session_canceled(current_queue_item) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index de31a42665..f977a5ba9c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -353,11 +353,11 @@ class ModelsInterface(InvocationContextInterface): if isinstance(identifier, str): model = self._services.model_manager.store.get_model(identifier) - return self._services.model_manager.load.load_model(model, submodel_type, self._data) + return self._services.model_manager.load.load_model(model, submodel_type) else: _submodel_type = submodel_type or identifier.submodel_type model = self._services.model_manager.store.get_model(identifier.key) - return self._services.model_manager.load.load_model(model, _submodel_type, self._data) + return self._services.model_manager.load.load_model(model, _submodel_type) def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None @@ -382,7 +382,7 @@ class ModelsInterface(InvocationContextInterface): if len(configs) > 1: raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") - return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) + return self._services.model_manager.load.load_model(configs[0], submodel_type) def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: """Gets a model's config. diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 8cb59f5b3a..1bbd6bc8d0 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -113,15 +113,10 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - events.emit_generator_progress( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - node_id=context_data.invocation.id, - source_node_id=context_data.source_invocation_id, - progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), - step=intermediate_state.step, - order=intermediate_state.order, - total_steps=intermediate_state.total_steps, + events.emit_invocation_denoise_progress( + context_data.queue_item, + context_data.invocation, + intermediate_state.step, + intermediate_state.total_steps * intermediate_state.order, + ProgressImage(dataURL=dataURL, width=width, height=height), ) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index c755d3c491..cf7ebe8d29 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -14,10 +14,12 @@ from pydantic_core import Url from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ( + ModelInstallServiceBase, +) +from invokeai.app.services.model_install.model_install_common import ( InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, URLModelSource, ) from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException