diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7330cd73be..65ea4c3edb 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput): image: ImageField = Field(default=None, description="The output image") #fmt: on + class Config: + schema_extra = { + 'required': [ + 'type', + 'image', + ] + } + class MaskOutput(BaseInvocationOutput): """Base class for invocations that output a mask""" #fmt: off type: Literal["mask"] = "mask" mask: ImageField = Field(default=None, description="The output mask") - #fomt: on + #fmt: on + + class Config: + schema_extra = { + 'required': [ + 'type', + 'mask', + ] + } # TODO: this isn't really necessary anymore class LoadImageInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 3544f30859..0c7e3069df 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput): prompt: str = Field(default=None, description="The output prompt") #fmt: on + + class Config: + schema_extra = { + 'required': [ + 'type', + 'prompt', + ] + } diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 0d4102c416..171d86c9e3 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception): class GraphInvocationOutput(BaseInvocationOutput): type: Literal["graph_output"] = "graph_output" + class Config: + schema_extra = { + 'required': [ + 'type', + 'image', + ] + } # TODO: Fill this out and move to invocations class GraphInvocation(BaseInvocation): @@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput): item: Any = Field(description="The item being iterated over") + class Config: + schema_extra = { + 'required': [ + 'type', + 'item', + ] + } # TODO: Fill this out and move to invocations class IterateInvocation(BaseInvocation): @@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput): collection: list[Any] = Field(description="The collection of input items") + class Config: + schema_extra = { + 'required': [ + 'type', + 'collection', + ] + } class CollectInvocation(BaseInvocation): """Collects values into a collection"""