mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(app): openapi schema generation
Some tech debt related to dynamic pydantic schemas for invocations became problematic. Including the invocations and results in the event schemas was breaking pydantic's handling of ref schemas. I don't really understand why - I think it's a pydantic bug in a remote edge case that we are hitting. After many failed attempts I landed on this implementation, which is actually much tidier than what was in there before. - Create pydantic-enabled types for `AnyInvocation` and `AnyInvocationOutput` and use these in place of the janky dynamic unions. Actually, they are kinda the same, but better encapsulated. Use these in `Graph`, `GraphExecutionState`, `InvocationEventBase` and `InvocationCompleteEvent`. - Revise the custom openapi function to work with the new models. - Split out the custom openapi function to a separate file. Add a `post_transform` callback so consumers can customize the output schema. - Update makefile scripts.
This commit is contained in:
@ -113,10 +113,10 @@ class BaseInvocationOutput(BaseModel):
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationOutputsUnion = TypeAliasType(
|
||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@ -125,12 +125,13 @@ class BaseInvocationOutput(BaseModel):
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "output"
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
@classmethod
|
||||
@ -182,10 +183,10 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@ -221,7 +222,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
@ -237,6 +238,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "invocation"
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
@ -310,7 +312,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
coerce_numbers_to_str=True,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user