mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(app): add dynamic validator to AnyInvocation & AnyInvocationOutput
This fixes the tests and slightly changes output types.
This commit is contained in:
parent
50d3030471
commit
ac56ab79a7
@ -14,7 +14,7 @@ from pydantic import (
|
||||
)
|
||||
from pydantic.fields import Field
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import CoreSchema
|
||||
from pydantic_core import core_schema
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from invokeai.app.invocations import * # noqa: F401 F403
|
||||
@ -280,11 +280,16 @@ class CollectInvocation(BaseInvocation):
|
||||
|
||||
class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
return BaseInvocation.get_typeadapter().core_schema
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
@ -297,10 +302,15 @@ class AnyInvocation(BaseInvocation):
|
||||
class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
return BaseInvocationOutput.get_typeadapter().core_schema
|
||||
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
_ = Graph.model_json_schema()
|
||||
models_json_schema([(Graph, "serialization")])
|
||||
|
Loading…
Reference in New Issue
Block a user