refactor(events): use pydantic schemas for events

Our events handling and implementation has a couple pain points:
- Adding or removing data from event payloads requires changes wherever the events are dispatched from.
- We have no type safety for events and need to rely on string matching and dict access when interacting with events.
- Frontend types for socket events must be manually typed. This has caused several bugs.

`fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly.

This allows us to eliminate a layer of indirection and some unpleasant complexity:
- Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed.
- Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload.
- Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync.

This commit moves the backend over to this improved event handling setup.
This commit is contained in:
psychedelicious 2024-03-14 19:04:19 +11:00
parent 461e857824
commit 9bd78823a3
21 changed files with 1263 additions and 1025 deletions

View File

@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService from ..services.images.images_default import ImageService
@ -33,7 +34,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this? # TODO: is there a better way to achieve this?

View File

@ -1,52 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,

View File

@ -1,66 +1,119 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler from pydantic import BaseModel
from fastapi_events.typing import Event
from socketio import ASGIApp, AsyncServer from socketio import ASGIApp, AsyncServer
from ..services.events.events_base import EventServiceBase from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEvent,
BulkDownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class SocketIO: class SocketIO:
__sio: AsyncServer _sub_queue = "subscribe_queue"
__app: ASGIApp _unsub_queue = "unsubscribe_queue"
__sub_queue: str = "subscribe_queue" _sub_bulk_download = "subscribe_bulk_download"
__unsub_queue: str = "unsubscribe_queue" _unsub_bulk_download = "unsubscribe_bulk_download"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app) app.mount("/ws", self._app)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) register_events(
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) {
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) InvocationStartedEvent,
InvocationDenoiseProgressEvent,
async def _handle_queue_event(self, event: Event): InvocationCompleteEvent,
await self.__sio.emit( InvocationErrorEvent,
event=event[1]["event"], SessionStartedEvent,
data=event[1]["data"], SessionCompleteEvent,
room=event[1]["data"]["queue_id"], SessionCanceledEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
},
self._handle_queue_event,
) )
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: register_events(
if "queue_id" in data: {
await self.__sio.enter_room(sid, data["queue_id"]) ModelLoadStartedEvent,
ModelLoadCompleteEvent,
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: ModelInstallDownloadProgressEvent,
if "queue_id" in data: ModelInstallStartedEvent,
await self.__sio.leave_room(sid, data["queue_id"]) ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
async def _handle_model_event(self, event: Event) -> None: ModelInstallErrorEvent,
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) },
self._handle_model_event,
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
) )
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): register_events(
if "bulk_download_id" in data: {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
await self.__sio.enter_room(sid, data["bulk_download_id"]) self._handle_bulk_image_download_event,
)
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): async def _handle_sub_queue(self, sid: str, data: Any) -> None:
if "bulk_download_id" in data: await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
await self.__sio.leave_room(sid, data["bulk_download_id"])
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]):
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump())
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump())

View File

@ -5,7 +5,7 @@ import socket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from inspect import signature from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, cast
import torch import torch
import uvicorn import uvicorn
@ -17,6 +17,8 @@ from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.registry.payload_schema import registry as fastapi_events_registry
from pydantic import BaseModel
from pydantic.json_schema import models_json_schema from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available from torch.backends.mps import is_available as is_mps_available
@ -182,23 +184,16 @@ def custom_openapi() -> dict[str, Any]:
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation" invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary? # Add all pydantic event schemas registered with fastapi-events
# Leave it here just in case for payload in fastapi_events_registry.data.values():
# json_schema = cast(BaseModel, payload).model_json_schema(
# from invokeai.backend.model_manager import get_model_config_formats mode="serialization", ref_template="#/components/schemas/{model}"
# formats = get_model_config_formats() )
# for model_config_name, enum_set in formats.items(): if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
# if model_config_name in openapi_schema["components"]["schemas"]: openapi_schema["components"]["schemas"][schema_key] = schema
# # print(f"Config with name {name} already defined") del json_schema["$defs"]
# continue openapi_schema["components"]["schemas"][payload.__name__] = json_schema
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema

View File

@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started( self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
) )
def _signal_job_completed( def _signal_job_completed(
@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
assert bulk_download_item_name is not None assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed( self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
) )
def _signal_job_failed( def _signal_job_failed(
@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
assert exception is not None assert exception is not None
self._invoker.services.events.emit_bulk_download_failed( self._invoker.services.events.emit_bulk_download_error(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
) )
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):

View File

@ -8,14 +8,13 @@ import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional, Set from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
import requests import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -30,6 +29,9 @@ from .download_base import (
UnknownJobIDException, UnknownJobIDException,
) )
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content() # Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000 DOWNLOAD_CHUNK_SIZE = 100000
@ -40,7 +42,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__( def __init__(
self, self,
max_parallel_dl: int = 5, max_parallel_dl: int = 5,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None, requests_session: Optional[requests.sessions.Session] = None,
): ):
""" """

View File

@ -1,494 +1,188 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Optional
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.events.events_common import (
from invokeai.app.services.session_queue.session_queue_common import ( BaseEvent,
BatchStatus, BatchEnqueuedEvent,
EnqueueBatchResult, BulkDownloadCompleteEvent,
SessionQueueItem, BulkDownloadErrorEvent,
SessionQueueStatus, BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
) )
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig if TYPE_CHECKING:
from invokeai.backend.model_manager.config import SubModelType from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import BaseEvent
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
class EventServiceBase: class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed""" """Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event: "BaseEvent") -> None:
pass pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: # region: Invocation
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None: def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
"""Queue events are emitted to a room with queue_id as the room name""" self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None: def emit_invocation_denoise_progress(
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self, self,
queue_id: str, queue_item: "SessionQueueItem",
queue_item_id: int, invocation: "BaseInvocation",
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int, step: int,
order: int,
total_steps: int, total_steps: int,
progress_image: "ProgressImage",
) -> None: ) -> None:
"""Emitted when there is generation progress""" self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, step, total_steps, progress_image))
self.__emit_queue_event(
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete( def emit_invocation_complete(
self, self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
self.__emit_queue_event(
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error( def emit_invocation_error(
self, self,
queue_id: str, queue_item: "SessionQueueItem",
queue_item_id: int, invocation: "BaseInvocation",
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str, error_type: str,
error_message: str, error_message: str,
error_traceback: str, error_traceback: str,
user_id: str | None,
project_id: str | None,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
self.__emit_queue_event(
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error_message": error_message,
"error_traceback": error_traceback,
"user_id": user_id,
"project_id": project_id,
},
)
def emit_invocation_started( # endregion
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
def emit_graph_execution_complete( # region Session
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started( def emit_session_started(self, queue_item: "SessionQueueItem") -> None:
self, self.dispatch(SessionStartedEvent.build(queue_item))
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_completed( def emit_session_complete(self, queue_item: "SessionQueueItem") -> None:
self, self.dispatch(SessionCompleteEvent.build(queue_item))
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_session_canceled( def emit_session_canceled(self, queue_item: "SessionQueueItem") -> None:
self, self.dispatch(SessionCanceledEvent.build(queue_item))
queue_id: str,
queue_item_id: int, # endregion
queue_batch_id: str,
graph_execution_state_id: str, # region Queue
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_queue_item_status_changed( def emit_queue_item_status_changed(
self, self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None: ) -> None:
"""Emitted when a queue item's status changes""" self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error_type": session_queue_item.error_type,
"error_message": session_queue_item.error_message,
"error_traceback": session_queue_item.error_traceback,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
"""Emitted when a batch is enqueued""" self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
self.__emit_queue_event(
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_queue_cleared(self, queue_id: str) -> None: def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared""" self.dispatch(QueueClearedEvent.build(queue_id))
self.__emit_queue_event(
event_name="queue_cleared", # endregion
payload={"queue_id": queue_id},
) # region Download
def emit_download_started(self, source: str, download_path: str) -> None: def emit_download_started(self, source: str, download_path: str) -> None:
""" self.dispatch(DownloadStartedEvent.build(source, download_path))
Emit when a download job is started.
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
""" self.dispatch(DownloadProgressEvent.build(source, download_path, current_bytes, total_bytes))
Emit "download_progress" events at regular intervals during a download job.
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
""" self.dispatch(DownloadCompleteEvent.build(source, download_path, total_bytes))
Emit a "download_complete" event at the end of a successful download.
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, source: str) -> None: def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user.""" self.dispatch(DownloadCancelledEvent.build(source))
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, source: str, error_type: str, error: str) -> None: def emit_download_error(self, source: str, error_type: str, error: str) -> None:
""" self.dispatch(DownloadErrorEvent.build(source, error_type, error))
Emit a "download_error" event when an download job encounters an exception.
:param source: Source URL # endregion
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_install_downloading( # region Model loading
self,
source: str, def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
local_path: str, self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
bytes: int,
total_bytes: int, def emit_model_load_complete(
parts: List[Dict[str, Union[str, int]]], self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
id: int,
) -> None: ) -> None:
""" self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
Emit at intervals while the install job is in progress (remote models only).
:param source: Source of the model # endregion
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
def emit_model_install_downloads_done(self, source: str) -> None: # region Model install
"""
Emit once when all parts are downloaded, but before the probing and registration start.
:param source: Source of the model; local path, repo_id or url def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
""" self.dispatch(ModelInstallDownloadProgressEvent.build(job))
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_running(self, source: str) -> None: def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
""" self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
Emit once when an install job becomes active.
:param source: Source of the model; local path, repo_id or url def emit_model_install_started(self, job: "ModelInstallJob") -> None:
""" self.dispatch(ModelInstallStartedEvent.build(job))
self.__emit_model_event(
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
""" self.dispatch(ModelInstallCompleteEvent.build(job))
Emit when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
:param key: Model config record key self.dispatch(ModelInstallCancelledEvent.build(job))
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, source: str, id: int) -> None: def emit_model_install_error(self, job: "ModelInstallJob") -> None:
""" self.dispatch(ModelInstallErrorEvent.build(job))
Emit when an install job is cancelled.
:param source: Source of the model; local path, repo_id or url # endregion
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: # region Bulk image download
"""
Emit when an install job encounters an exception.
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started( def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None: ) -> None:
"""Emitted when a bulk download starts""" self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_completed( def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None: ) -> None:
"""Emitted when a bulk download completes""" self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed( def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None: ) -> None:
"""Emitted when a bulk download fails""" self.dispatch(
self._emit_bulk_download_event( BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
) )
# endregion

View File

@ -0,0 +1,636 @@
from abc import ABC
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
if TYPE_CHECKING:
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventType(str, Enum):
QUEUE = "queue"
MODEL = "model"
DOWNLOAD = "download"
BULK_IMAGE_DOWNLOAD = "bulk_image_download"
class BaseEvent(BaseModel, ABC):
"""Base class for all events. All events must inherit from this class.
Events must define the following class attributes:
- `__event_name__: str`: The name of the event
- `__event_type__: EventType`: The type of the event
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType]
__event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
def __init_subclass__(cls, **kwargs: ConfigDict):
for required_attr in ("__event_name__", "__event_type__"):
if getattr(cls, required_attr) is ...:
raise TypeError(f"{cls.__name__} must define {required_attr}")
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
TEvent = TypeVar("TEvent", bound=BaseEvent)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol):
def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
"""Register a function to handle a list of events.
:param events: A list of event classes to handle
:param func: The function to handle the events
"""
for event in events:
local_handler.register(event_name=event.__event_name__, _func=func)
class QueueEvent(BaseEvent, ABC):
"""Base class for queue events"""
__event_type__ = EventType.QUEUE
__event_name__ = "queue_event"
queue_id: str = Field(description="The ID of the queue")
class QueueItemEvent(QueueEvent, ABC):
"""Base class for queue item events"""
__event_name__ = "queue_item_event"
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class SessionEvent(QueueItemEvent, ABC):
"""Base class for session (aka graph execution state) events"""
__event_name__ = "session_event"
session_id: str = Field(description="The ID of the session (aka graph execution state)")
class InvocationEvent(SessionEvent, ABC):
"""Base class for invocation events"""
__event_name__ = "invocation_event"
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation_id: str = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
invocation_type: str = Field(description="The type of invocation")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationStartedEvent(InvocationEvent):
"""Emitted when an invocation is started"""
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationDenoiseProgressEvent(InvocationEvent):
"""Emitted at each step during denoising of an invocation."""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
step: int,
total_steps: int,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
progress_image=progress_image,
step=step,
total_steps=total_steps,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationCompleteEvent(InvocationEvent):
"""Emitted when an invocation is complete"""
__event_name__ = "invocation_complete"
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
result=result,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class InvocationErrorEvent(InvocationEvent):
"""Emitted when an invocation encounters an error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The error type")
error_message: str = Field(description="The error message")
error_traceback: str = Field(description="The error traceback")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
error_type: str,
error_message: str,
error_traceback: str,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionStartedEvent(SessionEvent):
"""Emitted when a session has started"""
__event_name__ = "session_started"
@classmethod
def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCompleteEvent(SessionEvent):
"""Emitted when a session has completed all invocations"""
__event_name__ = "session_complete"
@classmethod
def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class SessionCanceledEvent(SessionEvent):
"""Emitted when a session is canceled"""
__event_name__ = "session_canceled"
@classmethod
def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueItemStatusChangedEvent(QueueItemEvent):
"""Emitted when a queue item's status changes"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
@classmethod
def build(
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
status=queue_item.status,
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BatchEnqueuedEvent(QueueEvent):
"""Emitted when a batch is enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class QueueClearedEvent(QueueEvent):
"""Emitted when a queue is cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
class DownloadEvent(BaseEvent, ABC):
"""Base class for events associated with a download"""
__event_type__ = EventType.DOWNLOAD
__event_name__ = "download_event"
source: str = Field(description="The source of the download")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadStartedEvent(DownloadEvent):
"""Emitted when a download is started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, source: str, download_path: str) -> "DownloadStartedEvent":
return cls(source=source, download_path=download_path)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadProgressEvent(DownloadEvent):
"""Emitted at intervals during a download"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, source: str, download_path: str, current_bytes: int, total_bytes: int) -> "DownloadProgressEvent":
return cls(source=source, download_path=download_path, current_bytes=current_bytes, total_bytes=total_bytes)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCompleteEvent(DownloadEvent):
"""Emitted when a download is completed"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, source: str, download_path: str, total_bytes: int) -> "DownloadCompleteEvent":
return cls(source=source, download_path=download_path, total_bytes=total_bytes)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadCancelledEvent(DownloadEvent):
"""Emitted when a download is cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, source: str) -> "DownloadCancelledEvent":
return cls(source=source)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class DownloadErrorEvent(DownloadEvent):
"""Emitted when a download encounters an error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, source: str, error_type: str, error: str) -> "DownloadErrorEvent":
return cls(source=source, error_type=error_type, error=error)
class ModelEvent(BaseEvent, ABC):
"""Base class for events associated with a model"""
__event_type__ = EventType.MODEL
__event_name__ = "model_event"
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadStartedEvent(ModelEvent):
"""Emitted when a model is requested"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelLoadCompleteEvent(ModelEvent):
"""Emitted when a model is requested"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadProgressEvent(ModelEvent):
"""Emitted at intervals while the install job is in progress (remote models only)."""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallDownloadsCompleteEvent(ModelEvent):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallStartedEvent(ModelEvent):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCompleteEvent(ModelEvent):
"""Emitted when an install job is completed successfully."""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallCancelledEvent(ModelEvent):
"""Emitted when an install job is cancelled."""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class ModelInstallErrorEvent(ModelEvent):
"""Emitted when an install job encounters an exception."""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEvent(BaseEvent, ABC):
"""Base class for events associated with a bulk image download"""
__event_type__ = EventType.BULK_IMAGE_DOWNLOAD
__event_name__ = "bulk_image_download_event"
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadStartedEvent(BulkDownloadEvent):
"""Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadCompleteEvent(BulkDownloadEvent):
"""Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register # pyright: ignore [reportUnknownMemberType]
class BulkDownloadErrorEvent(BulkDownloadEvent):
"""Emitted when a bulk image download is started"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
)

View File

@ -0,0 +1,46 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
BaseEvent,
)
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[BaseEvent | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: BaseEvent) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -1,11 +1,13 @@
"""Initialization file for model install service package.""" """Initialization file for model install service package."""
from .model_install_base import ( from .model_install_base import (
ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase,
ModelSource, ModelSource,
UnknownInstallJobException, UnknownInstallJobException,
URLModelSource, URLModelSource,

View File

@ -1,244 +1,19 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team # Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer.""" """Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
class ModelInstallServiceBase(ABC): class ModelInstallServiceBase(ABC):
@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC):
@property @property
@abstractmethod @abstractmethod
def event_bus(self) -> Optional[EventServiceBase]: def event_bus(self) -> Optional["EventServiceBase"]:
"""Return the event service base object associated with the installer.""" """Return the event service base object associated with the installer."""
@abstractmethod @abstractmethod

View File

@ -0,0 +1,231 @@
import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
base += f":{self.variant or ''}"
base += f":{self.subfolder}" if self.subfolder else ""
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]

View File

@ -10,7 +10,7 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch import torch
import yaml import yaml
@ -20,8 +20,8 @@ from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -45,13 +45,12 @@ from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from .model_install_base import ( from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP, MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase,
ModelSource, ModelSource,
StringLikeSource, StringLikeSource,
URLModelSource, URLModelSource,
@ -59,6 +58,9 @@ from .model_install_base import (
TMPDIR_PREFIX = "tmpinstall_" TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase): class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation.""" """class for InvokeAI model installation."""
@ -68,7 +70,7 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional["EventServiceBase"] = None,
session: Optional[Session] = None, session: Optional[Session] = None,
): ):
""" """
@ -104,7 +106,7 @@ class ModelInstallService(ModelInstallServiceBase):
return self._record_store return self._record_store
@property @property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
return self._event_bus return self._event_bus
# make the invoker optional here because we don't need it and it # make the invoker optional here because we don't need it and it
@ -855,35 +857,17 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.RUNNING job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}") self._logger.info(f"Model install started: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_running(str(job.source)) self._event_bus.emit_model_install_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus: if self._event_bus:
parts: List[Dict[str, str | int]] = [ self._event_bus.emit_model_install_download_progress(job)
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}") self._logger.info(f"Model download complete: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_downloads_done(str(job.source)) self._event_bus.emit_model_install_downloads_complete(job)
def _signal_job_completed(self, job: ModelInstallJob) -> None: def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED job.status = InstallStatus.COMPLETED
@ -891,24 +875,19 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}") self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus: if self._event_bus:
assert job.local_path is not None self._event_bus.emit_model_install_complete(job)
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus: if self._event_bus:
error_type = job.error_type assert job.error_type is not None
error = job.error assert job.error is not None
assert error_type is not None self._event_bus.emit_model_install_error(job)
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None: def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}") self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) self._event_bus.emit_model_install_cancelled(job)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:

View File

@ -4,7 +4,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader.""" """Wrapper around AnyModelLoader."""
@abstractmethod @abstractmethod
def load_model( def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
""" """
@property @property

View File

@ -5,7 +5,6 @@ from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import ( from invokeai.backend.model_manager.load import (
LoadedModel, LoadedModel,
@ -51,25 +50,15 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
return self._convert_cache return self._convert_cache
def load_model( def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
""" """
if context_data:
self._emit_load_event( self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation( loaded_model: LoadedModel = implementation(
@ -79,40 +68,6 @@ class ModelLoadService(ModelLoadServiceBase):
convert_cache=self._convert_cache, convert_cache=self._convert_cache,
).load_model(model_config, submodel_type) ).load_model(model_config, submodel_type)
if context_data: self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
return loaded_model return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@ -4,11 +4,16 @@ from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent from threading import Event as ThreadEvent
from typing import Optional from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
QueueClearedEvent,
QueueEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
register_events,
)
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_base import ( from invokeai.app.services.session_processor.session_processor_base import (
OnAfterRunNode, OnAfterRunNode,
@ -182,12 +187,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# TODO(psyche): This feels jumbled - we should review separation of concerns here. # TODO(psyche): This feels jumbled - we should review separation of concerns here.
# Send complete event. The events service will receive this and update the queue item's status. # Send complete event. The events service will receive this and update the queue item's status.
self._services.events.emit_graph_execution_complete( self._services.events.emit_session_complete(queue_item=queue_item)
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error. # we don't care about that - suppress the error.
@ -208,14 +208,7 @@ class DefaultSessionRunner(SessionRunnerBase):
) )
# Send starting event # Send starting event
self._services.events.emit_invocation_started( self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
for callback in self._on_before_run_node_callbacks: for callback in self._on_before_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item) callback(invocation=invocation, queue_item=queue_item)
@ -230,15 +223,7 @@ class DefaultSessionRunner(SessionRunnerBase):
) )
# Send complete event on successful runs # Send complete event on successful runs
self._services.events.emit_invocation_complete( self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
result=output.model_dump(),
)
for callback in self._on_after_run_node_callbacks: for callback in self._on_after_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item, output=output) callback(invocation=invocation, queue_item=queue_item, output=output)
@ -267,17 +252,11 @@ class DefaultSessionRunner(SessionRunnerBase):
# Send error event # Send error event
self._services.events.emit_invocation_error( self._services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id, queue_item=queue_item,
queue_item_id=queue_item.item_id, invocation=invocation,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=error_type, error_type=error_type,
error_message=error_message, error_message=error_message,
error_traceback=error_traceback, error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
) )
for callback in self._on_node_error_callbacks: for callback in self._on_node_error_callbacks:
@ -315,7 +294,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent() self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent() self._cancel_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) register_events(
events={SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent, QueueItemStatusChangedEvent},
func=self._on_queue_event,
)
self._thread_semaphore = BoundedSemaphore(self._thread_limit) self._thread_semaphore = BoundedSemaphore(self._thread_limit)
@ -350,30 +332,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None: def _poll_now(self) -> None:
self._poll_now_event.set() self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None: async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None:
event_name = event[1]["event"] _event_name, payload = event
if ( if (
event_name == "session_canceled" isinstance(payload, SessionCanceledEvent)
and self._queue_item and self._queue_item
and self._queue_item.item_id == event[1]["data"]["queue_item_id"] and self._queue_item.item_id == payload.item_id
): ):
self._cancel_event.set() self._cancel_event.set()
self._poll_now() self._poll_now()
elif ( elif (
event_name == "queue_cleared" isinstance(payload, QueueClearedEvent)
and self._queue_item and self._queue_item
and self._queue_item.queue_id == event[1]["data"]["queue_id"] and self._queue_item.queue_id == payload.queue_id
): ):
self._cancel_event.set() self._cancel_event.set()
self._poll_now() self._poll_now()
elif event_name == "batch_enqueued": elif isinstance(payload, BatchEnqueuedEvent):
self._poll_now() self._poll_now()
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [ elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ["completed", "failed", "canceled"]:
"completed",
"failed",
"canceled",
]:
self._poll_now() self._poll_now()
def resume(self) -> SessionProcessorStatus: def resume(self) -> SessionProcessorStatus:
@ -422,6 +399,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval) poll_now_event.wait(self._polling_interval)
continue continue
self._invoker.services.events.emit_session_started(self._queue_item)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear() cancel_event.clear()

View File

@ -2,10 +2,13 @@ import sqlite3
import threading import threading
from typing import Optional, Union, cast from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler from invokeai.app.services.events.events_common import (
from fastapi_events.typing import Event as FastAPIEvent FastAPIEvent,
InvocationErrorEvent,
from invokeai.app.services.events.events_base import EventServiceBase SessionCanceledEvent,
SessionCompleteEvent,
register_events,
)
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
@ -42,7 +45,11 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker self.__invoker = invoker
self._set_in_progress_to_canceled() self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID) prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
register_events(events={InvocationErrorEvent}, func=self._handle_error_event)
register_events(events={SessionCompleteEvent}, func=self._handle_complete_event)
register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event)
if prune_result.deleted > 0: if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@ -52,59 +59,41 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn self.__conn = db.conn
self.__cursor = self.__conn.cursor() self.__cursor = self.__conn.cursor()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try: try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event. # When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g. # Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event. # by a previously-handled error event.
queue_item = self.get_queue_item(item_id) _event_name, payload = event
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None: queue_item = self.get_queue_item(payload.item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
self._set_queue_item_status(item_id=payload.item_id, status="completed")
except SessionQueueItemNotFoundError:
pass
async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None:
try: try:
item_id = event[1]["data"]["queue_item_id"] _event_name, payload = event
error_type = event[1]["data"]["error_type"]
error_message = event[1]["data"]["error_message"]
error_traceback = event[1]["data"]["error_traceback"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled # always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status( self._set_queue_item_status(
item_id=queue_item.item_id, item_id=payload.item_id,
status="failed", status="failed",
error_type=error_type, error_type=payload.error_type,
error_message=error_message, error_message=payload.error_message,
error_traceback=error_traceback, error_traceback=payload.error_traceback,
) )
except SessionQueueItemNotFoundError: except SessionQueueItemNotFoundError:
return pass
async def _handle_cancel_event(self, event: FastAPIEvent) -> None: async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
try: try:
item_id = event[1]["data"]["queue_item_id"] _event_name, payload = event
queue_item = self.get_queue_item(item_id) queue_item = self.get_queue_item(payload.item_id)
if queue_item.status not in ["completed", "failed", "canceled"]: if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled") self._set_queue_item_status(item_id=payload.item_id, status="canceled")
except SessionQueueItemNotFoundError: except SessionQueueItemNotFoundError:
return pass
def _set_in_progress_to_canceled(self) -> None: def _set_in_progress_to_canceled(self) -> None:
""" """
@ -306,11 +295,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id) queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
return queue_item return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult: def is_empty(self, queue_id: str) -> IsEmptyResult:
@ -422,12 +407,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id) queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]: if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(queue_item)
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
return queue_item return queue_item
def fail_queue_item( def fail_queue_item(
@ -446,12 +426,7 @@ class SqliteSessionQueue(SessionQueueBase):
error_message=error_message, error_message=error_message,
error_traceback=error_traceback, error_traceback=error_traceback,
) )
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(queue_item)
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
return queue_item return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
@ -487,18 +462,11 @@ class SqliteSessionQueue(SessionQueueBase):
) )
self.__conn.commit() self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids: if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(current_queue_item)
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id) queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item, current_queue_item, batch_status, queue_status
batch_status=batch_status,
queue_status=queue_status,
) )
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
@ -538,18 +506,11 @@ class SqliteSessionQueue(SessionQueueBase):
) )
self.__conn.commit() self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id: if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(current_queue_item)
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id) queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item, current_queue_item, batch_status, queue_status
batch_status=batch_status,
queue_status=queue_status,
) )
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()

View File

@ -353,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str): if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier) model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data) return self._services.model_manager.load.load_model(model, submodel_type)
else: else:
_submodel_type = submodel_type or identifier.submodel_type _submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key) model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data) return self._services.model_manager.load.load_model(model, _submodel_type)
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -382,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1: if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config. """Gets a model's config.

View File

@ -113,15 +113,10 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG") dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress( events.emit_invocation_denoise_progress(
queue_id=context_data.queue_item.queue_id, context_data.queue_item,
queue_item_id=context_data.queue_item.item_id, context_data.invocation,
queue_batch_id=context_data.queue_item.batch_id, intermediate_state.step,
graph_execution_state_id=context_data.queue_item.session_id, intermediate_state.total_steps * intermediate_state.order,
node_id=context_data.invocation.id, ProgressImage(dataURL=dataURL, width=width, height=height),
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
) )

View File

@ -14,10 +14,12 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ( from invokeai.app.services.model_install import (
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase,
URLModelSource, URLModelSource,
) )
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException