diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index aa16974a8f..4e8103d8d3 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService +from ..services.events.events_fastapievents import FastAPIEventService from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService @@ -33,7 +34,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage -from .events import FastAPIEventService # TODO: is there a better way to achieve this? diff --git a/invokeai/app/api/events.py b/invokeai/app/api/events.py deleted file mode 100644 index 2ac07e6dfe..0000000000 --- a/invokeai/app/api/events.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import asyncio -import threading -from queue import Empty, Queue -from typing import Any - -from fastapi_events.dispatcher import dispatch - -from ..services.events.events_base import EventServiceBase - - -class FastAPIEventService(EventServiceBase): - event_handler_id: int - __queue: Queue - __stop_event: threading.Event - - def __init__(self, event_handler_id: int) -> None: - self.event_handler_id = event_handler_id - self.__queue = Queue() - self.__stop_event = threading.Event() - asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event)) - - super().__init__() - - def stop(self, *args, **kwargs): - self.__stop_event.set() - self.__queue.put(None) - - def dispatch(self, event_name: str, payload: Any) -> None: - self.__queue.put({"event_name": event_name, "payload": payload}) - - async def __dispatch_from_queue(self, stop_event: threading.Event): - """Get events on from the queue and dispatch them, from the correct thread""" - while not stop_event.is_set(): - try: - event = self.__queue.get(block=False) - if not event: # Probably stopping - continue - - dispatch( - event.get("event_name"), - payload=event.get("payload"), - middleware_id=self.event_handler_id, - ) - - except Empty: - await asyncio.sleep(0.1) - pass - - except asyncio.CancelledError as e: - raise e # Raise a proper error diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 1ba3e30e07..b1221f7a34 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException -from invokeai.app.services.model_install import ModelInstallJob +from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 463545d9bc..972d222f47 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -1,66 +1,119 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from typing import Any + from fastapi import FastAPI -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event +from pydantic import BaseModel from socketio import ASGIApp, AsyncServer -from ..services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadEvent, + BulkDownloadStartedEvent, + FastAPIEvent, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelEvent, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueEvent, + QueueItemStatusChangedEvent, + SessionCanceledEvent, + SessionCompleteEvent, + SessionStartedEvent, + register_events, +) + + +class QueueSubscriptionEvent(BaseModel): + queue_id: str + + +class BulkDownloadSubscriptionEvent(BaseModel): + bulk_download_id: str class SocketIO: - __sio: AsyncServer - __app: ASGIApp + _sub_queue = "subscribe_queue" + _unsub_queue = "unsubscribe_queue" - __sub_queue: str = "subscribe_queue" - __unsub_queue: str = "unsubscribe_queue" - - __sub_bulk_download: str = "subscribe_bulk_download" - __unsub_bulk_download: str = "unsubscribe_bulk_download" + _sub_bulk_download = "subscribe_bulk_download" + _unsub_bulk_download = "unsubscribe_bulk_download" def __init__(self, app: FastAPI): - self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") - self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") - app.mount("/ws", self.__app) + self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") + self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io") + app.mount("/ws", self._app) - self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) - self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) - local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) + self._sio.on(self._sub_queue, handler=self._handle_sub_queue) + self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue) + self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) + self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) - self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) - self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) - local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) - - async def _handle_queue_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["queue_id"], + register_events( + { + InvocationStartedEvent, + InvocationDenoiseProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + SessionStartedEvent, + SessionCompleteEvent, + SessionCanceledEvent, + QueueItemStatusChangedEvent, + BatchEnqueuedEvent, + QueueClearedEvent, + }, + self._handle_queue_event, ) - async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.enter_room(sid, data["queue_id"]) - - async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.leave_room(sid, data["queue_id"]) - - async def _handle_model_event(self, event: Event) -> None: - await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) - - async def _handle_bulk_download_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["bulk_download_id"], + register_events( + { + ModelLoadStartedEvent, + ModelLoadCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ModelInstallCancelledEvent, + ModelInstallErrorEvent, + }, + self._handle_model_event, ) - async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.enter_room(sid, data["bulk_download_id"]) + register_events( + {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}, + self._handle_bulk_image_download_event, + ) - async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.leave_room(sid, data["bulk_download_id"]) + async def _handle_sub_queue(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) + + async def _handle_unsub_queue(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) + + async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) + + async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) + + async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]): + event_name, payload = event + await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id) + + async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None: + event_name, payload = event + await self._sio.emit(event=event_name, data=payload.model_dump()) + + async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None: + event_name, payload = event + await self._sio.emit(event=event_name, data=payload.model_dump()) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 062682f7d0..710c878c9b 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -5,7 +5,7 @@ import socket from contextlib import asynccontextmanager from inspect import signature from pathlib import Path -from typing import Any +from typing import Any, cast import torch import uvicorn @@ -17,6 +17,8 @@ from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware +from fastapi_events.registry.payload_schema import registry as fastapi_events_registry +from pydantic import BaseModel from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available @@ -182,23 +184,16 @@ def custom_openapi() -> dict[str, Any]: openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) invoker_schema["class"] = "invocation" - # This code no longer seems to be necessary? - # Leave it here just in case - # - # from invokeai.backend.model_manager import get_model_config_formats - # formats = get_model_config_formats() - # for model_config_name, enum_set in formats.items(): - - # if model_config_name in openapi_schema["components"]["schemas"]: - # # print(f"Config with name {name} already defined") - # continue - - # openapi_schema["components"]["schemas"][model_config_name] = { - # "title": model_config_name, - # "description": "An enumeration.", - # "type": "string", - # "enum": [v.value for v in enum_set], - # } + # Add all pydantic event schemas registered with fastapi-events + for payload in fastapi_events_registry.data.values(): + json_schema = cast(BaseModel, payload).model_json_schema( + mode="serialization", ref_template="#/components/schemas/{model}" + ) + if "$defs" in json_schema: + for schema_key, schema in json_schema["$defs"].items(): + openapi_schema["components"]["schemas"][schema_key] = schema + del json_schema["$defs"] + openapi_schema["components"]["schemas"][payload.__name__] = json_schema app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 04cec928f4..d4bf059b8f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None self._invoker.services.events.emit_bulk_download_started( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_completed( @@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None - self._invoker.services.events.emit_bulk_download_completed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + self._invoker.services.events.emit_bulk_download_complete( + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_failed( @@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert exception is not None - self._invoker.services.events.emit_bulk_download_failed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, - error=str(exception), + self._invoker.services.events.emit_bulk_download_error( + bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception) ) def stop(self, *args, **kwargs): diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7d8229fba1..4c9d9bda13 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,14 +8,13 @@ import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError from tqdm import tqdm -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp from invokeai.backend.util.logging import InvokeAILogger @@ -30,6 +29,9 @@ from .download_base import ( UnknownJobIDException, ) +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + # Maximum number of bytes to download during each call to requests.iter_content() DOWNLOAD_CHUNK_SIZE = 100000 @@ -40,7 +42,7 @@ class DownloadQueueService(DownloadQueueServiceBase): def __init__( self, max_parallel_dl: int = 5, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 1ddda2921d..b489f29fd7 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -1,494 +1,188 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Optional -from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.app.services.session_queue.session_queue_common import ( - BatchStatus, - EnqueueBatchResult, - SessionQueueItem, - SessionQueueStatus, +from invokeai.app.services.events.events_common import ( + BaseEvent, + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadStartedEvent, + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueItemStatusChangedEvent, + SessionCanceledEvent, + SessionCompleteEvent, + SessionStartedEvent, ) -from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import AnyModelConfig -from invokeai.backend.model_manager.config import SubModelType + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput + from invokeai.app.services.events.events_common import BaseEvent + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + from invokeai.app.services.session_processor.session_processor_common import ProgressImage + from invokeai.app.services.session_queue.session_queue_common import ( + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, + ) + from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType class EventServiceBase: - queue_event: str = "queue_event" - bulk_download_event: str = "bulk_download_event" - download_event: str = "download_event" - model_event: str = "model_event" - """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event_name: str, payload: Any) -> None: + def dispatch(self, event: "BaseEvent") -> None: pass - def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: - """Bulk download events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.bulk_download_event, - payload={"event": event_name, "data": payload}, - ) + # region: Invocation - def __emit_queue_event(self, event_name: str, payload: dict) -> None: - """Queue events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.queue_event, - payload={"event": event_name, "data": payload}, - ) + def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None: + self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) - def __emit_download_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.download_event, - payload={"event": event_name, "data": payload}, - ) - - def __emit_model_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.model_event, - payload={"event": event_name, "data": payload}, - ) - - # Define events here for every event in the system. - # This will make them easier to integrate until we find a schema generator. - def emit_generator_progress( + def emit_invocation_denoise_progress( self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node_id: str, - source_node_id: str, - progress_image: Optional[ProgressImage], + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", step: int, - order: int, total_steps: int, + progress_image: "ProgressImage", ) -> None: - """Emitted when there is generation progress""" - self.__emit_queue_event( - event_name="generator_progress", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node_id": node_id, - "source_node_id": source_node_id, - "progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None, - "step": step, - "order": order, - "total_steps": total_steps, - }, - ) + self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, step, total_steps, progress_image)) def emit_invocation_complete( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - result: dict, - node: dict, - source_node_id: str, + self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "result": result, - }, - ) + self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output)) def emit_invocation_error( self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", error_type: str, error_message: str, error_traceback: str, - user_id: str | None, - project_id: str | None, ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "error_type": error_type, - "error_message": error_message, - "error_traceback": error_traceback, - "user_id": user_id, - "project_id": project_id, - }, - ) + self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback)) - def emit_invocation_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, - ) -> None: - """Emitted when an invocation has started""" - self.__emit_queue_event( - event_name="invocation_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - }, - ) + # endregion - def emit_graph_execution_complete( - self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str - ) -> None: - """Emitted when a session has completed all invocations""" - self.__emit_queue_event( - event_name="graph_execution_state_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) + # region Session - def emit_model_load_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is requested""" - self.__emit_queue_event( - event_name="model_load_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) + def emit_session_started(self, queue_item: "SessionQueueItem") -> None: + self.dispatch(SessionStartedEvent.build(queue_item)) - def emit_model_load_completed( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is correctly loaded (returns model info)""" - self.__emit_queue_event( - event_name="model_load_completed", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) + def emit_session_complete(self, queue_item: "SessionQueueItem") -> None: + self.dispatch(SessionCompleteEvent.build(queue_item)) - def emit_session_canceled( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - ) -> None: - """Emitted when a session is canceled""" - self.__emit_queue_event( - event_name="session_canceled", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) + def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None: + self.dispatch(SessionCanceledEvent.build(queue_item)) + + # endregion + + # region Queue def emit_queue_item_status_changed( - self, - session_queue_item: SessionQueueItem, - batch_status: BatchStatus, - queue_status: SessionQueueStatus, + self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus" ) -> None: - """Emitted when a queue item's status changes""" - self.__emit_queue_event( - event_name="queue_item_status_changed", - payload={ - "queue_id": queue_status.queue_id, - "queue_item": { - "queue_id": session_queue_item.queue_id, - "item_id": session_queue_item.item_id, - "status": session_queue_item.status, - "batch_id": session_queue_item.batch_id, - "session_id": session_queue_item.session_id, - "error_type": session_queue_item.error_type, - "error_message": session_queue_item.error_message, - "error_traceback": session_queue_item.error_traceback, - "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None, - "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None, - "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, - "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None, - }, - "batch_status": batch_status.model_dump(mode="json"), - "queue_status": queue_status.model_dump(mode="json"), - }, - ) + self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) - def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: - """Emitted when a batch is enqueued""" - self.__emit_queue_event( - event_name="batch_enqueued", - payload={ - "queue_id": enqueue_result.queue_id, - "batch_id": enqueue_result.batch.batch_id, - "enqueued": enqueue_result.enqueued, - }, - ) + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: + self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) def emit_queue_cleared(self, queue_id: str) -> None: - """Emitted when the queue is cleared""" - self.__emit_queue_event( - event_name="queue_cleared", - payload={"queue_id": queue_id}, - ) + self.dispatch(QueueClearedEvent.build(queue_id)) + + # endregion + + # region Download def emit_download_started(self, source: str, download_path: str) -> None: - """ - Emit when a download job is started. - - :param url: The downloaded url - """ - self.__emit_download_event( - event_name="download_started", - payload={"source": source, "download_path": download_path}, - ) + self.dispatch(DownloadStartedEvent.build(source, download_path)) def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: - """ - Emit "download_progress" events at regular intervals during a download job. - - :param source: The downloaded source - :param download_path: The local downloaded file - :param current_bytes: Number of bytes downloaded so far - :param total_bytes: The size of the file being downloaded (if known) - """ - self.__emit_download_event( - event_name="download_progress", - payload={ - "source": source, - "download_path": download_path, - "current_bytes": current_bytes, - "total_bytes": total_bytes, - }, - ) + self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes)) def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: - """ - Emit a "download_complete" event at the end of a successful download. - - :param source: Source URL - :param download_path: Path to the locally downloaded file - :param total_bytes: The size of the downloaded file - """ - self.__emit_download_event( - event_name="download_complete", - payload={ - "source": source, - "download_path": download_path, - "total_bytes": total_bytes, - }, - ) + self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes)) def emit_download_cancelled(self, source: str) -> None: - """Emit a "download_cancelled" event in the event that the download was cancelled by user.""" - self.__emit_download_event( - event_name="download_cancelled", - payload={ - "source": source, - }, - ) + self.dispatch(DownloadCancelledEvent.build(source)) def emit_download_error(self, source: str, error_type: str, error: str) -> None: - """ - Emit a "download_error" event when an download job encounters an exception. + self.dispatch(DownloadErrorEvent.build(source, error_type, error)) - :param source: Source URL - :param error_type: The name of the exception that raised the error - :param error: The traceback from this error - """ - self.__emit_download_event( - event_name="download_error", - payload={ - "source": source, - "error_type": error_type, - "error": error, - }, - ) + # endregion - def emit_model_install_downloading( - self, - source: str, - local_path: str, - bytes: int, - total_bytes: int, - parts: List[Dict[str, Union[str, int]]], - id: int, + # region Model loading + + def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None: + self.dispatch(ModelLoadStartedEvent.build(config, submodel_type)) + + def emit_model_load_complete( + self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None ) -> None: - """ - Emit at intervals while the install job is in progress (remote models only). + self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type)) - :param source: Source of the model - :param local_path: Where model is downloading to - :param parts: Progress of downloading URLs that comprise the model, if any. - :param bytes: Number of bytes downloaded so far. - :param total_bytes: Total size of download, including all files. - This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes". - """ - self.__emit_model_event( - event_name="model_install_downloading", - payload={ - "source": source, - "local_path": local_path, - "bytes": bytes, - "total_bytes": total_bytes, - "parts": parts, - "id": id, - }, - ) + # endregion - def emit_model_install_downloads_done(self, source: str) -> None: - """ - Emit once when all parts are downloaded, but before the probing and registration start. + # region Model install - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_downloads_done", - payload={"source": source}, - ) + def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallDownloadProgressEvent.build(job)) - def emit_model_install_running(self, source: str) -> None: - """ - Emit once when an install job becomes active. + def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallDownloadsCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_running", - payload={"source": source}, - ) + def emit_model_install_started(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallStartedEvent.build(job)) - def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: - """ - Emit when an install job is completed successfully. + def emit_model_install_complete(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - :param key: Model config record key - :param total_bytes: Size of the model (may be None for installation of a local path) - """ - self.__emit_model_event( - event_name="model_install_completed", - payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id}, - ) + def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallCancelledEvent.build(job)) - def emit_model_install_cancelled(self, source: str, id: int) -> None: - """ - Emit when an install job is cancelled. + def emit_model_install_error(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallErrorEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_cancelled", - payload={"source": source, "id": id}, - ) + # endregion - def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: - """ - Emit when an install job encounters an exception. - - :param source: Source of the model - :param error_type: The name of the exception - :param error: A text description of the exception - """ - self.__emit_model_event( - event_name="model_install_error", - payload={"source": source, "error_type": error_type, "error": error, "id": id}, - ) + # region Bulk image download def emit_bulk_download_started( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download starts""" - self._emit_bulk_download_event( - event_name="bulk_download_started", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_completed( + def emit_bulk_download_complete( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download completes""" - self._emit_bulk_download_event( - event_name="bulk_download_completed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_failed( + def emit_bulk_download_error( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str ) -> None: - """Emitted when a bulk download fails""" - self._emit_bulk_download_event( - event_name="bulk_download_failed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - "error": error, - }, + self.dispatch( + BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) ) + + # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py new file mode 100644 index 0000000000..373b422863 --- /dev/null +++ b/invokeai/app/services/events/events_common.py @@ -0,0 +1,636 @@ +from abc import ABC +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar + +from fastapi_events.handlers.local import local_handler +from fastapi_events.registry.payload_schema import registry as payload_schema +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.services.session_queue.session_queue_common import ( + QUEUE_ITEM_STATUS, + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, +) +from invokeai.app.util.misc import get_timestamp +from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType + +if TYPE_CHECKING: + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + + +class EventType(str, Enum): + QUEUE = "queue" + MODEL = "model" + DOWNLOAD = "download" + BULK_IMAGE_DOWNLOAD = "bulk_image_download" + + +class BaseEvent(BaseModel, ABC): + """Base class for all events. All events must inherit from this class. + + Events must define the following class attributes: + - `__event_name__: str`: The name of the event + - `__event_type__: EventType`: The type of the event + + All other attributes should be defined as normal for a pydantic model. + + A timestamp is automatically added to the event when it is created. + """ + + __event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType] + __event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType] + + timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) + + def __init_subclass__(cls, **kwargs: ConfigDict): + for required_attr in ("__event_name__", "__event_type__"): + if getattr(cls, required_attr) is ...: + raise TypeError(f"{cls.__name__} must define {required_attr}") + + model_config = ConfigDict(json_schema_serialization_defaults_required=True) + + +TEvent = TypeVar("TEvent", bound=BaseEvent) + +FastAPIEvent: TypeAlias = tuple[str, TEvent] +""" +A tuple representing a `fastapi-events` event, with the event name and payload. +Provide a generic type to `TEvent` to specify the payload type. +""" + + +class FastAPIEventFunc(Protocol): + def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ... + + +def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None: + """Register a function to handle a list of events. + + :param events: A list of event classes to handle + :param func: The function to handle the events + """ + for event in events: + local_handler.register(event_name=event.__event_name__, _func=func) + + +class QueueEvent(BaseEvent, ABC): + """Base class for queue events""" + + __event_type__ = EventType.QUEUE + __event_name__ = "queue_event" + + queue_id: str = Field(description="The ID of the queue") + + +class QueueItemEvent(QueueEvent, ABC): + """Base class for queue item events""" + + __event_name__ = "queue_item_event" + + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + + +class SessionEvent(QueueItemEvent, ABC): + """Base class for session (aka graph execution state) events""" + + __event_name__ = "session_event" + + session_id: str = Field(description="The ID of the session (aka graph execution state)") + + +class InvocationEvent(SessionEvent, ABC): + """Base class for invocation events""" + + __event_name__ = "invocation_event" + + queue_id: str = Field(description="The ID of the queue") + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + session_id: str = Field(description="The ID of the session (aka graph execution state)") + invocation_id: str = Field(description="The ID of the invocation") + invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") + invocation_type: str = Field(description="The type of invocation") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationStartedEvent(InvocationEvent): + """Emitted when an invocation is started""" + + __event_name__ = "invocation_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationDenoiseProgressEvent(InvocationEvent): + """Emitted at each step during denoising of an invocation.""" + + __event_name__ = "invocation_denoise_progress" + + progress_image: ProgressImage = Field(description="The progress image sent at each step during processing") + step: int = Field(description="The current step of the invocation") + total_steps: int = Field(description="The total number of steps in the invocation") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + step: int, + total_steps: int, + progress_image: ProgressImage, + ) -> "InvocationDenoiseProgressEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + progress_image=progress_image, + step=step, + total_steps=total_steps, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationCompleteEvent(InvocationEvent): + """Emitted when an invocation is complete""" + + __event_name__ = "invocation_complete" + + result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + ) -> "InvocationCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + result=result, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationErrorEvent(InvocationEvent): + """Emitted when an invocation encounters an error""" + + __event_name__ = "invocation_error" + + error_type: str = Field(description="The error type") + error_message: str = Field(description="The error message") + error_traceback: str = Field(description="The error traceback") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + error_type: str, + error_message: str, + error_traceback: str, + ) -> "InvocationErrorEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionStartedEvent(SessionEvent): + """Emitted when a session has started""" + + __event_name__ = "session_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionCompleteEvent(SessionEvent): + """Emitted when a session has completed all invocations""" + + __event_name__ = "session_complete" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionCanceledEvent(SessionEvent): + """Emitted when a session is canceled""" + + __event_name__ = "session_canceled" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class QueueItemStatusChangedEvent(QueueItemEvent): + """Emitted when a queue item's status changes""" + + __event_name__ = "queue_item_status_changed" + + status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item") + error_type: Optional[str] = Field(default=None, description="The error type, if any") + error_message: Optional[str] = Field(default=None, description="The error message, if any") + error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any") + created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created") + updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated") + started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started") + completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed") + batch_status: BatchStatus = Field(description="The status of the batch") + queue_status: SessionQueueStatus = Field(description="The status of the queue") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus + ) -> "QueueItemStatusChangedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + status=queue_item.status, + error_type=queue_item.error_type, + error_message=queue_item.error_message, + error_traceback=queue_item.error_traceback, + created_at=str(queue_item.created_at) if queue_item.created_at else None, + updated_at=str(queue_item.updated_at) if queue_item.updated_at else None, + started_at=str(queue_item.started_at) if queue_item.started_at else None, + completed_at=str(queue_item.completed_at) if queue_item.completed_at else None, + batch_status=batch_status, + queue_status=queue_status, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BatchEnqueuedEvent(QueueEvent): + """Emitted when a batch is enqueued""" + + __event_name__ = "batch_enqueued" + + batch_id: str = Field(description="The ID of the batch") + enqueued: int = Field(description="The number of invocations enqueued") + requested: int = Field( + description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" + ) + priority: int = Field(description="The priority of the batch") + + @classmethod + def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + return cls( + queue_id=enqueue_result.queue_id, + batch_id=enqueue_result.batch.batch_id, + enqueued=enqueue_result.enqueued, + requested=enqueue_result.requested, + priority=enqueue_result.priority, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class QueueClearedEvent(QueueEvent): + """Emitted when a queue is cleared""" + + __event_name__ = "queue_cleared" + + @classmethod + def build(cls, queue_id: str) -> "QueueClearedEvent": + return cls(queue_id=queue_id) + + +class DownloadEvent(BaseEvent, ABC): + """Base class for events associated with a download""" + + __event_type__ = EventType.DOWNLOAD + __event_name__ = "download_event" + + source: str = Field(description="The source of the download") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadStartedEvent(DownloadEvent): + """Emitted when a download is started""" + + __event_name__ = "download_started" + + download_path: str = Field(description="The local path where the download is saved") + + @classmethod + def build(cls, source: str, download_path: str) -> "DownloadStartedEvent": + return cls(source=source, download_path=download_path) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadProgressEvent(DownloadEvent): + """Emitted at intervals during a download""" + + __event_name__ = "download_progress" + + download_path: str = Field(description="The local path where the download is saved") + current_bytes: int = Field(description="The number of bytes downloaded so far") + total_bytes: int = Field(description="The total number of bytes to be downloaded") + + @classmethod + def build(cls, source: str, download_path: str, current_bytes: int, total_bytes: int) -> "DownloadProgressEvent": + return cls(source=source, download_path=download_path, current_bytes=current_bytes, total_bytes=total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadCompleteEvent(DownloadEvent): + """Emitted when a download is completed""" + + __event_name__ = "download_complete" + + download_path: str = Field(description="The local path where the download is saved") + total_bytes: int = Field(description="The total number of bytes downloaded") + + @classmethod + def build(cls, source: str, download_path: str, total_bytes: int) -> "DownloadCompleteEvent": + return cls(source=source, download_path=download_path, total_bytes=total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadCancelledEvent(DownloadEvent): + """Emitted when a download is cancelled""" + + __event_name__ = "download_cancelled" + + @classmethod + def build(cls, source: str) -> "DownloadCancelledEvent": + return cls(source=source) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadErrorEvent(DownloadEvent): + """Emitted when a download encounters an error""" + + __event_name__ = "download_error" + + error_type: str = Field(description="The type of error") + error: str = Field(description="The error message") + + @classmethod + def build(cls, source: str, error_type: str, error: str) -> "DownloadErrorEvent": + return cls(source=source, error_type=error_type, error=error) + + +class ModelEvent(BaseEvent, ABC): + """Base class for events associated with a model""" + + __event_type__ = EventType.MODEL + __event_name__ = "model_event" + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelLoadStartedEvent(ModelEvent): + """Emitted when a model is requested""" + + __event_name__ = "model_load_started" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelLoadCompleteEvent(ModelEvent): + """Emitted when a model is requested""" + + __event_name__ = "model_load_complete" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallDownloadProgressEvent(ModelEvent): + """Emitted at intervals while the install job is in progress (remote models only).""" + + __event_name__ = "model_install_download_progress" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent": + parts: list[dict[str, str | int]] = [ + { + "url": str(x.source), + "local_path": str(x.download_path), + "bytes": x.bytes, + "total_bytes": x.total_bytes, + } + for x in job.download_parts + ] + return cls( + id=job.id, + source=str(job.source), + local_path=job.local_path.as_posix(), + parts=parts, + bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallDownloadsCompleteEvent(ModelEvent): + """Emitted once when an install job becomes active.""" + + __event_name__ = "model_install_downloads_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallStartedEvent(ModelEvent): + """Emitted once when an install job becomes active.""" + + __event_name__ = "model_install_started" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallCompleteEvent(ModelEvent): + """Emitted when an install job is completed successfully.""" + + __event_name__ = "model_install_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + key: str = Field(description="Model config record key") + total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent": + assert job.config_out is not None + return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallCancelledEvent(ModelEvent): + """Emitted when an install job is cancelled.""" + + __event_name__ = "model_install_cancelled" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallErrorEvent(ModelEvent): + """Emitted when an install job encounters an exception.""" + + __event_name__ = "model_install_error" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + error_type: str = Field(description="The name of the exception") + error: str = Field(description="A text description of the exception") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent": + assert job.error_type is not None + assert job.error is not None + return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error) + + +class BulkDownloadEvent(BaseEvent, ABC): + """Base class for events associated with a bulk image download""" + + __event_type__ = EventType.BULK_IMAGE_DOWNLOAD + __event_name__ = "bulk_image_download_event" + + bulk_download_id: str = Field(description="The ID of the bulk image download") + bulk_download_item_id: str = Field(description="The ID of the bulk image download item") + bulk_download_item_name: str = Field(description="The name of the bulk image download item") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BulkDownloadStartedEvent(BulkDownloadEvent): + """Emitted when a bulk image download is started""" + + __event_name__ = "bulk_download_started" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadStartedEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BulkDownloadCompleteEvent(BulkDownloadEvent): + """Emitted when a bulk image download is started""" + + __event_name__ = "bulk_download_complete" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadCompleteEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BulkDownloadErrorEvent(BulkDownloadEvent): + """Emitted when a bulk image download is started""" + + __event_name__ = "bulk_download_error" + + error: str = Field(description="The error message") + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + ) -> "BulkDownloadErrorEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + error=error, + ) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py new file mode 100644 index 0000000000..a8317911cf --- /dev/null +++ b/invokeai/app/services/events/events_fastapievents.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +import threading +from queue import Empty, Queue + +from fastapi_events.dispatcher import dispatch + +from invokeai.app.services.events.events_common import ( + BaseEvent, +) + +from .events_base import EventServiceBase + + +class FastAPIEventService(EventServiceBase): + def __init__(self, event_handler_id: int) -> None: + self.event_handler_id = event_handler_id + self._queue = Queue[BaseEvent | None]() + self._stop_event = threading.Event() + asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) + + super().__init__() + + def stop(self, *args, **kwargs): + self._stop_event.set() + self._queue.put(None) + + def dispatch(self, event: BaseEvent) -> None: + self._queue.put(event) + + async def _dispatch_from_queue(self, stop_event: threading.Event): + """Get events on from the queue and dispatch them, from the correct thread""" + while not stop_event.is_set(): + try: + event = self._queue.get(block=False) + if not event: # Probably stopping + continue + dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False) + + except Empty: + await asyncio.sleep(0.1) + pass + + except asyncio.CancelledError as e: + raise e # Raise a proper error diff --git a/invokeai/app/services/model_install/__init__.py b/invokeai/app/services/model_install/__init__.py index 00a33c203e..941485a134 100644 --- a/invokeai/app/services/model_install/__init__.py +++ b/invokeai/app/services/model_install/__init__.py @@ -1,11 +1,13 @@ """Initialization file for model install service package.""" from .model_install_base import ( + ModelInstallServiceBase, +) +from .model_install_common import ( HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, UnknownInstallJobException, URLModelSource, diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 0ea901fb46..6ee671062d 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -1,244 +1,19 @@ # Copyright 2023 Lincoln D. Stein and the InvokeAI development team """Baseclass definitions for the model installer.""" -import re -import traceback from abc import ABC, abstractmethod -from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Set, Union +from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.networks import AnyHttpUrl -from typing_extensions import Annotated from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.config import ModelSourceType -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - - -class InstallStatus(str, Enum): - """State of an install job running in the background.""" - - WAITING = "waiting" # waiting to be dequeued - DOWNLOADING = "downloading" # downloading of model files in process - DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run - RUNNING = "running" # being processed - COMPLETED = "completed" # finished running - ERROR = "error" # terminated with an error message - CANCELLED = "cancelled" # terminated with an error message - - -class ModelInstallPart(BaseModel): - url: AnyHttpUrl - path: Path - bytes: int = 0 - total_bytes: int = 0 - - -class UnknownInstallJobException(Exception): - """Raised when the status of an unknown job is requested.""" - - -class StringLikeSource(BaseModel): - """ - Base class for model sources, implements functions that lets the source be sorted and indexed. - - These shenanigans let this stuff work: - - source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') - mydict = {source1: 'model 1'} - assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' - assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' - - source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) - assert source1 == source2 - assert source1 == 'C:/users/mort/foo.safetensors' - """ - - def __hash__(self) -> int: - """Return hash of the path field, for indexing.""" - return hash(str(self)) - - def __lt__(self, other: object) -> int: - """Return comparison of the stringified version, for sorting.""" - return str(self) < str(other) - - def __eq__(self, other: object) -> bool: - """Return equality on the stringified version.""" - if isinstance(other, Path): - return str(self) == other.as_posix() - else: - return str(self) == str(other) - - -class LocalModelSource(StringLikeSource): - """A local file or directory path.""" - - path: str | Path - inplace: Optional[bool] = False - type: Literal["local"] = "local" - - # these methods allow the source to be used in a string-like way, - # for example as an index into a dict - def __str__(self) -> str: - """Return string version of path when string rep needed.""" - return Path(self.path).as_posix() - - -class HFModelSource(StringLikeSource): - """ - A HuggingFace repo_id with optional variant, sub-folder and access token. - Note that the variant option, if not provided to the constructor, will default to fp16, which is - what people (almost) always want. - """ - - repo_id: str - variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 - subfolder: Optional[Path] = None - access_token: Optional[str] = None - type: Literal["hf"] = "hf" - - @field_validator("repo_id") - @classmethod - def proper_repo_id(cls, v: str) -> str: # noqa D102 - if not re.match(r"^([.\w-]+/[.\w-]+)$", v): - raise ValueError(f"{v}: invalid repo_id format") - return v - - def __str__(self) -> str: - """Return string version of repoid when string rep needed.""" - base: str = self.repo_id - if self.variant: - base += f":{self.variant or ''}" - if self.subfolder: - base += f":{self.subfolder}" - return base - - -class URLModelSource(StringLikeSource): - """A generic URL point to a checkpoint file.""" - - url: AnyHttpUrl - access_token: Optional[str] = None - type: Literal["url"] = "url" - - def __str__(self) -> str: - """Return string version of the url when string rep needed.""" - return str(self.url) - - -ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] - -MODEL_SOURCE_TO_TYPE_MAP = { - URLModelSource: ModelSourceType.Url, - HFModelSource: ModelSourceType.HFRepoID, - LocalModelSource: ModelSourceType.Path, -} - - -class ModelInstallJob(BaseModel): - """Object that tracks the current status of an install request.""" - - id: int = Field(description="Unique ID for this job") - status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") - error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") - config_in: Dict[str, Any] = Field( - default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." - ) - config_out: Optional[AnyModelConfig] = Field( - default=None, description="After successful installation, this will hold the configuration object." - ) - inplace: bool = Field( - default=False, description="Leave model in its current location; otherwise install under models directory" - ) - source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") - local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") - bytes: int = Field( - default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" - ) - total_bytes: int = Field(default=0, description="Total size of the model to be installed") - source_metadata: Optional[AnyModelRepoMetadata] = Field( - default=None, description="Metadata provided by the model source" - ) - download_parts: Set[DownloadJob] = Field( - default_factory=set, description="Download jobs contributing to this install" - ) - error: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the text of the exception" - ) - error_traceback: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the exception traceback" - ) - # internal flags and transitory settings - _install_tmpdir: Optional[Path] = PrivateAttr(default=None) - _exception: Optional[Exception] = PrivateAttr(default=None) - - def set_error(self, e: Exception) -> None: - """Record the error and traceback from an exception.""" - self._exception = e - self.error = str(e) - self.error_traceback = self._format_error(e) - self.status = InstallStatus.ERROR - self.error_reason = self._exception.__class__.__name__ if self._exception else None - - def cancel(self) -> None: - """Call to cancel the job.""" - self.status = InstallStatus.CANCELLED - - @property - def error_type(self) -> Optional[str]: - """Class name of the exception that led to status==ERROR.""" - return self._exception.__class__.__name__ if self._exception else None - - def _format_error(self, exception: Exception) -> str: - """Error traceback.""" - return "".join(traceback.format_exception(exception)) - - @property - def cancelled(self) -> bool: - """Set status to CANCELLED.""" - return self.status == InstallStatus.CANCELLED - - @property - def errored(self) -> bool: - """Return true if job has errored.""" - return self.status == InstallStatus.ERROR - - @property - def waiting(self) -> bool: - """Return true if job is waiting to run.""" - return self.status == InstallStatus.WAITING - - @property - def downloading(self) -> bool: - """Return true if job is downloading.""" - return self.status == InstallStatus.DOWNLOADING - - @property - def downloads_done(self) -> bool: - """Return true if job's downloads ae done.""" - return self.status == InstallStatus.DOWNLOADS_DONE - - @property - def running(self) -> bool: - """Return true if job is running.""" - return self.status == InstallStatus.RUNNING - - @property - def complete(self) -> bool: - """Return true if job completed without errors.""" - return self.status == InstallStatus.COMPLETED - - @property - def in_terminal_state(self) -> bool: - """Return true if job is in a terminal state.""" - return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] +from invokeai.backend.model_manager.config import AnyModelConfig class ModelInstallServiceBase(ABC): @@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC): @property @abstractmethod - def event_bus(self) -> Optional[EventServiceBase]: + def event_bus(self) -> Optional["EventServiceBase"]: """Return the event service base object associated with the installer.""" @abstractmethod diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py new file mode 100644 index 0000000000..2de1db5474 --- /dev/null +++ b/invokeai/app/services/model_install/model_install_common.py @@ -0,0 +1,231 @@ +import re +import traceback +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Set, Union + +from pydantic import BaseModel, Field, PrivateAttr, field_validator +from pydantic.networks import AnyHttpUrl +from typing_extensions import Annotated + +from invokeai.app.services.download import DownloadJob +from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant +from invokeai.backend.model_manager.config import ModelSourceType +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class InstallStatus(str, Enum): + """State of an install job running in the background.""" + + WAITING = "waiting" # waiting to be dequeued + DOWNLOADING = "downloading" # downloading of model files in process + DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run + RUNNING = "running" # being processed + COMPLETED = "completed" # finished running + ERROR = "error" # terminated with an error message + CANCELLED = "cancelled" # terminated with an error message + + +class ModelInstallPart(BaseModel): + url: AnyHttpUrl + path: Path + bytes: int = 0 + total_bytes: int = 0 + + +class UnknownInstallJobException(Exception): + """Raised when the status of an unknown job is requested.""" + + +class StringLikeSource(BaseModel): + """ + Base class for model sources, implements functions that lets the source be sorted and indexed. + + These shenanigans let this stuff work: + + source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') + mydict = {source1: 'model 1'} + assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' + assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' + + source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) + assert source1 == source2 + assert source1 == 'C:/users/mort/foo.safetensors' + """ + + def __hash__(self) -> int: + """Return hash of the path field, for indexing.""" + return hash(str(self)) + + def __lt__(self, other: object) -> int: + """Return comparison of the stringified version, for sorting.""" + return str(self) < str(other) + + def __eq__(self, other: object) -> bool: + """Return equality on the stringified version.""" + if isinstance(other, Path): + return str(self) == other.as_posix() + else: + return str(self) == str(other) + + +class LocalModelSource(StringLikeSource): + """A local file or directory path.""" + + path: str | Path + inplace: Optional[bool] = False + type: Literal["local"] = "local" + + # these methods allow the source to be used in a string-like way, + # for example as an index into a dict + def __str__(self) -> str: + """Return string version of path when string rep needed.""" + return Path(self.path).as_posix() + + +class HFModelSource(StringLikeSource): + """ + A HuggingFace repo_id with optional variant, sub-folder and access token. + Note that the variant option, if not provided to the constructor, will default to fp16, which is + what people (almost) always want. + """ + + repo_id: str + variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 + subfolder: Optional[Path] = None + access_token: Optional[str] = None + type: Literal["hf"] = "hf" + + @field_validator("repo_id") + @classmethod + def proper_repo_id(cls, v: str) -> str: # noqa D102 + if not re.match(r"^([.\w-]+/[.\w-]+)$", v): + raise ValueError(f"{v}: invalid repo_id format") + return v + + def __str__(self) -> str: + """Return string version of repoid when string rep needed.""" + base: str = self.repo_id + base += f":{self.variant or ''}" + base += f":{self.subfolder}" if self.subfolder else "" + return base + + +class URLModelSource(StringLikeSource): + """A generic URL point to a checkpoint file.""" + + url: AnyHttpUrl + access_token: Optional[str] = None + type: Literal["url"] = "url" + + def __str__(self) -> str: + """Return string version of the url when string rep needed.""" + return str(self.url) + + +ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] + +MODEL_SOURCE_TO_TYPE_MAP = { + URLModelSource: ModelSourceType.Url, + HFModelSource: ModelSourceType.HFRepoID, + LocalModelSource: ModelSourceType.Path, +} + + +class ModelInstallJob(BaseModel): + """Object that tracks the current status of an install request.""" + + id: int = Field(description="Unique ID for this job") + status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") + error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") + config_in: Dict[str, Any] = Field( + default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." + ) + config_out: Optional[AnyModelConfig] = Field( + default=None, description="After successful installation, this will hold the configuration object." + ) + inplace: bool = Field( + default=False, description="Leave model in its current location; otherwise install under models directory" + ) + source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") + local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") + bytes: int = Field( + default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" + ) + total_bytes: int = Field(default=0, description="Total size of the model to be installed") + source_metadata: Optional[AnyModelRepoMetadata] = Field( + default=None, description="Metadata provided by the model source" + ) + download_parts: Set[DownloadJob] = Field( + default_factory=set, description="Download jobs contributing to this install" + ) + error: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the text of the exception" + ) + error_traceback: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the exception traceback" + ) + # internal flags and transitory settings + _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _exception: Optional[Exception] = PrivateAttr(default=None) + + def set_error(self, e: Exception) -> None: + """Record the error and traceback from an exception.""" + self._exception = e + self.error = str(e) + self.error_traceback = self._format_error(e) + self.status = InstallStatus.ERROR + self.error_reason = self._exception.__class__.__name__ if self._exception else None + + def cancel(self) -> None: + """Call to cancel the job.""" + self.status = InstallStatus.CANCELLED + + @property + def error_type(self) -> Optional[str]: + """Class name of the exception that led to status==ERROR.""" + return self._exception.__class__.__name__ if self._exception else None + + def _format_error(self, exception: Exception) -> str: + """Error traceback.""" + return "".join(traceback.format_exception(exception)) + + @property + def cancelled(self) -> bool: + """Set status to CANCELLED.""" + return self.status == InstallStatus.CANCELLED + + @property + def errored(self) -> bool: + """Return true if job has errored.""" + return self.status == InstallStatus.ERROR + + @property + def waiting(self) -> bool: + """Return true if job is waiting to run.""" + return self.status == InstallStatus.WAITING + + @property + def downloading(self) -> bool: + """Return true if job is downloading.""" + return self.status == InstallStatus.DOWNLOADING + + @property + def downloads_done(self) -> bool: + """Return true if job's downloads ae done.""" + return self.status == InstallStatus.DOWNLOADS_DONE + + @property + def running(self) -> bool: + """Return true if job is running.""" + return self.status == InstallStatus.RUNNING + + @property + def complete(self) -> bool: + """Return true if job completed without errors.""" + return self.status == InstallStatus.COMPLETED + + @property + def in_terminal_state(self) -> bool: + """Return true if job is in a terminal state.""" + return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 6eb9549ef0..df060caff3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -10,7 +10,7 @@ from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch import yaml @@ -20,8 +20,8 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.config import ( @@ -45,13 +45,12 @@ from invokeai.backend.util import InvokeAILogger from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.devices import TorchDevice -from .model_install_base import ( +from .model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, StringLikeSource, URLModelSource, @@ -59,6 +58,9 @@ from .model_install_base import ( TMPDIR_PREFIX = "tmpinstall_" +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + class ModelInstallService(ModelInstallServiceBase): """class for InvokeAI model installation.""" @@ -68,7 +70,7 @@ class ModelInstallService(ModelInstallServiceBase): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, session: Optional[Session] = None, ): """ @@ -104,7 +106,7 @@ class ModelInstallService(ModelInstallServiceBase): return self._record_store @property - def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 + def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102 return self._event_bus # make the invoker optional here because we don't need it and it @@ -855,35 +857,17 @@ class ModelInstallService(ModelInstallServiceBase): job.status = InstallStatus.RUNNING self._logger.info(f"Model install started: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_running(str(job.source)) + self._event_bus.emit_model_install_started(job) def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: - parts: List[Dict[str, str | int]] = [ - { - "url": str(x.source), - "local_path": str(x.download_path), - "bytes": x.bytes, - "total_bytes": x.total_bytes, - } - for x in job.download_parts - ] - assert job.bytes is not None - assert job.total_bytes is not None - self._event_bus.emit_model_install_downloading( - str(job.source), - local_path=job.local_path.as_posix(), - parts=parts, - bytes=job.bytes, - total_bytes=job.total_bytes, - id=job.id, - ) + self._event_bus.emit_model_install_download_progress(job) def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: job.status = InstallStatus.DOWNLOADS_DONE self._logger.info(f"Model download complete: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_downloads_done(str(job.source)) + self._event_bus.emit_model_install_downloads_complete(job) def _signal_job_completed(self, job: ModelInstallJob) -> None: job.status = InstallStatus.COMPLETED @@ -891,24 +875,19 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"Model install complete: {job.source}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") if self._event_bus: - assert job.local_path is not None - assert job.config_out is not None - key = job.config_out.key - self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id) + self._event_bus.emit_model_install_complete(job) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") if self._event_bus: - error_type = job.error_type - error = job.error - assert error_type is not None - assert error is not None - self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id) + assert job.error_type is not None + assert job.error is not None + self._event_bus.emit_model_install_error(job) def _signal_job_cancelled(self, job: ModelInstallJob) -> None: self._logger.info(f"Model install canceled: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) + self._event_bus.emit_model_install_cancelled(job) @staticmethod def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index cc80333e93..9d75aafde1 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context_data: Invocation context data used for event reporting """ @property diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 21d3c56f36..1d6423af5a 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -5,7 +5,6 @@ from typing import Optional, Type from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, @@ -51,25 +50,15 @@ class ModelLoadService(ModelLoadServiceBase): """Return the checkpoint convert cache used by this loader.""" return self._convert_cache - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting """ - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - ) + + self._invoker.services.events.emit_model_load_started(model_config, submodel_type) implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore loaded_model: LoadedModel = implementation( @@ -79,40 +68,6 @@ class ModelLoadService(ModelLoadServiceBase): convert_cache=self._convert_cache, ).load_model(model_config, submodel_type) - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - loaded=True, - ) + self._invoker.services.events.emit_model_load_started(model_config, submodel_type) + return loaded_model - - def _emit_load_event( - self, - context_data: InvocationContextData, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - submodel_type: Optional[SubModelType] = None, - ) -> None: - if not self._invoker: - return - - if not loaded: - self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) - else: - self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 2207e71176..ef455d00b0 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -4,11 +4,16 @@ from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from typing import Optional -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + FastAPIEvent, + QueueClearedEvent, + QueueEvent, + QueueItemStatusChangedEvent, + SessionCanceledEvent, + register_events, +) from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.session_processor.session_processor_base import ( OnAfterRunNode, @@ -182,12 +187,7 @@ class DefaultSessionRunner(SessionRunnerBase): # TODO(psyche): This feels jumbled - we should review separation of concerns here. # Send complete event. The events service will receive this and update the queue item's status. - self._services.events.emit_graph_execution_complete( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - ) + self._services.events.emit_session_complete(queue_item=queue_item) # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor # we don't care about that - suppress the error. @@ -208,14 +208,7 @@ class DefaultSessionRunner(SessionRunnerBase): ) # Send starting event - self._services.events.emit_invocation_started( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session_id, - node=invocation.model_dump(), - source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - ) + self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation) for callback in self._on_before_run_node_callbacks: callback(invocation=invocation, queue_item=queue_item) @@ -230,15 +223,7 @@ class DefaultSessionRunner(SessionRunnerBase): ) # Send complete event on successful runs - self._services.events.emit_invocation_complete( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - result=output.model_dump(), - ) + self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output) for callback in self._on_after_run_node_callbacks: callback(invocation=invocation, queue_item=queue_item, output=output) @@ -267,17 +252,11 @@ class DefaultSessionRunner(SessionRunnerBase): # Send error event self._services.events.emit_invocation_error( - queue_batch_id=queue_item.session_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + invocation=invocation, error_type=error_type, error_message=error_message, error_traceback=error_traceback, - user_id=getattr(queue_item, "user_id", None), - project_id=getattr(queue_item, "project_id", None), ) for callback in self._on_node_error_callbacks: @@ -315,7 +294,10 @@ class DefaultSessionProcessor(SessionProcessorBase): self._poll_now_event = ThreadEvent() self._cancel_event = ThreadEvent() - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) + register_events( + events={SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent, QueueItemStatusChangedEvent}, + func=self._on_queue_event, + ) self._thread_semaphore = BoundedSemaphore(self._thread_limit) @@ -350,30 +332,25 @@ class DefaultSessionProcessor(SessionProcessorBase): def _poll_now(self) -> None: self._poll_now_event.set() - async def _on_queue_event(self, event: FastAPIEvent) -> None: - event_name = event[1]["event"] - + async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None: + _event_name, payload = event if ( - event_name == "session_canceled" + isinstance(payload, SessionCanceledEvent) and self._queue_item - and self._queue_item.item_id == event[1]["data"]["queue_item_id"] + and self._queue_item.item_id == payload.item_id ): self._cancel_event.set() self._poll_now() elif ( - event_name == "queue_cleared" + isinstance(payload, QueueClearedEvent) and self._queue_item - and self._queue_item.queue_id == event[1]["data"]["queue_id"] + and self._queue_item.queue_id == payload.queue_id ): self._cancel_event.set() self._poll_now() - elif event_name == "batch_enqueued": + elif isinstance(payload, BatchEnqueuedEvent): self._poll_now() - elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [ - "completed", - "failed", - "canceled", - ]: + elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ["completed", "failed", "canceled"]: self._poll_now() def resume(self) -> SessionProcessorStatus: @@ -422,6 +399,7 @@ class DefaultSessionProcessor(SessionProcessorBase): poll_now_event.wait(self._polling_interval) continue + self._invoker.services.events.emit_session_started(self._queue_item) self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 9401eabecf..8b4be151ed 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -2,10 +2,13 @@ import sqlite3 import threading from typing import Optional, Union, cast -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + FastAPIEvent, + InvocationErrorEvent, + SessionCanceledEvent, + SessionCompleteEvent, + register_events, +) from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_common import ( @@ -42,7 +45,11 @@ class SqliteSessionQueue(SessionQueueBase): self.__invoker = invoker self._set_in_progress_to_canceled() prune_result = self.prune(DEFAULT_QUEUE_ID) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) + + register_events(events={InvocationErrorEvent}, func=self._handle_error_event) + register_events(events={SessionCompleteEvent}, func=self._handle_complete_event) + register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event) + if prune_result.deleted > 0: self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") @@ -52,59 +59,41 @@ class SqliteSessionQueue(SessionQueueBase): self.__conn = db.conn self.__cursor = self.__conn.cursor() - def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: - return event[1]["event"] in match_in - - async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent: - event_name = event[1]["event"] - - # This was a match statement, but match is not supported on python 3.9 - if event_name == "graph_execution_state_complete": - await self._handle_complete_event(event) - elif event_name == "invocation_error": - await self._handle_error_event(event) - elif event_name == "session_canceled": - await self._handle_cancel_event(event) - return event - - async def _handle_complete_event(self, event: FastAPIEvent) -> None: + async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] # When a queue item has an error, we get an error event, then a completed event. # Mark the queue item completed only if it isn't already marked completed, e.g. # by a previously-handled error event. - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed") - except SessionQueueItemNotFoundError: - return + _event_name, payload = event - async def _handle_error_event(self, event: FastAPIEvent) -> None: + queue_item = self.get_queue_item(payload.item_id) + if queue_item.status not in ["completed", "failed", "canceled"]: + self._set_queue_item_status(item_id=payload.item_id, status="completed") + except SessionQueueItemNotFoundError: + pass + + async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] - error_type = event[1]["data"]["error_type"] - error_message = event[1]["data"]["error_message"] - error_traceback = event[1]["data"]["error_traceback"] - queue_item = self.get_queue_item(item_id) + _event_name, payload = event # always set to failed if have an error, even if previously the item was marked completed or canceled - queue_item = self._set_queue_item_status( - item_id=queue_item.item_id, + self._set_queue_item_status( + item_id=payload.item_id, status="failed", - error_type=error_type, - error_message=error_message, - error_traceback=error_traceback, + error_type=payload.error_type, + error_message=payload.error_message, + error_traceback=payload.error_traceback, ) except SessionQueueItemNotFoundError: - return + pass - async def _handle_cancel_event(self, event: FastAPIEvent) -> None: + async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] - queue_item = self.get_queue_item(item_id) + _event_name, payload = event + queue_item = self.get_queue_item(payload.item_id) if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled") + self._set_queue_item_status(item_id=payload.item_id, status="canceled") except SessionQueueItemNotFoundError: - return + pass def _set_in_progress_to_canceled(self) -> None: """ @@ -306,11 +295,7 @@ class SqliteSessionQueue(SessionQueueBase): queue_item = self.get_queue_item(item_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=queue_item, - batch_status=batch_status, - queue_status=queue_status, - ) + self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) return queue_item def is_empty(self, queue_id: str) -> IsEmptyResult: @@ -422,12 +407,7 @@ class SqliteSessionQueue(SessionQueueBase): queue_item = self.get_queue_item(item_id) if queue_item.status not in ["canceled", "failed", "completed"]: queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") - self.__invoker.services.events.emit_session_canceled( - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - queue_batch_id=queue_item.batch_id, - graph_execution_state_id=queue_item.session_id, - ) + self.__invoker.services.events.emit_session_canceled(queue_item) return queue_item def fail_queue_item( @@ -446,12 +426,7 @@ class SqliteSessionQueue(SessionQueueBase): error_message=error_message, error_traceback=error_traceback, ) - self.__invoker.services.events.emit_session_canceled( - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - queue_batch_id=queue_item.batch_id, - graph_execution_state_id=queue_item.session_id, - ) + self.__invoker.services.events.emit_session_canceled(queue_item) return queue_item def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: @@ -487,18 +462,11 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) + self.__invoker.services.events.emit_session_canceled(current_queue_item) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() @@ -538,18 +506,11 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) + self.__invoker.services.events.emit_session_canceled(current_queue_item) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index de31a42665..f977a5ba9c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -353,11 +353,11 @@ class ModelsInterface(InvocationContextInterface): if isinstance(identifier, str): model = self._services.model_manager.store.get_model(identifier) - return self._services.model_manager.load.load_model(model, submodel_type, self._data) + return self._services.model_manager.load.load_model(model, submodel_type) else: _submodel_type = submodel_type or identifier.submodel_type model = self._services.model_manager.store.get_model(identifier.key) - return self._services.model_manager.load.load_model(model, _submodel_type, self._data) + return self._services.model_manager.load.load_model(model, _submodel_type) def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None @@ -382,7 +382,7 @@ class ModelsInterface(InvocationContextInterface): if len(configs) > 1: raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") - return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) + return self._services.model_manager.load.load_model(configs[0], submodel_type) def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: """Gets a model's config. diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 8cb59f5b3a..1bbd6bc8d0 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -113,15 +113,10 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - events.emit_generator_progress( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - node_id=context_data.invocation.id, - source_node_id=context_data.source_invocation_id, - progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), - step=intermediate_state.step, - order=intermediate_state.order, - total_steps=intermediate_state.total_steps, + events.emit_invocation_denoise_progress( + context_data.queue_item, + context_data.invocation, + intermediate_state.step, + intermediate_state.total_steps * intermediate_state.order, + ProgressImage(dataURL=dataURL, width=width, height=height), ) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index c755d3c491..cf7ebe8d29 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -14,10 +14,12 @@ from pydantic_core import Url from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ( + ModelInstallServiceBase, +) +from invokeai.app.services.model_install.model_install_common import ( InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, URLModelSource, ) from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException