fix(app): add dynamic validator to AnyInvocation & AnyInvocationOutput

This fixes the tests and slightly changes output types.
This commit is contained in:
psychedelicious
2024-05-29 21:05:42 +10:00
parent 50d3030471
commit ac56ab79a7
2 changed files with 18 additions and 7 deletions

View File

@ -14,7 +14,7 @@ from pydantic import (
)
from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
from pydantic_core import core_schema
# Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403
@ -280,11 +280,16 @@ class CollectInvocation(BaseInvocation):
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
return BaseInvocation.get_typeadapter().core_schema
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
@ -297,10 +302,15 @@ class AnyInvocation(BaseInvocation):
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
return BaseInvocationOutput.get_typeadapter().core_schema
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually