mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: openapi stuff (#6454)
## Summary Fix some issues with openapi schema generation. See commits for details. ## Related Issues / Discussions https://discord.com/channels/1020123559063990373/1049495067846524939/1245141831394529352 ## QA Instructions App should work, workflows should work. ## Merge Plan n/a ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
cfb12615e1
4
Makefile
4
Makefile
@ -18,6 +18,7 @@ help:
|
|||||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||||
@echo "installer-zip Build the installer .zip file for the current version"
|
@echo "installer-zip Build the installer .zip file for the current version"
|
||||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||||
|
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||||
|
|
||||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||||
ruff:
|
ruff:
|
||||||
@ -70,3 +71,6 @@ installer-zip:
|
|||||||
tag-release:
|
tag-release:
|
||||||
cd installer && ./tag_release.sh
|
cd installer && ./tag_release.sh
|
||||||
|
|
||||||
|
# Generate the OpenAPI Schema for the app
|
||||||
|
openapi:
|
||||||
|
python scripts/generate_openapi_schema.py
|
||||||
|
@ -3,9 +3,7 @@ import logging
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from inspect import signature
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -13,11 +11,9 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.json_schema import models_json_schema
|
|
||||||
from torch.backends.mps import is_available as is_mps_available
|
from torch.backends.mps import is_available as is_mps_available
|
||||||
|
|
||||||
# for PyCharm:
|
# for PyCharm:
|
||||||
@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
|
|||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.services.events.events_common import EventBase
|
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
@ -45,11 +39,6 @@ from .api.routers import (
|
|||||||
workflows,
|
workflows,
|
||||||
)
|
)
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
UIConfigBase,
|
|
||||||
)
|
|
||||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
|
|||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
app.include_router(workflows.workflows_router, prefix="/api")
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
|
|
||||||
|
app.openapi = get_openapi_func(app)
|
||||||
# Build a custom OpenAPI to include all outputs
|
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
|
||||||
def custom_openapi() -> dict[str, Any]:
|
|
||||||
if app.openapi_schema:
|
|
||||||
return app.openapi_schema
|
|
||||||
openapi_schema = get_openapi(
|
|
||||||
title=app.title,
|
|
||||||
description="An API for invoking AI image operations",
|
|
||||||
version="1.0.0",
|
|
||||||
routes=app.routes,
|
|
||||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add all outputs
|
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
|
||||||
output_types = set()
|
|
||||||
output_type_titles = {}
|
|
||||||
for invoker in all_invocations:
|
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
|
||||||
output_types.add(output_type)
|
|
||||||
|
|
||||||
output_schemas = models_json_schema(
|
|
||||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
|
||||||
)
|
|
||||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
|
||||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
|
||||||
# This could break in some cases, figure out a better way to do it
|
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
|
||||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
|
||||||
|
|
||||||
# Some models don't end up in the schemas as standalone definitions
|
|
||||||
additional_schemas = models_json_schema(
|
|
||||||
[
|
|
||||||
(UIConfigBase, "serialization"),
|
|
||||||
(InputFieldJSONSchemaExtra, "serialization"),
|
|
||||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
|
||||||
(ModelIdentifierField, "serialization"),
|
|
||||||
(ProgressImage, "serialization"),
|
|
||||||
],
|
|
||||||
ref_template="#/components/schemas/{model}",
|
|
||||||
)
|
|
||||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
|
||||||
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
|
||||||
for invoker in all_invocations:
|
|
||||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
|
||||||
output_type = signature(obj=invoker.invoke).return_annotation
|
|
||||||
output_type_title = output_type_titles[output_type.__name__]
|
|
||||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
|
||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
|
||||||
invoker_schema["output"] = outputs_ref
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
|
||||||
invoker_schema["class"] = "invocation"
|
|
||||||
|
|
||||||
# Add all event schemas
|
|
||||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
|
||||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
|
||||||
if "$defs" in json_schema:
|
|
||||||
for schema_key, schema in json_schema["$defs"].items():
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
|
||||||
del json_schema["$defs"]
|
|
||||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
|
||||||
return app.openapi_schema
|
|
||||||
|
|
||||||
|
|
||||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
|
|
||||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
_typeadapter_needs_update: ClassVar[bool] = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||||
"""Registers an invocation output."""
|
"""Registers an invocation output."""
|
||||||
cls._output_classes.add(output)
|
cls._output_classes.add(output)
|
||||||
|
cls._typeadapter_needs_update = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||||
@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||||
if not cls._typeadapter:
|
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||||
InvocationOutputsUnion = TypeAliasType(
|
AnyInvocationOutput = TypeAliasType(
|
||||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||||
)
|
)
|
||||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||||
|
cls._typeadapter_needs_update = False
|
||||||
return cls._typeadapter
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
|
schema["class"] = "output"
|
||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
_typeadapter_needs_update: ClassVar[bool] = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_type(cls) -> str:
|
def get_type(cls) -> str:
|
||||||
@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||||
"""Registers an invocation."""
|
"""Registers an invocation."""
|
||||||
cls._invocation_classes.add(invocation)
|
cls._invocation_classes.add(invocation)
|
||||||
|
cls._typeadapter_needs_update = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||||
if not cls._typeadapter:
|
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||||
InvocationsUnion = TypeAliasType(
|
AnyInvocation = TypeAliasType(
|
||||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||||
)
|
)
|
||||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||||
|
cls._typeadapter_needs_update = False
|
||||||
return cls._typeadapter
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||||
if uiconfig is not None:
|
if uiconfig is not None:
|
||||||
@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
|
schema["class"] = "invocation"
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
json_schema_extra=json_schema_extra,
|
json_schema_extra=json_schema_extra,
|
||||||
json_schema_serialization_defaults_required=True,
|
json_schema_serialization_defaults_required=False,
|
||||||
coerce_numbers_to_str=True,
|
coerce_numbers_to_str=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,9 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, P
|
|||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
QUEUE_ITEM_STATUS,
|
QUEUE_ITEM_STATUS,
|
||||||
@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItem,
|
SessionQueueItem,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase):
|
|||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||||
|
|
||||||
@field_validator("invocation", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_invocation(cls, v: Any):
|
|
||||||
"""Validates the invocation using the dynamic type adapter."""
|
|
||||||
|
|
||||||
invocation = BaseInvocation.get_typeadapter().validate_python(v)
|
|
||||||
return invocation
|
|
||||||
|
|
||||||
|
|
||||||
@payload_schema.register
|
@payload_schema.register
|
||||||
class InvocationStartedEvent(InvocationEventBase):
|
class InvocationStartedEvent(InvocationEventBase):
|
||||||
@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
|||||||
__event_name__ = "invocation_started"
|
__event_name__ = "invocation_started"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
|
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
|||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
invocation: BaseInvocation,
|
invocation: AnyInvocation,
|
||||||
intermediate_state: PipelineIntermediateState,
|
intermediate_state: PipelineIntermediateState,
|
||||||
progress_image: ProgressImage,
|
progress_image: ProgressImage,
|
||||||
) -> "InvocationDenoiseProgressEvent":
|
) -> "InvocationDenoiseProgressEvent":
|
||||||
@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
|
|
||||||
__event_name__ = "invocation_complete"
|
__event_name__ = "invocation_complete"
|
||||||
|
|
||||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
result: AnyInvocationOutput = Field(description="The result of the invocation")
|
||||||
|
|
||||||
@field_validator("result", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_results(cls, v: Any):
|
|
||||||
"""Validates the invocation result using the dynamic type adapter."""
|
|
||||||
|
|
||||||
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
|
||||||
) -> "InvocationCompleteEvent":
|
) -> "InvocationCompleteEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
|||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
invocation: BaseInvocation,
|
invocation: AnyInvocation,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
error_traceback: str,
|
error_traceback: str,
|
||||||
|
@ -2,18 +2,19 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
GetCoreSchemaHandler,
|
||||||
GetJsonSchemaHandler,
|
GetJsonSchemaHandler,
|
||||||
ValidationError,
|
ValidationError,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from pydantic_core import CoreSchema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from invokeai.app.invocations import * # noqa: F401 F403
|
from invokeai.app.invocations import * # noqa: F401 F403
|
||||||
@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation):
|
|||||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||||
|
|
||||||
|
|
||||||
|
class AnyInvocation(BaseInvocation):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||||
|
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||||
|
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
|
oneOf: list[dict[str, str]] = []
|
||||||
|
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||||
|
for name in sorted(names):
|
||||||
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||||
|
return {"oneOf": oneOf}
|
||||||
|
|
||||||
|
|
||||||
|
class AnyInvocationOutput(BaseInvocationOutput):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||||
|
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||||
|
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
|
|
||||||
|
oneOf: list[dict[str, str]] = []
|
||||||
|
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||||
|
for name in sorted(names):
|
||||||
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||||
|
return {"oneOf": oneOf}
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||||
edges: list[Edge] = Field(
|
edges: list[Edge] = Field(
|
||||||
description="The connections between nodes and their fields in this graph",
|
description="The connections between nodes and their fields in this graph",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("nodes", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_nodes(cls, v: dict[str, Any]):
|
|
||||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
|
||||||
|
|
||||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
|
||||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
|
||||||
#
|
|
||||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
|
||||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
|
||||||
# invocations will cause a graph to fail if they are used.
|
|
||||||
#
|
|
||||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
|
||||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
|
||||||
#
|
|
||||||
# This same pattern is used in `GraphExecutionState`.
|
|
||||||
|
|
||||||
nodes: dict[str, BaseInvocation] = {}
|
|
||||||
typeadapter = BaseInvocation.get_typeadapter()
|
|
||||||
for node_id, node in v.items():
|
|
||||||
nodes[node_id] = typeadapter.validate_python(node)
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
|
||||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
|
||||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
|
||||||
# the generated schema as options for the `nodes` field.
|
|
||||||
#
|
|
||||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
|
||||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
|
||||||
# expected.
|
|
||||||
#
|
|
||||||
# You might be tempted to do something like this:
|
|
||||||
#
|
|
||||||
# ```py
|
|
||||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
|
||||||
# delattr(cloned_model, "validate_nodes")
|
|
||||||
# cloned_model.model_rebuild(force=True)
|
|
||||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
|
||||||
# ```
|
|
||||||
#
|
|
||||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
|
||||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
|
||||||
#
|
|
||||||
# This same pattern is used in `GraphExecutionState`.
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
|
||||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
|
||||||
nodes: dict[
|
|
||||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
|
||||||
] = Field(description="The nodes in this graph")
|
|
||||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
|
||||||
|
|
||||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||||
|
|
||||||
# Errors raised when executing nodes
|
# Errors raised when executing nodes
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||||
@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("results", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
|
||||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
|
||||||
|
|
||||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
|
||||||
results: dict[str, BaseInvocationOutput] = {}
|
|
||||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
|
||||||
for result_id, result in v.items():
|
|
||||||
results[result_id] = typeadapter.validate_python(result)
|
|
||||||
return results
|
|
||||||
|
|
||||||
@field_validator("graph")
|
@field_validator("graph")
|
||||||
def graph_is_valid(cls, v: Graph):
|
def graph_is_valid(cls, v: Graph):
|
||||||
"""Validates that the graph is valid"""
|
"""Validates that the graph is valid"""
|
||||||
v.validate_self()
|
v.validate_self()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
|
||||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
|
||||||
class GraphExecutionState(BaseModel):
|
|
||||||
"""Tracks the state of a graph execution"""
|
|
||||||
|
|
||||||
id: str = Field(description="The id of the execution state")
|
|
||||||
graph: Graph = Field(description="The graph being executed")
|
|
||||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
|
||||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
|
||||||
executed_history: list[str] = Field(
|
|
||||||
description="The list of node ids that have been executed, in order of execution"
|
|
||||||
)
|
|
||||||
results: dict[
|
|
||||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
|
||||||
] = Field(description="The results of node executions")
|
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
|
||||||
prepared_source_mapping: dict[str, str] = Field(
|
|
||||||
description="The map of prepared nodes to original graph nodes"
|
|
||||||
)
|
|
||||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
|
||||||
description="The map of original graph nodes to prepared nodes"
|
|
||||||
)
|
|
||||||
|
|
||||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
|
|
||||||
|
116
invokeai/app/util/custom_openapi.py
Normal file
116
invokeai/app/util/custom_openapi.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||||
|
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.events.events_common import EventBase
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
|
|
||||||
|
|
||||||
|
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
|
||||||
|
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
|
||||||
|
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
|
||||||
|
component_schema."""
|
||||||
|
|
||||||
|
defs = component_schema.pop("$defs", {})
|
||||||
|
for schema_key, json_schema in defs.items():
|
||||||
|
if schema_key in openapi_schema["components"]["schemas"]:
|
||||||
|
continue
|
||||||
|
openapi_schema["components"]["schemas"][schema_key] = json_schema
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_func(
|
||||||
|
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
|
||||||
|
) -> Callable[[], dict[str, Any]]:
|
||||||
|
"""Gets the OpenAPI schema generator function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app (FastAPI): The FastAPI app to generate the schema for.
|
||||||
|
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
|
||||||
|
generated schema before returning it. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
|
||||||
|
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
|
||||||
|
matches FastAPI's default schema generation caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def openapi() -> dict[str, Any]:
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
description="An API for invoking AI image operations",
|
||||||
|
version="1.0.0",
|
||||||
|
routes=app.routes,
|
||||||
|
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll create a map of invocation type to output schema to make some types simpler on the client.
|
||||||
|
invocation_output_map_properties: dict[str, Any] = {}
|
||||||
|
invocation_output_map_required: list[str] = []
|
||||||
|
|
||||||
|
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||||
|
for output in BaseInvocationOutput.get_outputs():
|
||||||
|
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||||
|
move_defs_to_top_level(openapi_schema, json_schema)
|
||||||
|
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||||
|
|
||||||
|
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||||
|
# property, so we'll just do it all manually.
|
||||||
|
for invocation in BaseInvocation.get_invocations():
|
||||||
|
json_schema = invocation.model_json_schema(
|
||||||
|
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||||
|
)
|
||||||
|
move_defs_to_top_level(openapi_schema, json_schema)
|
||||||
|
output_title = invocation.get_output_annotation().__name__
|
||||||
|
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
|
||||||
|
json_schema["output"] = outputs_ref
|
||||||
|
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
|
||||||
|
|
||||||
|
# Add this invocation and its output to the output map
|
||||||
|
invocation_type = invocation.get_type()
|
||||||
|
invocation_output_map_properties[invocation_type] = json_schema["output"]
|
||||||
|
invocation_output_map_required.append(invocation_type)
|
||||||
|
|
||||||
|
# Add the output map to the schema
|
||||||
|
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": invocation_output_map_properties,
|
||||||
|
"required": invocation_output_map_required,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
|
||||||
|
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
|
||||||
|
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
|
||||||
|
# `models_json_schema` seems to work fine.
|
||||||
|
additional_models = [
|
||||||
|
*EventBase.get_events(),
|
||||||
|
UIConfigBase,
|
||||||
|
InputFieldJSONSchemaExtra,
|
||||||
|
OutputFieldJSONSchemaExtra,
|
||||||
|
ModelIdentifierField,
|
||||||
|
ProgressImage,
|
||||||
|
]
|
||||||
|
|
||||||
|
additional_schemas = models_json_schema(
|
||||||
|
[(m, "serialization") for m in additional_models],
|
||||||
|
ref_template="#/components/schemas/{model}",
|
||||||
|
)
|
||||||
|
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
|
||||||
|
move_defs_to_top_level(openapi_schema, additional_schemas[1])
|
||||||
|
|
||||||
|
if post_transform is not None:
|
||||||
|
openapi_schema = post_transform(openapi_schema)
|
||||||
|
|
||||||
|
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
|
||||||
|
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
return openapi
|
@ -13,7 +13,6 @@ import {
|
|||||||
isControlAdapterLayer,
|
isControlAdapterLayer,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
@ -139,7 +138,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
|||||||
|
|
||||||
// We still have to check the output type
|
// We still have to check the output type
|
||||||
assert(
|
assert(
|
||||||
isImageOutput(invocationCompleteAction.payload.data.result),
|
invocationCompleteAction.payload.data.result.type === 'image_output',
|
||||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
||||||
);
|
);
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||||
|
@ -9,7 +9,6 @@ import {
|
|||||||
selectControlAdapterById,
|
selectControlAdapterById,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@ -74,7 +73,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
|||||||
);
|
);
|
||||||
|
|
||||||
// We still have to check the output type
|
// We still have to check the output type
|
||||||
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||||
|
|
||||||
// Wait for the ImageDTO to be received
|
// Wait for the ImageDTO to be received
|
||||||
|
@ -11,7 +11,6 @@ import {
|
|||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
@ -33,7 +32,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
|||||||
|
|
||||||
const { result, invocation_source_id } = data;
|
const { result, invocation_source_id } = data;
|
||||||
// This complete event has an associated image output
|
// This complete event has an associated image output
|
||||||
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
|
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
|
||||||
const { image_name } = data.result.image;
|
const { image_name } = data.result.image;
|
||||||
const { canvas, gallery } = getState();
|
const { canvas, gallery } = getState();
|
||||||
|
|
||||||
|
@ -11,8 +11,7 @@ import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
|||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { ImageOutput } from 'services/api/types';
|
import type { AnyInvocationOutput, ImageOutput } from 'services/api/types';
|
||||||
import type { AnyResult } from 'services/events/types';
|
|
||||||
|
|
||||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||||
|
|
||||||
@ -66,4 +65,4 @@ const InspectorOutputsTab = () => {
|
|||||||
|
|
||||||
export default memo(InspectorOutputsTab);
|
export default memo(InspectorOutputsTab);
|
||||||
|
|
||||||
const getKey = (result: AnyResult, i: number) => `${result.type}-${i}`;
|
const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`;
|
||||||
|
@ -144,5 +144,4 @@ const zImageOutput = z.object({
|
|||||||
type: z.literal('image_output'),
|
type: z.literal('image_output'),
|
||||||
});
|
});
|
||||||
export type ImageOutput = z.infer<typeof zImageOutput>;
|
export type ImageOutput = z.infer<typeof zImageOutput>;
|
||||||
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
|
|
||||||
// #endregion
|
// #endregion
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import type { NodesState } from 'features/nodes/store/types';
|
import type { NodesState } from 'features/nodes/store/types';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { omit, reduce } from 'lodash-es';
|
import { omit, reduce } from 'lodash-es';
|
||||||
import type { Graph } from 'services/api/types';
|
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||||
import type { AnyInvocation } from 'services/events/types';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
File diff suppressed because one or more lines are too long
@ -122,7 +122,6 @@ export type ModelInstallStatus = S['InstallStatus'];
|
|||||||
// Graphs
|
// Graphs
|
||||||
export type Graph = S['Graph'];
|
export type Graph = S['Graph'];
|
||||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||||
export type GraphExecutionState = S['GraphExecutionState'];
|
|
||||||
export type Batch = S['Batch'];
|
export type Batch = S['Batch'];
|
||||||
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
|
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
|
||||||
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
|
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
|
||||||
@ -132,14 +131,14 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
|
|||||||
type KeysOfUnion<T> = T extends T ? keyof T : never;
|
type KeysOfUnion<T> = T extends T ? keyof T : never;
|
||||||
|
|
||||||
export type AnyInvocation = Exclude<
|
export type AnyInvocation = Exclude<
|
||||||
Graph['nodes'][string],
|
NonNullable<S['Graph']['nodes']>[string],
|
||||||
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
|
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
|
||||||
>;
|
>;
|
||||||
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
|
export type AnyInvocationIncMetadata = NonNullable<S['Graph']['nodes']>[string];
|
||||||
|
|
||||||
export type InvocationType = AnyInvocation['type'];
|
export type InvocationType = AnyInvocation['type'];
|
||||||
type InvocationOutputMap = S['InvocationOutputMap'];
|
type InvocationOutputMap = S['InvocationOutputMap'];
|
||||||
type AnyInvocationOutput = InvocationOutputMap[InvocationType];
|
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
|
||||||
|
|
||||||
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
|
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
|
||||||
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
|
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
|
||||||
|
@ -1,21 +1,12 @@
|
|||||||
import type { Graph, GraphExecutionState, S } from 'services/api/types';
|
import type { S } from 'services/api/types';
|
||||||
|
|
||||||
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
|
|
||||||
|
|
||||||
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
|
|
||||||
|
|
||||||
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
||||||
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
||||||
|
|
||||||
export type InvocationStartedEvent = Omit<S['InvocationStartedEvent'], 'invocation'> & { invocation: AnyInvocation };
|
export type InvocationStartedEvent = S['InvocationStartedEvent'];
|
||||||
export type InvocationDenoiseProgressEvent = Omit<S['InvocationDenoiseProgressEvent'], 'invocation'> & {
|
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
|
||||||
invocation: AnyInvocation;
|
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
|
||||||
};
|
export type InvocationErrorEvent = S['InvocationErrorEvent'];
|
||||||
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result' | 'invocation'> & {
|
|
||||||
result: AnyResult;
|
|
||||||
invocation: AnyInvocation;
|
|
||||||
};
|
|
||||||
export type InvocationErrorEvent = Omit<S['InvocationErrorEvent'], 'invocation'> & { invocation: AnyInvocation };
|
|
||||||
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
||||||
|
|
||||||
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
||||||
|
@ -55,10 +55,10 @@ dependencies = [
|
|||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.11.0",
|
"fastapi-events==0.11.0",
|
||||||
"fastapi==0.110.0",
|
"fastapi==0.111.0",
|
||||||
"huggingface-hub==0.23.1",
|
"huggingface-hub==0.23.1",
|
||||||
"pydantic-settings==2.2.1",
|
"pydantic-settings==2.2.1",
|
||||||
"pydantic==2.6.3",
|
"pydantic==2.7.2",
|
||||||
"python-socketio==5.11.1",
|
"python-socketio==5.11.1",
|
||||||
"uvicorn[standard]==0.28.0",
|
"uvicorn[standard]==0.28.0",
|
||||||
|
|
||||||
|
@ -7,9 +7,10 @@ def main():
|
|||||||
# Change working directory to the repo root
|
# Change working directory to the repo root
|
||||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
from invokeai.app.api_app import custom_openapi
|
from invokeai.app.api_app import app
|
||||||
|
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||||
|
|
||||||
schema = custom_openapi()
|
schema = get_openapi_func(app)()
|
||||||
json.dump(schema, sys.stdout, indent=2)
|
json.dump(schema, sys.stdout, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
|
|||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
_ = Graph.model_json_schema()
|
models_json_schema([(Graph, "serialization")])
|
||||||
|
Loading…
Reference in New Issue
Block a user