refactor(events): use pydantic schemas for events

Our events handling and implementation has a couple pain points:
- Adding or removing data from event payloads requires changes wherever the events are dispatched from.
- We have no type safety for events and need to rely on string matching and dict access when interacting with events.
- Frontend types for socket events must be manually typed. This has caused several bugs.

`fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly.

This allows us to eliminate a layer of indirection and some unpleasant complexity:
- Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed.
- Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload.
- Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync.

This commit moves the backend over to this improved event handling setup.
This commit is contained in:
psychedelicious
2024-03-14 19:04:19 +11:00
parent 461e857824
commit 9bd78823a3
21 changed files with 1263 additions and 1025 deletions

View File

@ -5,7 +5,7 @@ import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
from typing import Any, cast
import torch
import uvicorn
@ -17,6 +17,8 @@ from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.registry.payload_schema import registry as fastapi_events_registry
from pydantic import BaseModel
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
@ -182,23 +184,16 @@ def custom_openapi() -> dict[str, Any]:
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
# Add all pydantic event schemas registered with fastapi-events
for payload in fastapi_events_registry.data.values():
json_schema = cast(BaseModel, payload).model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][payload.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema