fix(nodes): fix schema generation for output classes

All output classes need to have their properties flagged as `required` for the schema generation to work as needed.
This commit is contained in:
psychedelicious 2023-03-26 13:26:59 +11:00
parent c34ac91ff0
commit 4221cf7731
3 changed files with 46 additions and 1 deletions

View File

@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
class MaskOutput(BaseInvocationOutput): class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask""" """Base class for invocations that output a mask"""
#fmt: off #fmt: off
type: Literal["mask"] = "mask" type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask") mask: ImageField = Field(default=None, description="The output mask")
#fomt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
]
}
# TODO: this isn't really necessary anymore # TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):

View File

@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
prompt: str = Field(default=None, description="The output prompt") prompt: str = Field(default=None, description="The output prompt")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'prompt',
]
}

View File

@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
class GraphInvocationOutput(BaseInvocationOutput): class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output" type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation): class GraphInvocation(BaseInvocation):
@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = Field(description="The item being iterated over") item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation): class IterateInvocation(BaseInvocation):
@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = Field(description="The collection of input items") collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation): class CollectInvocation(BaseInvocation):
"""Collects values into a collection""" """Collects values into a collection"""