mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/feat/simple-mm2-api
This commit is contained in:
@ -3,9 +3,7 @@ import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
@ -13,11 +11,9 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
@ -45,11 +39,6 @@ from .api.routers import (
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = {}
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = models_json_schema(
|
||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions
|
||||
additional_schemas = models_json_schema(
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
(ModelIdentifierField, "serialization"),
|
||||
(ProgressImage, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# Add all event schemas
|
||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
if "$defs" in json_schema:
|
||||
for schema_key, schema in json_schema["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
||||
del json_schema["$defs"]
|
||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
|
@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationOutputsUnion = TypeAliasType(
|
||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "output"
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
@classmethod
|
||||
@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "invocation"
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
coerce_numbers_to_str=True,
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
from math import floor
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
@ -33,6 +33,7 @@ class EventBase(BaseModel):
|
||||
A timestamp is automatically added to the event when it is created.
|
||||
"""
|
||||
|
||||
__event_name__: ClassVar[str]
|
||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||
|
||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||
@ -97,7 +98,7 @@ class InvocationEventBase(QueueItemEventBase):
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
|
||||
|
||||
@ -108,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
__event_name__ = "invocation_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
|
||||
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
@ -135,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
invocation: AnyInvocation,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: ProgressImage,
|
||||
) -> "InvocationDenoiseProgressEvent":
|
||||
@ -173,11 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
|
||||
__event_name__ = "invocation_complete"
|
||||
|
||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
||||
result: AnyInvocationOutput = Field(description="The result of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
||||
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
|
||||
) -> "InvocationCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -206,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
invocation: AnyInvocation,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
|
@ -2,18 +2,19 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic.fields import Field
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import CoreSchema
|
||||
from pydantic_core import core_schema
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from invokeai.app.invocations import * # noqa: F401 F403
|
||||
@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@field_validator("nodes", mode="plain")
|
||||
@classmethod
|
||||
def validate_nodes(cls, v: dict[str, Any]):
|
||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
||||
|
||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
||||
#
|
||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
||||
# invocations will cause a graph to fail if they are used.
|
||||
#
|
||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
nodes: dict[str, BaseInvocation] = {}
|
||||
typeadapter = BaseInvocation.get_typeadapter()
|
||||
for node_id, node in v.items():
|
||||
nodes[node_id] = typeadapter.validate_python(node)
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
||||
# the generated schema as options for the `nodes` field.
|
||||
#
|
||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
||||
# expected.
|
||||
#
|
||||
# You might be tempted to do something like this:
|
||||
#
|
||||
# ```py
|
||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
||||
# delattr(cloned_model, "validate_nodes")
|
||||
# cloned_model.model_rebuild(force=True)
|
||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
||||
# ```
|
||||
#
|
||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
||||
nodes: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The nodes in this graph")
|
||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
||||
|
||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@field_validator("results", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
||||
|
||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
||||
results: dict[str, BaseInvocationOutput] = {}
|
||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||
for result_id, result in v.items():
|
||||
results[result_id] = typeadapter.validate_python(result)
|
||||
return results
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state")
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution"
|
||||
)
|
||||
results: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions")
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
description="The map of prepared nodes to original graph nodes"
|
||||
)
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
||||
description="The map of original graph nodes to prepared nodes"
|
||||
)
|
||||
|
||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
|
116
invokeai/app/util/custom_openapi.py
Normal file
116
invokeai/app/util/custom_openapi.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
|
||||
|
||||
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
|
||||
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
|
||||
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
|
||||
component_schema."""
|
||||
|
||||
defs = component_schema.pop("$defs", {})
|
||||
for schema_key, json_schema in defs.items():
|
||||
if schema_key in openapi_schema["components"]["schemas"]:
|
||||
continue
|
||||
openapi_schema["components"]["schemas"][schema_key] = json_schema
|
||||
|
||||
|
||||
def get_openapi_func(
|
||||
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
|
||||
) -> Callable[[], dict[str, Any]]:
|
||||
"""Gets the OpenAPI schema generator function.
|
||||
|
||||
Args:
|
||||
app (FastAPI): The FastAPI app to generate the schema for.
|
||||
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
|
||||
generated schema before returning it. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
|
||||
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
|
||||
matches FastAPI's default schema generation caching.
|
||||
"""
|
||||
|
||||
def openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# We'll create a map of invocation type to output schema to make some types simpler on the client.
|
||||
invocation_output_map_properties: dict[str, Any] = {}
|
||||
invocation_output_map_required: list[str] = []
|
||||
|
||||
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||
for output in BaseInvocationOutput.get_outputs():
|
||||
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||
|
||||
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||
# property, so we'll just do it all manually.
|
||||
for invocation in BaseInvocation.get_invocations():
|
||||
json_schema = invocation.model_json_schema(
|
||||
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
output_title = invocation.get_output_annotation().__name__
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
|
||||
json_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
|
||||
|
||||
# Add this invocation and its output to the output map
|
||||
invocation_type = invocation.get_type()
|
||||
invocation_output_map_properties[invocation_type] = json_schema["output"]
|
||||
invocation_output_map_required.append(invocation_type)
|
||||
|
||||
# Add the output map to the schema
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": invocation_output_map_properties,
|
||||
"required": invocation_output_map_required,
|
||||
}
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
|
||||
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
|
||||
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
|
||||
# `models_json_schema` seems to work fine.
|
||||
additional_models = [
|
||||
*EventBase.get_events(),
|
||||
UIConfigBase,
|
||||
InputFieldJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
ModelIdentifierField,
|
||||
ProgressImage,
|
||||
]
|
||||
|
||||
additional_schemas = models_json_schema(
|
||||
[(m, "serialization") for m in additional_models],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
|
||||
move_defs_to_top_level(openapi_schema, additional_schemas[1])
|
||||
|
||||
if post_transform is not None:
|
||||
openapi_schema = post_transform(openapi_schema)
|
||||
|
||||
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
return openapi
|
Reference in New Issue
Block a user