From ac56ab79a7f6d2a9ca3e7c76e37f123d11c90b7d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 29 May 2024 21:05:42 +1000 Subject: [PATCH] fix(app): add dynamic validator to AnyInvocation & AnyInvocationOutput This fixes the tests and slightly changes output types. --- invokeai/app/services/shared/graph.py | 22 ++++++++++++++++------ tests/test_node_graph.py | 3 ++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 1a60a8cc0e..d745e73823 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -14,7 +14,7 @@ from pydantic import ( ) from pydantic.fields import Field from pydantic.json_schema import JsonSchemaValue -from pydantic_core import CoreSchema +from pydantic_core import core_schema # Importing * is bad karma but needed here for node detection from invokeai.app.invocations import * # noqa: F401 F403 @@ -280,11 +280,16 @@ class CollectInvocation(BaseInvocation): class AnyInvocation(BaseInvocation): @classmethod - def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): - return BaseInvocation.get_typeadapter().core_schema + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def validate_invocation(v: Any) -> "AnyInvocation": + return BaseInvocation.get_typeadapter().validate_python(v) + + return core_schema.no_info_plain_validator_function(validate_invocation) @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: # Nodes are too powerful, we have to make our own OpenAPI schema manually # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually oneOf: list[dict[str, str]] = [] @@ -297,10 +302,15 @@ class AnyInvocation(BaseInvocation): class AnyInvocationOutput(BaseInvocationOutput): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): - return BaseInvocationOutput.get_typeadapter().core_schema + def validate_invocation_output(v: Any) -> "AnyInvocationOutput": + return BaseInvocationOutput.get_typeadapter().validate_python(v) + + return core_schema.no_info_plain_validator_function(validate_invocation_output) @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: # Nodes are too powerful, we have to make our own OpenAPI schema manually # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py index 87a4948af4..861f1bd07b 100644 --- a/tests/test_node_graph.py +++ b/tests/test_node_graph.py @@ -1,5 +1,6 @@ import pytest from pydantic import TypeAdapter +from pydantic.json_schema import models_json_schema from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -713,4 +714,4 @@ def test_iterate_accepts_collection(): def test_graph_can_generate_schema(): # Not throwing on this line is sufficient # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation - _ = Graph.model_json_schema() + models_json_schema([(Graph, "serialization")])