mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
1062fc4796
Initial support for polymorphic field types. Polymorphic types are a single of or list of a specific type. For example, `Union[str, list[str]]`. Polymorphics do not yet have support for direct input in the UI (will come in the future). They will be forcibly set as Connection-only fields, in which case users will not be able to provide direct input to the field. If a polymorphic should present as a singleton type - which would allow direct input - the node must provide an explicit type hint. For example, `DenoiseLatents`' `CFG Scale` is polymorphic, but in the node editor, we want to present this as a number input. In the node definition, the field is given `ui_type=UIType.Float`, which tells the UI to treat this as a `float` field. The connection validation logic will prevent connecting a collection to `CFG Scale` in this situation, because it is typed as `float`. The workaround is to disable validation from the settings to make this specific connection. A future improvement will resolve this. This also introduces better support for collection field types. Like polymorphics, collection types are parsed automatically by the client and do not need any specific type hints. Also like polymorphics, there is no support yet for direct input of collection types in the UI. - Disabling validation in workflow editor now displays the visual hints for valid connections, but lets you connect to anything. - Added `ui_order: int` to `InputField` and `OutputField`. The UI will use this, if present, to order fields in a node UI. See usage in `DenoiseLatents` for an example. - Updated the field colors - duplicate colors have just been lightened a bit. It's not perfect but it was a quick fix. - Field handles for collections are the same color as their single counterparts, but have a dark dot in the center of them. - Field handles for polymorphics are a rounded square with dot in the middle. - Removed all fields that just render `null` from `InputFieldRenderer`, replaced with a single fallback - Removed logic in `zValidatedWorkflow`, which checked for existence of node templates for each node in a workflow. This logic introduced a circular dependency, due to importing the global redux `store` in order to get the node templates within a zod schema. It's actually fine to just leave this out entirely; The case of a missing node template is handled by the UI. Fixing it otherwise would introduce a substantial headache. - Fixed the `ControlNetInvocation.control_model` field default, which was a string when it shouldn't have one.
462 lines
14 KiB
Python
462 lines
14 KiB
Python
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from pydantic import BaseModel, Field
|
|
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
FieldDescriptions,
|
|
Input,
|
|
InputField,
|
|
InvocationContext,
|
|
OutputField,
|
|
UIComponent,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
|
|
"""
|
|
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
|
|
- primitive nodes
|
|
- primitive outputs
|
|
- primitive collection outputs
|
|
"""
|
|
|
|
# region Boolean
|
|
|
|
|
|
@invocation_output("boolean_output")
|
|
class BooleanOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single boolean"""
|
|
|
|
value: bool = OutputField(description="The output boolean")
|
|
|
|
|
|
@invocation_output("boolean_collection_output")
|
|
class BooleanCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of booleans"""
|
|
|
|
collection: list[bool] = OutputField(
|
|
description="The output boolean collection",
|
|
)
|
|
|
|
|
|
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
|
|
class BooleanInvocation(BaseInvocation):
|
|
"""A boolean primitive value"""
|
|
|
|
value: bool = InputField(default=False, description="The boolean value")
|
|
|
|
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
|
return BooleanOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"boolean_collection",
|
|
title="Boolean Collection Primitive",
|
|
tags=["primitives", "boolean", "collection"],
|
|
category="primitives",
|
|
)
|
|
class BooleanCollectionInvocation(BaseInvocation):
|
|
"""A collection of boolean primitive values"""
|
|
|
|
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
|
|
|
|
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
|
return BooleanCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Integer
|
|
|
|
|
|
@invocation_output("integer_output")
|
|
class IntegerOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single integer"""
|
|
|
|
value: int = OutputField(description="The output integer")
|
|
|
|
|
|
@invocation_output("integer_collection_output")
|
|
class IntegerCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of integers"""
|
|
|
|
collection: list[int] = OutputField(
|
|
description="The int collection",
|
|
)
|
|
|
|
|
|
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
|
|
class IntegerInvocation(BaseInvocation):
|
|
"""An integer primitive value"""
|
|
|
|
value: int = InputField(default=0, description="The integer value")
|
|
|
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
|
return IntegerOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"integer_collection",
|
|
title="Integer Collection Primitive",
|
|
tags=["primitives", "integer", "collection"],
|
|
category="primitives",
|
|
)
|
|
class IntegerCollectionInvocation(BaseInvocation):
|
|
"""A collection of integer primitive values"""
|
|
|
|
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
|
|
|
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
|
return IntegerCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Float
|
|
|
|
|
|
@invocation_output("float_output")
|
|
class FloatOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single float"""
|
|
|
|
value: float = OutputField(description="The output float")
|
|
|
|
|
|
@invocation_output("float_collection_output")
|
|
class FloatCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of floats"""
|
|
|
|
collection: list[float] = OutputField(
|
|
description="The float collection",
|
|
)
|
|
|
|
|
|
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
|
|
class FloatInvocation(BaseInvocation):
|
|
"""A float primitive value"""
|
|
|
|
value: float = InputField(default=0.0, description="The float value")
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
|
return FloatOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"float_collection",
|
|
title="Float Collection Primitive",
|
|
tags=["primitives", "float", "collection"],
|
|
category="primitives",
|
|
)
|
|
class FloatCollectionInvocation(BaseInvocation):
|
|
"""A collection of float primitive values"""
|
|
|
|
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
|
return FloatCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region String
|
|
|
|
|
|
@invocation_output("string_output")
|
|
class StringOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single string"""
|
|
|
|
value: str = OutputField(description="The output string")
|
|
|
|
|
|
@invocation_output("string_collection_output")
|
|
class StringCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of strings"""
|
|
|
|
collection: list[str] = OutputField(
|
|
description="The output strings",
|
|
)
|
|
|
|
|
|
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
|
|
class StringInvocation(BaseInvocation):
|
|
"""A string primitive value"""
|
|
|
|
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
|
|
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
|
return StringOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"string_collection",
|
|
title="String Collection Primitive",
|
|
tags=["primitives", "string", "collection"],
|
|
category="primitives",
|
|
)
|
|
class StringCollectionInvocation(BaseInvocation):
|
|
"""A collection of string primitive values"""
|
|
|
|
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
|
|
|
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
|
return StringCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Image
|
|
|
|
|
|
class ImageField(BaseModel):
|
|
"""An image primitive field"""
|
|
|
|
image_name: str = Field(description="The name of the image")
|
|
|
|
|
|
@invocation_output("image_output")
|
|
class ImageOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single image"""
|
|
|
|
image: ImageField = OutputField(description="The output image")
|
|
width: int = OutputField(description="The width of the image in pixels")
|
|
height: int = OutputField(description="The height of the image in pixels")
|
|
|
|
|
|
@invocation_output("image_collection_output")
|
|
class ImageCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of images"""
|
|
|
|
collection: list[ImageField] = OutputField(
|
|
description="The output images",
|
|
)
|
|
|
|
|
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
|
|
class ImageInvocation(BaseInvocation):
|
|
"""An image primitive value"""
|
|
|
|
image: ImageField = InputField(description="The image to load")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=self.image.image_name),
|
|
width=image.width,
|
|
height=image.height,
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"image_collection",
|
|
title="Image Collection Primitive",
|
|
tags=["primitives", "image", "collection"],
|
|
category="primitives",
|
|
)
|
|
class ImageCollectionInvocation(BaseInvocation):
|
|
"""A collection of image primitive values"""
|
|
|
|
collection: list[ImageField] = InputField(default_factory=list, description="The collection of image values")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
|
return ImageCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region DenoiseMask
|
|
|
|
|
|
class DenoiseMaskField(BaseModel):
|
|
"""An inpaint mask field"""
|
|
|
|
mask_name: str = Field(description="The name of the mask image")
|
|
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
|
|
|
|
|
|
@invocation_output("denoise_mask_output")
|
|
class DenoiseMaskOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single image"""
|
|
|
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
|
|
|
|
|
# endregion
|
|
|
|
# region Latents
|
|
|
|
|
|
class LatentsField(BaseModel):
|
|
"""A latents tensor primitive field"""
|
|
|
|
latents_name: str = Field(description="The name of the latents")
|
|
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
|
|
|
|
|
@invocation_output("latents_output")
|
|
class LatentsOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single latents tensor"""
|
|
|
|
latents: LatentsField = OutputField(
|
|
description=FieldDescriptions.latents,
|
|
)
|
|
width: int = OutputField(description=FieldDescriptions.width)
|
|
height: int = OutputField(description=FieldDescriptions.height)
|
|
|
|
|
|
@invocation_output("latents_collection_output")
|
|
class LatentsCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of latents tensors"""
|
|
|
|
collection: list[LatentsField] = OutputField(
|
|
description=FieldDescriptions.latents,
|
|
)
|
|
|
|
|
|
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
|
|
class LatentsInvocation(BaseInvocation):
|
|
"""A latents tensor primitive value"""
|
|
|
|
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = context.services.latents.get(self.latents.latents_name)
|
|
|
|
return build_latents_output(self.latents.latents_name, latents)
|
|
|
|
|
|
@invocation(
|
|
"latents_collection",
|
|
title="Latents Collection Primitive",
|
|
tags=["primitives", "latents", "collection"],
|
|
category="primitives",
|
|
)
|
|
class LatentsCollectionInvocation(BaseInvocation):
|
|
"""A collection of latents tensor primitive values"""
|
|
|
|
collection: list[LatentsField] = InputField(
|
|
description="The collection of latents tensors",
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
|
return LatentsCollectionOutput(collection=self.collection)
|
|
|
|
|
|
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
|
return LatentsOutput(
|
|
latents=LatentsField(latents_name=latents_name, seed=seed),
|
|
width=latents.size()[3] * 8,
|
|
height=latents.size()[2] * 8,
|
|
)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Color
|
|
|
|
|
|
class ColorField(BaseModel):
|
|
"""A color primitive field"""
|
|
|
|
r: int = Field(ge=0, le=255, description="The red component")
|
|
g: int = Field(ge=0, le=255, description="The green component")
|
|
b: int = Field(ge=0, le=255, description="The blue component")
|
|
a: int = Field(ge=0, le=255, description="The alpha component")
|
|
|
|
def tuple(self) -> Tuple[int, int, int, int]:
|
|
return (self.r, self.g, self.b, self.a)
|
|
|
|
|
|
@invocation_output("color_output")
|
|
class ColorOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single color"""
|
|
|
|
color: ColorField = OutputField(description="The output color")
|
|
|
|
|
|
@invocation_output("color_collection_output")
|
|
class ColorCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of colors"""
|
|
|
|
collection: list[ColorField] = OutputField(
|
|
description="The output colors",
|
|
)
|
|
|
|
|
|
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
|
|
class ColorInvocation(BaseInvocation):
|
|
"""A color primitive value"""
|
|
|
|
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
|
|
|
def invoke(self, context: InvocationContext) -> ColorOutput:
|
|
return ColorOutput(color=self.color)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Conditioning
|
|
|
|
|
|
class ConditioningField(BaseModel):
|
|
"""A conditioning tensor primitive value"""
|
|
|
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
|
|
|
|
|
@invocation_output("conditioning_output")
|
|
class ConditioningOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single conditioning tensor"""
|
|
|
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
|
|
|
|
|
@invocation_output("conditioning_collection_output")
|
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of conditioning tensors"""
|
|
|
|
collection: list[ConditioningField] = OutputField(
|
|
description="The output conditioning tensors",
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"conditioning",
|
|
title="Conditioning Primitive",
|
|
tags=["primitives", "conditioning"],
|
|
category="primitives",
|
|
)
|
|
class ConditioningInvocation(BaseInvocation):
|
|
"""A conditioning tensor primitive value"""
|
|
|
|
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
|
return ConditioningOutput(conditioning=self.conditioning)
|
|
|
|
|
|
@invocation(
|
|
"conditioning_collection",
|
|
title="Conditioning Collection Primitive",
|
|
tags=["primitives", "conditioning", "collection"],
|
|
category="primitives",
|
|
)
|
|
class ConditioningCollectionInvocation(BaseInvocation):
|
|
"""A collection of conditioning tensor primitive values"""
|
|
|
|
collection: list[ConditioningField] = InputField(
|
|
default_factory=list,
|
|
description="The collection of conditioning tensors",
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
|
return ConditioningCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|