mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(app): add dynamic validator to AnyInvocation & AnyInvocationOutput
This fixes the tests and slightly changes output types.
This commit is contained in:
parent
50d3030471
commit
ac56ab79a7
@ -14,7 +14,7 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from pydantic_core import CoreSchema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from invokeai.app.invocations import * # noqa: F401 F403
|
from invokeai.app.invocations import * # noqa: F401 F403
|
||||||
@ -280,11 +280,16 @@ class CollectInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class AnyInvocation(BaseInvocation):
|
class AnyInvocation(BaseInvocation):
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||||
return BaseInvocation.get_typeadapter().core_schema
|
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||||
|
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
oneOf: list[dict[str, str]] = []
|
oneOf: list[dict[str, str]] = []
|
||||||
@ -297,10 +302,15 @@ class AnyInvocation(BaseInvocation):
|
|||||||
class AnyInvocationOutput(BaseInvocationOutput):
|
class AnyInvocationOutput(BaseInvocationOutput):
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||||
return BaseInvocationOutput.get_typeadapter().core_schema
|
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||||
|
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
|
|||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
_ = Graph.model_json_schema()
|
models_json_schema([(Graph, "serialization")])
|
||||||
|
Loading…
Reference in New Issue
Block a user