fix(app): add dynamic validator to AnyInvocation & AnyInvocationOutput

This fixes the tests and slightly changes output types.
This commit is contained in:
psychedelicious 2024-05-29 21:05:42 +10:00
parent 50d3030471
commit ac56ab79a7
2 changed files with 18 additions and 7 deletions

View File

@ -14,7 +14,7 @@ from pydantic import (
) )
from pydantic.fields import Field from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema from pydantic_core import core_schema
# Importing * is bad karma but needed here for node detection # Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations import * # noqa: F401 F403
@ -280,11 +280,16 @@ class CollectInvocation(BaseInvocation):
class AnyInvocation(BaseInvocation): class AnyInvocation(BaseInvocation):
@classmethod @classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return BaseInvocation.get_typeadapter().core_schema def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod @classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually # Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = [] oneOf: list[dict[str, str]] = []
@ -297,10 +302,15 @@ class AnyInvocation(BaseInvocation):
class AnyInvocationOutput(BaseInvocationOutput): class AnyInvocationOutput(BaseInvocationOutput):
@classmethod @classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
return BaseInvocationOutput.get_typeadapter().core_schema def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod @classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually # Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually

View File

@ -1,5 +1,6 @@
import pytest import pytest
from pydantic import TypeAdapter from pydantic import TypeAdapter
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
def test_graph_can_generate_schema(): def test_graph_can_generate_schema():
# Not throwing on this line is sufficient # Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
_ = Graph.model_json_schema() models_json_schema([(Graph, "serialization")])