mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lama-infill
This commit is contained in:
commit
db4af7c287
@ -244,8 +244,12 @@ copy-paste the template above.
|
||||
We can use the `@invocation` decorator to provide some additional info to the
|
||||
UI, like a custom title, tags and category.
|
||||
|
||||
We also encourage providing a version. This must be a
|
||||
[semver](https://semver.org/) version string ("$MAJOR.$MINOR.$PATCH"). The UI
|
||||
will let users know if their workflow is using a mismatched version of the node.
|
||||
|
||||
```python
|
||||
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
|
||||
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations", version="1.0.0")
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image"""
|
||||
|
||||
@ -279,8 +283,6 @@ take a look a at our [contributing nodes overview](contributingNodes).
|
||||
|
||||
## Advanced
|
||||
|
||||
-->
|
||||
|
||||
### Custom Output Types
|
||||
|
||||
Like with custom inputs, sometimes you might find yourself needing custom
|
||||
|
@ -26,11 +26,16 @@ from typing import (
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.fields import Undefined, ModelField
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
import semver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvalidVersionError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
@ -105,24 +110,39 @@ class UIType(str, Enum):
|
||||
"""
|
||||
|
||||
# region Primitives
|
||||
Integer = "integer"
|
||||
Float = "float"
|
||||
Boolean = "boolean"
|
||||
String = "string"
|
||||
Array = "array"
|
||||
Image = "ImageField"
|
||||
Latents = "LatentsField"
|
||||
Color = "ColorField"
|
||||
Conditioning = "ConditioningField"
|
||||
Control = "ControlField"
|
||||
Color = "ColorField"
|
||||
ImageCollection = "ImageCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
StringCollection = "StringCollection"
|
||||
Float = "float"
|
||||
Image = "ImageField"
|
||||
Integer = "integer"
|
||||
Latents = "LatentsField"
|
||||
String = "string"
|
||||
# endregion
|
||||
|
||||
# region Collection Primitives
|
||||
BooleanCollection = "BooleanCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ControlCollection = "ControlCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
ImageCollection = "ImageCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
StringCollection = "StringCollection"
|
||||
# endregion
|
||||
|
||||
# region Polymorphic Primitives
|
||||
BooleanPolymorphic = "BooleanPolymorphic"
|
||||
ColorPolymorphic = "ColorPolymorphic"
|
||||
ConditioningPolymorphic = "ConditioningPolymorphic"
|
||||
ControlPolymorphic = "ControlPolymorphic"
|
||||
FloatPolymorphic = "FloatPolymorphic"
|
||||
ImagePolymorphic = "ImagePolymorphic"
|
||||
IntegerPolymorphic = "IntegerPolymorphic"
|
||||
LatentsPolymorphic = "LatentsPolymorphic"
|
||||
StringPolymorphic = "StringPolymorphic"
|
||||
# endregion
|
||||
|
||||
# region Models
|
||||
@ -176,6 +196,7 @@ class _InputField(BaseModel):
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
ui_order: Optional[int]
|
||||
item_default: Optional[Any]
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
@ -223,6 +244,7 @@ def InputField(
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
item_default: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
@ -249,6 +271,11 @@ def InputField(
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
|
||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||
Ignored for non-collection fields..
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
@ -282,6 +309,7 @@ def InputField(
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -332,6 +360,8 @@ def OutputField(
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
@ -376,6 +406,9 @@ class UIConfigBase(BaseModel):
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: Optional[str] = Field(
|
||||
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
@ -474,6 +507,8 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
@ -542,7 +577,11 @@ GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
def invocation(
|
||||
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
|
||||
invocation_type: str,
|
||||
title: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
@ -569,6 +608,12 @@ def invocation(
|
||||
cls.UIConfig.tags = tags
|
||||
if category is not None:
|
||||
cls.UIConfig.category = category
|
||||
if version is not None:
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
|
||||
# Add the invocation type to the pydantic model of the invocation
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
|
@ -10,7 +10,9 @@ from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
|
||||
@invocation(
|
||||
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
|
||||
)
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
@ -33,6 +35,7 @@ class RangeInvocation(BaseInvocation):
|
||||
title="Integer Range of Size",
|
||||
tags=["collection", "integer", "size", "range"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
@ -50,6 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
title="Random Range",
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
@ -44,7 +44,7 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
|
||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -267,6 +267,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -351,6 +352,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
title="SDXL Refiner Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -403,7 +405,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
|
||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
|
@ -95,14 +95,12 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(
|
||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||
)
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
@ -129,7 +127,9 @@ class ControlNetInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
|
||||
@invocation(
|
||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
@ -173,6 +173,7 @@ class ImageProcessorInvocation(BaseInvocation):
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
@ -195,6 +196,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@ -223,6 +225,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@ -244,6 +247,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@ -266,6 +270,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Openpose Processor",
|
||||
tags=["controlnet", "openpose", "pose"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
@ -290,6 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
@ -316,6 +322,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@ -331,7 +338,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
@ -352,7 +361,9 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
@ -378,6 +389,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@ -407,6 +419,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@ -422,6 +435,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
@ -444,6 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@ -472,6 +487,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@ -511,6 +527,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
@ -10,12 +10,7 @@ from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation(
|
||||
"cv_inpaint",
|
||||
title="OpenCV Inpaint",
|
||||
tags=["opencv", "inpaint"],
|
||||
category="inpaint",
|
||||
)
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
|
@ -16,7 +16,7 @@ from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image")
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
|
||||
|
||||
@ -36,7 +36,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image")
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
@ -65,7 +65,7 @@ class BlankImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
@ -98,7 +98,7 @@ class ImageCropInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.0")
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
@ -146,7 +146,7 @@ class ImagePasteInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
@ -177,7 +177,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
@ -210,7 +210,7 @@ class ImageMultiplyInvocation(BaseInvocation):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
@ -242,7 +242,7 @@ class ImageChannelInvocation(BaseInvocation):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
@ -271,7 +271,7 @@ class ImageConvertInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
@ -325,7 +325,7 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
@ -365,7 +365,7 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
@ -406,7 +406,7 @@ class ImageScaleInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -439,7 +439,7 @@ class ImageLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -472,7 +472,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
@ -517,7 +517,9 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
||||
|
||||
@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
|
||||
@invocation(
|
||||
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
@ -548,7 +550,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
@ -593,7 +595,9 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
|
||||
@invocation(
|
||||
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
@ -623,7 +627,7 @@ class MaskCombineInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
@ -733,7 +737,7 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
@ -779,6 +783,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
title="Adjust Image Luminosity",
|
||||
tags=["image", "luminosity", "hsl"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Luminosity (Value) of an image."""
|
||||
@ -831,6 +836,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
title="Adjust Image Saturation",
|
||||
tags=["image", "saturation", "hsl"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Saturation of an image."""
|
||||
|
@ -118,7 +118,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
@ -153,7 +153,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
@ -189,7 +189,9 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
@ -242,7 +244,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class LaMaInfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
|
@ -74,7 +74,7 @@ class SchedulerOutput(BaseInvocationOutput):
|
||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
|
||||
|
||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
|
||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
||||
class SchedulerInvocation(BaseInvocation):
|
||||
"""Selects a scheduler."""
|
||||
|
||||
@ -86,7 +86,9 @@ class SchedulerInvocation(BaseInvocation):
|
||||
return SchedulerOutput(scheduler=self.scheduler)
|
||||
|
||||
|
||||
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
|
||||
@invocation(
|
||||
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
@ -186,6 +188,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@ -208,12 +211,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -317,7 +322,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
# really only need model for dtype and device
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: List[ControlField],
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
@ -542,7 +547,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
|
||||
@invocation(
|
||||
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
@ -639,7 +646,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
|
||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
@ -683,7 +690,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
|
||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
@ -719,7 +726,9 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
|
||||
@invocation(
|
||||
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
@ -799,7 +808,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
|
||||
|
@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerOutput
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math")
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||
class AddInvocation(BaseInvocation):
|
||||
"""Adds two numbers"""
|
||||
|
||||
@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
|
||||
|
||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
|
||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0")
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a - self.b)
|
||||
|
||||
|
||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
|
||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0")
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a * self.b)
|
||||
|
||||
|
||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
|
||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0")
|
||||
class DivideInvocation(BaseInvocation):
|
||||
"""Divides two numbers"""
|
||||
|
||||
@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
|
||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
|
@ -98,7 +98,9 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||
|
||||
|
||||
@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
|
||||
@invocation(
|
||||
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
|
||||
)
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
|
@ -73,7 +73,7 @@ class LoRAModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
@ -173,7 +173,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model")
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@ -244,7 +244,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
|
||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@ -338,7 +338,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
@ -376,7 +376,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
|
||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
||||
class SeamlessModeInvocation(BaseInvocation):
|
||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||
|
||||
|
@ -78,7 +78,7 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
)
|
||||
|
||||
|
||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
|
||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
|
@ -56,7 +56,7 @@ ORT_TO_NP_TYPE = {
|
||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||
|
||||
|
||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
|
||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
@ -143,6 +143,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
title="ONNX Text to Latents",
|
||||
tags=["latents", "inference", "txt2img", "onnx"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
@ -319,6 +320,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
title="ONNX Latents to Image",
|
||||
tags=["latents", "image", "vae", "onnx"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
@ -403,7 +405,7 @@ class OnnxModelField(BaseModel):
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
|
@ -45,7 +45,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
|
||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
@ -96,7 +96,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
|
||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
|
@ -14,7 +14,6 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -40,10 +39,14 @@ class BooleanOutput(BaseInvocationOutput):
|
||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of booleans"""
|
||||
|
||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
||||
collection: list[bool] = OutputField(
|
||||
description="The output boolean collection",
|
||||
)
|
||||
|
||||
|
||||
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
|
||||
@invocation(
|
||||
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
|
||||
)
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
|
||||
@ -58,13 +61,12 @@ class BooleanInvocation(BaseInvocation):
|
||||
title="Boolean Collection Primitive",
|
||||
tags=["primitives", "boolean", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
collection: list[bool] = InputField(
|
||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
@ -86,10 +88,14 @@ class IntegerOutput(BaseInvocationOutput):
|
||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of integers"""
|
||||
|
||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
||||
collection: list[int] = OutputField(
|
||||
description="The int collection",
|
||||
)
|
||||
|
||||
|
||||
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
|
||||
@invocation(
|
||||
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
|
||||
)
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
|
||||
@ -104,13 +110,12 @@ class IntegerInvocation(BaseInvocation):
|
||||
title="Integer Collection Primitive",
|
||||
tags=["primitives", "integer", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
collection: list[int] = InputField(
|
||||
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
@ -132,10 +137,12 @@ class FloatOutput(BaseInvocationOutput):
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of floats"""
|
||||
|
||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
||||
collection: list[float] = OutputField(
|
||||
description="The float collection",
|
||||
)
|
||||
|
||||
|
||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
|
||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
|
||||
@ -150,13 +157,12 @@ class FloatInvocation(BaseInvocation):
|
||||
title="Float Collection Primitive",
|
||||
tags=["primitives", "float", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
collection: list[float] = InputField(
|
||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
@ -178,10 +184,12 @@ class StringOutput(BaseInvocationOutput):
|
||||
class StringCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of strings"""
|
||||
|
||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
||||
collection: list[str] = OutputField(
|
||||
description="The output strings",
|
||||
)
|
||||
|
||||
|
||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
|
||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
|
||||
@ -196,13 +204,12 @@ class StringInvocation(BaseInvocation):
|
||||
title="String Collection Primitive",
|
||||
tags=["primitives", "string", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
collection: list[str] = InputField(
|
||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
@ -232,10 +239,12 @@ class ImageOutput(BaseInvocationOutput):
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
||||
collection: list[ImageField] = OutputField(
|
||||
description="The output images",
|
||||
)
|
||||
|
||||
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
@ -256,13 +265,12 @@ class ImageInvocation(BaseInvocation):
|
||||
title="Image Collection Primitive",
|
||||
tags=["primitives", "image", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
|
||||
collection: list[ImageField] = InputField(
|
||||
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||
)
|
||||
collection: list[ImageField] = InputField(description="The collection of image values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.collection)
|
||||
@ -316,11 +324,12 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
collection: list[LatentsField] = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
)
|
||||
|
||||
|
||||
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
|
||||
@invocation(
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
||||
)
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
|
||||
@ -337,12 +346,13 @@ class LatentsInvocation(BaseInvocation):
|
||||
title="Latents Collection Primitive",
|
||||
tags=["primitives", "latents", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
|
||||
collection: list[LatentsField] = InputField(
|
||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
description="The collection of latents tensors",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
@ -385,10 +395,12 @@ class ColorOutput(BaseInvocationOutput):
|
||||
class ColorCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of colors"""
|
||||
|
||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
||||
collection: list[ColorField] = OutputField(
|
||||
description="The output colors",
|
||||
)
|
||||
|
||||
|
||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
|
||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
|
||||
class ColorInvocation(BaseInvocation):
|
||||
"""A color primitive value"""
|
||||
|
||||
@ -422,7 +434,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
collection: list[ConditioningField] = OutputField(
|
||||
description="The output conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
||||
|
||||
@ -431,6 +442,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
title="Conditioning Primitive",
|
||||
tags=["primitives", "conditioning"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ConditioningInvocation(BaseInvocation):
|
||||
"""A conditioning tensor primitive value"""
|
||||
@ -446,6 +458,7 @@ class ConditioningInvocation(BaseInvocation):
|
||||
title="Conditioning Collection Primitive",
|
||||
tags=["primitives", "conditioning", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
@ -453,7 +466,6 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default_factory=list,
|
||||
description="The collection of conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||
|
@ -10,7 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||
|
||||
|
||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
|
||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
|
||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
|
||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
|
@ -33,7 +33,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
@ -119,6 +119,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
title="SDXL Refiner Model",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
@ -23,7 +23,7 @@ ESRGAN_MODELS = Literal[
|
||||
]
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
|
@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
# allow int -> float, pydantic will cast for us
|
||||
if from_type is int and to_type is float:
|
||||
return True
|
||||
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
|
@ -75,6 +75,7 @@
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
"@stevebel/png": "^1.5.1",
|
||||
"compare-versions": "^6.1.0",
|
||||
"dateformat": "^5.0.3",
|
||||
"formik": "^2.4.3",
|
||||
"framer-motion": "^10.16.1",
|
||||
|
@ -84,6 +84,7 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -202,6 +203,9 @@ addBoardIdSelectedListener();
|
||||
// Node schemas
|
||||
addReceivedOpenAPISchemaListener();
|
||||
|
||||
// Workflows
|
||||
addWorkflowLoadedListener();
|
||||
|
||||
// DND
|
||||
addImageDroppedListener();
|
||||
|
||||
|
@ -0,0 +1,55 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addWorkflowLoadedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const workflow = action.payload;
|
||||
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||
|
||||
const { workflow: validatedWorkflow, errors } = validateWorkflow(
|
||||
workflow,
|
||||
nodeTemplates
|
||||
);
|
||||
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
|
||||
if (!errors.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded with Warnings',
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
errors.forEach(({ message, ...rest }) => {
|
||||
log.warn(rest, message);
|
||||
});
|
||||
}
|
||||
|
||||
dispatch(setActiveTab('nodes'));
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
@ -63,7 +63,11 @@ const selector = createSelector(
|
||||
return;
|
||||
}
|
||||
|
||||
if (fieldTemplate.required && !field.value && !hasConnection) {
|
||||
if (
|
||||
fieldTemplate.required &&
|
||||
field.value === undefined &&
|
||||
!hasConnection
|
||||
) {
|
||||
reasons.push(
|
||||
`${node.data.label || nodeTemplate.title} -> ${
|
||||
field.label || fieldTemplate.title
|
||||
|
@ -1,2 +1,2 @@
|
||||
export const colorTokenToCssVar = (colorToken: string) =>
|
||||
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
|
||||
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;
|
||||
|
@ -17,16 +17,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
@ -124,16 +121,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
if (!workflow) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoaded(workflow));
|
||||
dispatch(setActiveTab('nodes'));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
dispatch(workflowLoadRequested(workflow));
|
||||
}, [dispatch, workflow]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
|
@ -7,12 +7,9 @@ import {
|
||||
isModalOpenChanged,
|
||||
} from 'features/changeBoardModal/store/slice';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
@ -36,6 +33,7 @@ import {
|
||||
} from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
imageDTO: ImageDTO;
|
||||
@ -102,16 +100,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
if (!workflow) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoaded(workflow));
|
||||
dispatch(setActiveTab('nodes'));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
dispatch(workflowLoadRequested(workflow));
|
||||
}, [dispatch, workflow]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
|
@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@ -13,6 +14,7 @@ import {
|
||||
OnConnectStart,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnInit,
|
||||
OnMoveEnd,
|
||||
OnNodesChange,
|
||||
OnNodesDelete,
|
||||
@ -147,6 +149,11 @@ export const Flow = () => {
|
||||
dispatch(contextMenusClosed());
|
||||
}, [dispatch]);
|
||||
|
||||
const onInit: OnInit = useCallback((flow) => {
|
||||
$flow.set(flow);
|
||||
flow.fitView();
|
||||
}, []);
|
||||
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionCopied());
|
||||
@ -170,6 +177,7 @@ export const Flow = () => {
|
||||
edgeTypes={edgeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onInit={onInit}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
Tooltip,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
@ -20,6 +21,7 @@ import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
import NotesTextarea from './NotesTextarea';
|
||||
import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
@ -29,6 +31,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const title = useNodeTemplateTitle(nodeId);
|
||||
const doVersionsMatch = useDoNodeVersionsMatch(nodeId);
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -50,7 +53,11 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
||||
>
|
||||
<Icon
|
||||
as={FaInfoCircle}
|
||||
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
|
||||
sx={{
|
||||
boxSize: 4,
|
||||
w: 8,
|
||||
color: doVersionsMatch ? 'base.400' : 'error.400',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
@ -92,16 +99,59 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
return 'Unknown Node';
|
||||
}, [data, nodeTemplate]);
|
||||
|
||||
const versionComponent = useMemo(() => {
|
||||
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!data.version) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version unknown
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodeTemplate.version) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (unknown template)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (update node)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (update app)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return <Text as="span">Version {data.version}</Text>;
|
||||
}, [data, nodeTemplate]);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex sx={{ flexDir: 'column' }}>
|
||||
<Text sx={{ fontWeight: 600 }}>{title}</Text>
|
||||
<Text as="span" sx={{ fontWeight: 600 }}>
|
||||
{title}
|
||||
</Text>
|
||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||
{nodeTemplate?.description}
|
||||
</Text>
|
||||
{versionComponent}
|
||||
{data?.notes && <Text>{data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,8 +1,11 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import {
|
||||
COLLECTION_TYPES,
|
||||
FIELDS,
|
||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||
MODEL_TYPES,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
|
||||
borderWidth: 0,
|
||||
zIndex: 1,
|
||||
};
|
||||
``;
|
||||
|
||||
export const inputHandleStyles: CSSProperties = {
|
||||
left: '-1rem',
|
||||
@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
connectionError,
|
||||
} = props;
|
||||
const { name, type } = fieldTemplate;
|
||||
const { color, title } = FIELDS[type];
|
||||
const { color: typeColor, title } = FIELDS[type];
|
||||
|
||||
const styles: CSSProperties = useMemo(() => {
|
||||
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||
const isModelType = MODEL_TYPES.includes(type);
|
||||
const color = colorTokenToCssVar(typeColor);
|
||||
const s: CSSProperties = {
|
||||
backgroundColor: colorTokenToCssVar(color),
|
||||
backgroundColor:
|
||||
isCollectionType || isPolymorphicType
|
||||
? 'var(--invokeai-colors-base-900)'
|
||||
: color,
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: 0,
|
||||
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
zIndex: 1,
|
||||
};
|
||||
|
||||
@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
|
||||
return s;
|
||||
}, [
|
||||
color,
|
||||
connectionError,
|
||||
handleType,
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
type,
|
||||
typeColor,
|
||||
]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
|
@ -75,6 +75,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
sx={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
h: 'full',
|
||||
mb: 0,
|
||||
px: 1,
|
||||
gap: 2,
|
||||
|
@ -3,18 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { memo } from 'react';
|
||||
import BooleanInputField from './inputs/BooleanInputField';
|
||||
import ClipInputField from './inputs/ClipInputField';
|
||||
import CollectionInputField from './inputs/CollectionInputField';
|
||||
import CollectionItemInputField from './inputs/CollectionItemInputField';
|
||||
import ColorInputField from './inputs/ColorInputField';
|
||||
import ConditioningInputField from './inputs/ConditioningInputField';
|
||||
import ControlInputField from './inputs/ControlInputField';
|
||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
|
||||
import EnumInputField from './inputs/EnumInputField';
|
||||
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
|
||||
import ImageInputField from './inputs/ImageInputField';
|
||||
import LatentsInputField from './inputs/LatentsInputField';
|
||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||
import MainModelInputField from './inputs/MainModelInputField';
|
||||
import NumberInputField from './inputs/NumberInputField';
|
||||
@ -22,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
|
||||
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import UnetInputField from './inputs/UnetInputField';
|
||||
import VaeInputField from './inputs/VaeInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
|
||||
type InputFieldProps = {
|
||||
@ -31,7 +21,6 @@ type InputFieldProps = {
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
// build an individual input element based on the schema
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
@ -93,88 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'LatentsField' &&
|
||||
fieldTemplate?.type === 'LatentsField'
|
||||
) {
|
||||
return (
|
||||
<LatentsInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'DenoiseMaskField' &&
|
||||
fieldTemplate?.type === 'DenoiseMaskField'
|
||||
) {
|
||||
return (
|
||||
<DenoiseMaskInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ConditioningField' &&
|
||||
fieldTemplate?.type === 'ConditioningField'
|
||||
) {
|
||||
return (
|
||||
<ConditioningInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
|
||||
return (
|
||||
<UnetInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
|
||||
return (
|
||||
<ClipInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
|
||||
return (
|
||||
<VaeInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ControlField' &&
|
||||
fieldTemplate?.type === 'ControlField'
|
||||
) {
|
||||
return (
|
||||
<ControlInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'MainModelField' &&
|
||||
fieldTemplate?.type === 'MainModelField'
|
||||
@ -240,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
|
||||
return (
|
||||
<CollectionInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'CollectionItem' &&
|
||||
fieldTemplate?.type === 'CollectionItem'
|
||||
) {
|
||||
return (
|
||||
<CollectionItemInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
@ -273,19 +157,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ImageCollection' &&
|
||||
fieldTemplate?.type === 'ImageCollection'
|
||||
) {
|
||||
return (
|
||||
<ImageCollectionInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLMainModelField' &&
|
||||
fieldTemplate?.type === 'SDXLMainModelField'
|
||||
@ -309,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field && fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={1}>
|
||||
<Text
|
||||
|
@ -1,12 +1,17 @@
|
||||
import {
|
||||
ControlInputFieldTemplate,
|
||||
ControlInputFieldValue,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ControlPolymorphicInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const ControlInputFieldComponent = (
|
||||
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
|
||||
_props: FieldComponentProps<
|
||||
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
|
||||
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
@ -9,9 +9,9 @@ import {
|
||||
} from 'features/dnd/types';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
|
@ -2,11 +2,16 @@ import {
|
||||
LatentsInputFieldTemplate,
|
||||
LatentsInputFieldValue,
|
||||
FieldComponentProps,
|
||||
LatentsPolymorphicInputFieldValue,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const LatentsInputFieldComponent = (
|
||||
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
|
||||
_props: FieldComponentProps<
|
||||
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
|
||||
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
FloatInputFieldTemplate,
|
||||
FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
|
@ -138,13 +138,14 @@ export const useBuildNodeData = () => {
|
||||
data: {
|
||||
id: nodeId,
|
||||
type,
|
||||
inputs,
|
||||
outputs,
|
||||
isOpen: true,
|
||||
version: template.version,
|
||||
label: '',
|
||||
notes: '',
|
||||
isOpen: true,
|
||||
embedWorkflow: false,
|
||||
isIntermediate: true,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,33 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
if (!nodeTemplate?.version || !node.data?.version) {
|
||||
return false;
|
||||
}
|
||||
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeTemplate = useAppSelector(selector);
|
||||
|
||||
return nodeTemplate;
|
||||
};
|
@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return Boolean(node?.data.inputs[fieldName]?.value);
|
||||
return node?.data.inputs[fieldName]?.value !== undefined;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
|
@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
||||
import { COLLECTION_TYPES } from '../types/constants';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from '../types/constants';
|
||||
import { InvocationNodeData } from '../types/types';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const flow = useReactFlow();
|
||||
const shouldValidateGraph = useAppSelector(
|
||||
@ -42,6 +52,19 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
edges
|
||||
.filter((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
})
|
||||
.find((edge) => {
|
||||
edge.source === source && edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
@ -53,21 +76,62 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection types must be the same for a connection
|
||||
if (
|
||||
sourceType !== targetType &&
|
||||
sourceType !== 'CollectionItem' &&
|
||||
targetType !== 'CollectionItem'
|
||||
) {
|
||||
if (
|
||||
!(
|
||||
COLLECTION_TYPES.includes(targetType) &&
|
||||
COLLECTION_TYPES.includes(sourceType)
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat
|
||||
);
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||
},
|
||||
|
@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { zValidatedWorkflow } from 'features/nodes/types/types';
|
||||
import { zWorkflow } from 'features/nodes/types/types';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
||||
import { workflowLoadRequested } from '../store/actions';
|
||||
|
||||
export const useLoadWorkflowFromFile = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -24,7 +24,7 @@ export const useLoadWorkflowFromFile = () => {
|
||||
|
||||
try {
|
||||
const parsedJSON = JSON.parse(String(rawJSON));
|
||||
const result = zValidatedWorkflow.safeParse(parsedJSON);
|
||||
const result = zWorkflow.safeParse(parsedJSON);
|
||||
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
@ -45,32 +45,8 @@ export const useLoadWorkflowFromFile = () => {
|
||||
reader.abort();
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoaded(result.data.workflow));
|
||||
|
||||
if (!result.data.warnings.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded with Warnings',
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
result.data.warnings.forEach(({ message, ...rest }) => {
|
||||
logger.warn(rest, message);
|
||||
});
|
||||
dispatch(workflowLoadRequested(result.data));
|
||||
|
||||
reader.abort();
|
||||
} catch {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { Workflow } from '../types/types';
|
||||
|
||||
export const textToImageGraphBuilt = createAction<Graph>(
|
||||
'nodes/textToImageGraphBuilt'
|
||||
@ -16,3 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
|
||||
canvasGraphBuilt,
|
||||
nodesGraphBuilt
|
||||
);
|
||||
|
||||
export const workflowLoadRequested = createAction<Workflow>(
|
||||
'nodes/workflowLoadRequested'
|
||||
);
|
||||
|
@ -0,0 +1,4 @@
|
||||
import { atom } from 'nanostores';
|
||||
import { ReactFlowInstance } from 'reactflow';
|
||||
|
||||
export const $flow = atom<ReactFlowInstance | null>(null);
|
@ -1,10 +1,20 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { COLLECTION_TYPES } from 'features/nodes/types/constants';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import { HandleType } from 'reactflow';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const makeConnectionErrorSelector = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = (
|
||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||
state.nodes;
|
||||
|
||||
if (!state.nodes.shouldValidateGraph) {
|
||||
// manual override!
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!connectionStartParams || !currentConnectionFieldType) {
|
||||
return 'No connection in progress';
|
||||
}
|
||||
@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = (
|
||||
return 'No connection data';
|
||||
}
|
||||
|
||||
const targetFieldType =
|
||||
const targetType =
|
||||
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
||||
const sourceFieldType =
|
||||
const sourceType =
|
||||
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
||||
|
||||
if (nodeId === connectionNodeId) {
|
||||
@ -55,30 +60,73 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
|
||||
if (
|
||||
fieldType !== currentConnectionFieldType &&
|
||||
fieldType !== 'CollectionItem' &&
|
||||
currentConnectionFieldType !== 'CollectionItem'
|
||||
) {
|
||||
if (
|
||||
!(
|
||||
COLLECTION_TYPES.includes(targetFieldType) &&
|
||||
COLLECTION_TYPES.includes(sourceFieldType)
|
||||
)
|
||||
) {
|
||||
// except for collection items, field types must match
|
||||
return 'Field types must match';
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
handleType === 'target' &&
|
||||
edges.find((edge) => {
|
||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetFieldType !== 'CollectionItem'
|
||||
targetType !== 'CollectionItem'
|
||||
) {
|
||||
return 'Inputs may only have one connection';
|
||||
return 'Input may only have one connection';
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
if (
|
||||
!(
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat
|
||||
)
|
||||
) {
|
||||
return 'Field types must match';
|
||||
}
|
||||
}
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(
|
||||
|
@ -17,176 +17,297 @@ export const KIND_MAP = {
|
||||
export const COLLECTION_TYPES: FieldType[] = [
|
||||
'Collection',
|
||||
'IntegerCollection',
|
||||
'BooleanCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'BooleanCollection',
|
||||
'ImageCollection',
|
||||
'LatentsCollection',
|
||||
'ConditioningCollection',
|
||||
'ControlCollection',
|
||||
'ColorCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES = [
|
||||
'IntegerPolymorphic',
|
||||
'BooleanPolymorphic',
|
||||
'FloatPolymorphic',
|
||||
'StringPolymorphic',
|
||||
'ImagePolymorphic',
|
||||
'LatentsPolymorphic',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'ONNXModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'ClipField',
|
||||
];
|
||||
|
||||
export const COLLECTION_MAP = {
|
||||
integer: 'IntegerCollection',
|
||||
boolean: 'BooleanCollection',
|
||||
number: 'FloatCollection',
|
||||
float: 'FloatCollection',
|
||||
string: 'StringCollection',
|
||||
ImageField: 'ImageCollection',
|
||||
LatentsField: 'LatentsCollection',
|
||||
ConditioningField: 'ConditioningCollection',
|
||||
ControlField: 'ControlCollection',
|
||||
ColorField: 'ColorCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof COLLECTION_MAP =>
|
||||
Boolean(itemType && itemType in COLLECTION_MAP);
|
||||
|
||||
export const SINGLE_TO_POLYMORPHIC_MAP = {
|
||||
integer: 'IntegerPolymorphic',
|
||||
boolean: 'BooleanPolymorphic',
|
||||
number: 'FloatPolymorphic',
|
||||
float: 'FloatPolymorphic',
|
||||
string: 'StringPolymorphic',
|
||||
ImageField: 'ImagePolymorphic',
|
||||
LatentsField: 'LatentsPolymorphic',
|
||||
ConditioningField: 'ConditioningPolymorphic',
|
||||
ControlField: 'ControlPolymorphic',
|
||||
ColorField: 'ColorPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP = {
|
||||
IntegerPolymorphic: 'integer',
|
||||
BooleanPolymorphic: 'boolean',
|
||||
FloatPolymorphic: 'float',
|
||||
StringPolymorphic: 'string',
|
||||
ImagePolymorphic: 'ImageField',
|
||||
LatentsPolymorphic: 'LatentsField',
|
||||
ConditioningPolymorphic: 'ConditioningField',
|
||||
ControlPolymorphic: 'ControlField',
|
||||
ColorPolymorphic: 'ColorField',
|
||||
};
|
||||
|
||||
export const isPolymorphicItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
title: 'Integer',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
color: 'red.500',
|
||||
},
|
||||
float: {
|
||||
title: 'Float',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
color: 'orange.500',
|
||||
},
|
||||
string: {
|
||||
title: 'String',
|
||||
description: 'Strings are text.',
|
||||
color: 'yellow.500',
|
||||
},
|
||||
boolean: {
|
||||
title: 'Boolean',
|
||||
color: 'green.500',
|
||||
description: 'Booleans are true or false.',
|
||||
title: 'Boolean',
|
||||
},
|
||||
enum: {
|
||||
title: 'Enum',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
color: 'blue.500',
|
||||
BooleanCollection: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Collection',
|
||||
},
|
||||
array: {
|
||||
title: 'Array',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
color: 'base.500',
|
||||
},
|
||||
ImageField: {
|
||||
title: 'Image',
|
||||
description: 'Images may be passed between nodes.',
|
||||
color: 'purple.500',
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
title: 'Denoise Mask',
|
||||
description: 'Denoise Mask may be passed between nodes',
|
||||
color: 'base.500',
|
||||
},
|
||||
LatentsField: {
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
color: 'pink.500',
|
||||
},
|
||||
LatentsCollection: {
|
||||
title: 'Latents Collection',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
color: 'pink.500',
|
||||
},
|
||||
ConditioningField: {
|
||||
color: 'cyan.500',
|
||||
title: 'Conditioning',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
},
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
title: 'Conditioning Collection',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
},
|
||||
ImageCollection: {
|
||||
title: 'Image Collection',
|
||||
description: 'A collection of images.',
|
||||
color: 'base.300',
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
title: 'UNet',
|
||||
description: 'UNet submodel.',
|
||||
BooleanPolymorphic: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Polymorphic',
|
||||
},
|
||||
ClipField: {
|
||||
color: 'green.500',
|
||||
title: 'Clip',
|
||||
description: 'Tokenizer and text_encoder submodels.',
|
||||
},
|
||||
VaeField: {
|
||||
color: 'blue.500',
|
||||
title: 'Vae',
|
||||
description: 'Vae submodel.',
|
||||
},
|
||||
ControlField: {
|
||||
color: 'cyan.500',
|
||||
title: 'Control',
|
||||
description: 'Control info passed between nodes.',
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'Model',
|
||||
description: 'TODO',
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'Refiner Model',
|
||||
description: 'TODO',
|
||||
},
|
||||
VaeModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'VAE',
|
||||
description: 'TODO',
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'LoRA',
|
||||
description: 'TODO',
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'ControlNet',
|
||||
description: 'TODO',
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
title: 'Scheduler',
|
||||
description: 'TODO',
|
||||
title: 'Clip',
|
||||
},
|
||||
Collection: {
|
||||
color: 'base.500',
|
||||
title: 'Collection',
|
||||
description: 'TODO',
|
||||
title: 'Collection',
|
||||
},
|
||||
CollectionItem: {
|
||||
color: 'base.500',
|
||||
title: 'Collection Item',
|
||||
description: 'TODO',
|
||||
title: 'Collection Item',
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Collection',
|
||||
},
|
||||
ColorField: {
|
||||
title: 'Color',
|
||||
color: 'pink.300',
|
||||
description: 'A RGBA color.',
|
||||
color: 'base.500',
|
||||
title: 'Color',
|
||||
},
|
||||
BooleanCollection: {
|
||||
title: 'Boolean Collection',
|
||||
description: 'A collection of booleans.',
|
||||
color: 'green.500',
|
||||
ColorPolymorphic: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Polymorphic',
|
||||
},
|
||||
IntegerCollection: {
|
||||
title: 'Integer Collection',
|
||||
description: 'A collection of integers.',
|
||||
color: 'red.500',
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Collection',
|
||||
},
|
||||
ConditioningField: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning',
|
||||
},
|
||||
ConditioningPolymorphic: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Polymorphic',
|
||||
},
|
||||
ControlCollection: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Collection',
|
||||
},
|
||||
ControlField: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control',
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'ControlNet',
|
||||
},
|
||||
ControlPolymorphic: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Polymorphic',
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
color: 'blue.300',
|
||||
description: 'Denoise Mask may be passed between nodes',
|
||||
title: 'Denoise Mask',
|
||||
},
|
||||
enum: {
|
||||
color: 'blue.500',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
title: 'Enum',
|
||||
},
|
||||
float: {
|
||||
color: 'orange.500',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
title: 'Float',
|
||||
},
|
||||
FloatCollection: {
|
||||
color: 'orange.500',
|
||||
title: 'Float Collection',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Collection',
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'base.500',
|
||||
title: 'Color Collection',
|
||||
description: 'A collection of colors.',
|
||||
FloatPolymorphic: {
|
||||
color: 'orange.500',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Polymorphic',
|
||||
},
|
||||
ImageCollection: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Collection',
|
||||
},
|
||||
ImageField: {
|
||||
color: 'purple.500',
|
||||
description: 'Images may be passed between nodes.',
|
||||
title: 'Image',
|
||||
},
|
||||
ImagePolymorphic: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Polymorphic',
|
||||
},
|
||||
integer: {
|
||||
color: 'red.500',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
title: 'Integer',
|
||||
},
|
||||
IntegerCollection: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Collection',
|
||||
},
|
||||
IntegerPolymorphic: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Polymorphic',
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Collection',
|
||||
},
|
||||
LatentsField: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents',
|
||||
},
|
||||
LatentsPolymorphic: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Polymorphic',
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'LoRA',
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Model',
|
||||
},
|
||||
ONNXModelField: {
|
||||
color: 'base.500',
|
||||
title: 'ONNX Model',
|
||||
color: 'teal.500',
|
||||
description: 'ONNX model field.',
|
||||
title: 'ONNX Model',
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Scheduler',
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
color: 'base.500',
|
||||
title: 'SDXL Model',
|
||||
color: 'teal.500',
|
||||
description: 'SDXL model field.',
|
||||
title: 'SDXL Model',
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Refiner Model',
|
||||
},
|
||||
string: {
|
||||
color: 'yellow.500',
|
||||
description: 'Strings are text.',
|
||||
title: 'String',
|
||||
},
|
||||
StringCollection: {
|
||||
color: 'yellow.500',
|
||||
title: 'String Collection',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Collection',
|
||||
},
|
||||
StringPolymorphic: {
|
||||
color: 'yellow.500',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Polymorphic',
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
description: 'UNet submodel.',
|
||||
title: 'UNet',
|
||||
},
|
||||
VaeField: {
|
||||
color: 'blue.500',
|
||||
description: 'Vae submodel.',
|
||||
title: 'Vae',
|
||||
},
|
||||
VaeModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'VAE',
|
||||
},
|
||||
};
|
||||
|
@ -11,7 +11,7 @@ import { keyBy } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { Node } from 'reactflow';
|
||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
||||
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
||||
import {
|
||||
AnyInvocationType,
|
||||
AnyResult,
|
||||
@ -52,6 +52,10 @@ export type InvocationTemplate = {
|
||||
* The type of this node's output
|
||||
*/
|
||||
outputType: string; // TODO: generate a union of output types
|
||||
/**
|
||||
* The invocation's version.
|
||||
*/
|
||||
version?: string;
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
@ -62,50 +66,48 @@ export type FieldUIConfig = {
|
||||
|
||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||
export const zFieldType = z.enum([
|
||||
// region Primitives
|
||||
'integer',
|
||||
'float',
|
||||
'boolean',
|
||||
'string',
|
||||
'array',
|
||||
'ImageField',
|
||||
'DenoiseMaskField',
|
||||
'LatentsField',
|
||||
'ConditioningField',
|
||||
'ControlField',
|
||||
'ColorField',
|
||||
'ImageCollection',
|
||||
'ConditioningCollection',
|
||||
'ColorCollection',
|
||||
'LatentsCollection',
|
||||
'IntegerCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'BooleanCollection',
|
||||
// endregion
|
||||
|
||||
// region Models
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'ONNXModelField',
|
||||
'VaeModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'BooleanPolymorphic',
|
||||
'ClipField',
|
||||
// endregion
|
||||
|
||||
// region Iterate/Collect
|
||||
'Collection',
|
||||
'CollectionItem',
|
||||
// endregion
|
||||
|
||||
// region Misc
|
||||
'ColorCollection',
|
||||
'ColorField',
|
||||
'ColorPolymorphic',
|
||||
'ConditioningCollection',
|
||||
'ConditioningField',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlCollection',
|
||||
'ControlField',
|
||||
'ControlNetModelField',
|
||||
'ControlPolymorphic',
|
||||
'DenoiseMaskField',
|
||||
'enum',
|
||||
'float',
|
||||
'FloatCollection',
|
||||
'FloatPolymorphic',
|
||||
'ImageCollection',
|
||||
'ImageField',
|
||||
'ImagePolymorphic',
|
||||
'integer',
|
||||
'IntegerCollection',
|
||||
'IntegerPolymorphic',
|
||||
'LatentsCollection',
|
||||
'LatentsField',
|
||||
'LatentsPolymorphic',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'ONNXModelField',
|
||||
'Scheduler',
|
||||
// endregion
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'string',
|
||||
'StringCollection',
|
||||
'StringPolymorphic',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'VaeModelField',
|
||||
]);
|
||||
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
@ -122,38 +124,6 @@ export const isFieldType = (value: unknown): value is FieldType =>
|
||||
zFieldType.safeParse(value).success ||
|
||||
zReservedFieldType.safeParse(value).success;
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the field type and other field metadata (e.g. title, description,
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| IntegerInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| DenoiseMaskInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| ClipInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| MainModelInputFieldTemplate
|
||||
| SDXLMainModelInputFieldTemplate
|
||||
| SDXLRefinerModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| CollectionInputFieldTemplate
|
||||
| CollectionItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
| ImageCollectionInputFieldTemplate
|
||||
| SchedulerInputFieldTemplate;
|
||||
|
||||
/**
|
||||
* Indicates the kind of input(s) this field may have.
|
||||
*/
|
||||
@ -232,24 +202,88 @@ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
||||
|
||||
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IntegerCollection'),
|
||||
value: z.array(z.number().int()).optional(),
|
||||
});
|
||||
export type IntegerCollectionInputFieldValue = z.infer<
|
||||
typeof zIntegerCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IntegerPolymorphic'),
|
||||
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
|
||||
});
|
||||
export type IntegerPolymorphicInputFieldValue = z.infer<
|
||||
typeof zIntegerPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('float'),
|
||||
value: z.number().optional(),
|
||||
});
|
||||
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
|
||||
|
||||
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FloatCollection'),
|
||||
value: z.array(z.number()).optional(),
|
||||
});
|
||||
export type FloatCollectionInputFieldValue = z.infer<
|
||||
typeof zFloatCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FloatPolymorphic'),
|
||||
value: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
});
|
||||
export type FloatPolymorphicInputFieldValue = z.infer<
|
||||
typeof zFloatPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zStringInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('string'),
|
||||
value: z.string().optional(),
|
||||
});
|
||||
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
|
||||
|
||||
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('StringCollection'),
|
||||
value: z.array(z.string()).optional(),
|
||||
});
|
||||
export type StringCollectionInputFieldValue = z.infer<
|
||||
typeof zStringCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('StringPolymorphic'),
|
||||
value: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
});
|
||||
export type StringPolymorphicInputFieldValue = z.infer<
|
||||
typeof zStringPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('boolean'),
|
||||
value: z.boolean().optional(),
|
||||
});
|
||||
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
|
||||
|
||||
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('BooleanCollection'),
|
||||
value: z.array(z.boolean()).optional(),
|
||||
});
|
||||
export type BooleanCollectionInputFieldValue = z.infer<
|
||||
typeof zBooleanCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('BooleanPolymorphic'),
|
||||
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
|
||||
});
|
||||
export type BooleanPolymorphicInputFieldValue = z.infer<
|
||||
typeof zBooleanPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('enum'),
|
||||
value: z.union([z.string(), z.number()]).optional(),
|
||||
@ -262,6 +296,22 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
||||
|
||||
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('LatentsCollection'),
|
||||
value: z.array(zLatentsField).optional(),
|
||||
});
|
||||
export type LatentsCollectionInputFieldValue = z.infer<
|
||||
typeof zLatentsCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('LatentsPolymorphic'),
|
||||
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
|
||||
});
|
||||
export type LatentsPolymorphicInputFieldValue = z.infer<
|
||||
typeof zLatentsPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('DenoiseMaskField'),
|
||||
value: zDenoiseMaskField.optional(),
|
||||
@ -278,6 +328,26 @@ export type ConditioningInputFieldValue = z.infer<
|
||||
typeof zConditioningInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningCollectionInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('ConditioningCollection'),
|
||||
value: z.array(zConditioningField).optional(),
|
||||
});
|
||||
export type ConditioningCollectionInputFieldValue = z.infer<
|
||||
typeof zConditioningCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningPolymorphicInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('ConditioningPolymorphic'),
|
||||
value: z
|
||||
.union([zConditioningField, z.array(zConditioningField)])
|
||||
.optional(),
|
||||
});
|
||||
export type ConditioningPolymorphicInputFieldValue = z.infer<
|
||||
typeof zConditioningPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zControlNetModel = zModelIdentifier;
|
||||
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
||||
|
||||
@ -302,6 +372,22 @@ export const zControlInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
|
||||
|
||||
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ControlPolymorphic'),
|
||||
value: z.union([zControlField, z.array(zControlField)]).optional(),
|
||||
});
|
||||
export type ControlPolymorphicInputFieldValue = z.infer<
|
||||
typeof zControlPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ControlCollection'),
|
||||
value: z.array(zControlField).optional(),
|
||||
});
|
||||
export type ControlCollectionInputFieldValue = z.infer<
|
||||
typeof zControlCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zModelType = z.enum([
|
||||
'onnx',
|
||||
'main',
|
||||
@ -381,6 +467,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
||||
|
||||
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ImagePolymorphic'),
|
||||
value: z.union([zImageField, z.array(zImageField)]).optional(),
|
||||
});
|
||||
export type ImagePolymorphicInputFieldValue = z.infer<
|
||||
typeof zImagePolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ImageCollection'),
|
||||
value: z.array(zImageField).optional(),
|
||||
@ -473,6 +567,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
|
||||
|
||||
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ColorCollection'),
|
||||
value: z.array(zColorField).optional(),
|
||||
});
|
||||
export type ColorCollectionInputFieldValue = z.infer<
|
||||
typeof zColorCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ColorPolymorphic'),
|
||||
value: z.union([zColorField, z.array(zColorField)]).optional(),
|
||||
});
|
||||
export type ColorPolymorphicInputFieldValue = z.infer<
|
||||
typeof zColorPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Scheduler'),
|
||||
value: zScheduler.optional(),
|
||||
@ -482,30 +592,47 @@ export type SchedulerInputFieldValue = z.infer<
|
||||
>;
|
||||
|
||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zIntegerInputFieldValue,
|
||||
zFloatInputFieldValue,
|
||||
zStringInputFieldValue,
|
||||
zBooleanCollectionInputFieldValue,
|
||||
zBooleanInputFieldValue,
|
||||
zImageInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zDenoiseMaskInputFieldValue,
|
||||
zConditioningInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zBooleanPolymorphicInputFieldValue,
|
||||
zClipInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zControlInputFieldValue,
|
||||
zEnumInputFieldValue,
|
||||
zMainModelInputFieldValue,
|
||||
zSDXLMainModelInputFieldValue,
|
||||
zSDXLRefinerModelInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
zLoRAModelInputFieldValue,
|
||||
zControlNetModelInputFieldValue,
|
||||
zCollectionInputFieldValue,
|
||||
zCollectionItemInputFieldValue,
|
||||
zColorInputFieldValue,
|
||||
zColorCollectionInputFieldValue,
|
||||
zColorPolymorphicInputFieldValue,
|
||||
zConditioningInputFieldValue,
|
||||
zConditioningCollectionInputFieldValue,
|
||||
zConditioningPolymorphicInputFieldValue,
|
||||
zControlInputFieldValue,
|
||||
zControlNetModelInputFieldValue,
|
||||
zControlCollectionInputFieldValue,
|
||||
zControlPolymorphicInputFieldValue,
|
||||
zDenoiseMaskInputFieldValue,
|
||||
zEnumInputFieldValue,
|
||||
zFloatCollectionInputFieldValue,
|
||||
zFloatInputFieldValue,
|
||||
zFloatPolymorphicInputFieldValue,
|
||||
zImageCollectionInputFieldValue,
|
||||
zImagePolymorphicInputFieldValue,
|
||||
zImageInputFieldValue,
|
||||
zIntegerCollectionInputFieldValue,
|
||||
zIntegerPolymorphicInputFieldValue,
|
||||
zIntegerInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zLatentsCollectionInputFieldValue,
|
||||
zLatentsPolymorphicInputFieldValue,
|
||||
zLoRAModelInputFieldValue,
|
||||
zMainModelInputFieldValue,
|
||||
zSchedulerInputFieldValue,
|
||||
zSDXLMainModelInputFieldValue,
|
||||
zSDXLRefinerModelInputFieldValue,
|
||||
zStringCollectionInputFieldValue,
|
||||
zStringPolymorphicInputFieldValue,
|
||||
zStringInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
]);
|
||||
|
||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||
@ -514,7 +641,6 @@ export type InputFieldTemplateBase = {
|
||||
name: string;
|
||||
title: string;
|
||||
description: string;
|
||||
type: FieldType;
|
||||
required: boolean;
|
||||
fieldKind: 'input';
|
||||
} & _InputField;
|
||||
@ -529,6 +655,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'IntegerCollection';
|
||||
default: number[];
|
||||
item_default?: number;
|
||||
};
|
||||
|
||||
export type IntegerPolymorphicInputFieldTemplate = Omit<
|
||||
IntegerInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'IntegerPolymorphic';
|
||||
};
|
||||
|
||||
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'float';
|
||||
default: number;
|
||||
@ -539,6 +678,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'FloatCollection';
|
||||
default: number[];
|
||||
item_default?: number;
|
||||
};
|
||||
|
||||
export type FloatPolymorphicInputFieldTemplate = Omit<
|
||||
FloatInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'FloatPolymorphic';
|
||||
};
|
||||
|
||||
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'string';
|
||||
default: string;
|
||||
@ -547,19 +699,53 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||
pattern?: string;
|
||||
};
|
||||
|
||||
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'StringCollection';
|
||||
default: string[];
|
||||
item_default?: string;
|
||||
};
|
||||
|
||||
export type StringPolymorphicInputFieldTemplate = Omit<
|
||||
StringInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'StringPolymorphic';
|
||||
};
|
||||
|
||||
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: boolean;
|
||||
type: 'boolean';
|
||||
};
|
||||
|
||||
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'BooleanCollection';
|
||||
default: boolean[];
|
||||
item_default?: boolean;
|
||||
};
|
||||
|
||||
export type BooleanPolymorphicInputFieldTemplate = Omit<
|
||||
BooleanInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'BooleanPolymorphic';
|
||||
};
|
||||
|
||||
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: ImageDTO;
|
||||
default: ImageField;
|
||||
type: 'ImageField';
|
||||
};
|
||||
|
||||
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: ImageField[];
|
||||
type: 'ImageCollection';
|
||||
item_default?: ImageField;
|
||||
};
|
||||
|
||||
export type ImagePolymorphicInputFieldTemplate = Omit<
|
||||
ImageInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ImagePolymorphic';
|
||||
};
|
||||
|
||||
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||
@ -568,15 +754,40 @@ export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||
};
|
||||
|
||||
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
default: LatentsField;
|
||||
type: 'LatentsField';
|
||||
};
|
||||
|
||||
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: LatentsField[];
|
||||
type: 'LatentsCollection';
|
||||
item_default?: LatentsField;
|
||||
};
|
||||
|
||||
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: LatentsField;
|
||||
type: 'LatentsPolymorphic';
|
||||
};
|
||||
|
||||
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ConditioningField';
|
||||
};
|
||||
|
||||
export type ConditioningCollectionInputFieldTemplate =
|
||||
InputFieldTemplateBase & {
|
||||
default: ConditioningField[];
|
||||
type: 'ConditioningCollection';
|
||||
item_default?: ConditioningField;
|
||||
};
|
||||
|
||||
export type ConditioningPolymorphicInputFieldTemplate = Omit<
|
||||
ConditioningInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ConditioningPolymorphic';
|
||||
};
|
||||
|
||||
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'UNetField';
|
||||
@ -597,6 +808,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ControlField';
|
||||
};
|
||||
|
||||
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ControlCollection';
|
||||
item_default?: ControlField;
|
||||
};
|
||||
|
||||
export type ControlPolymorphicInputFieldTemplate = Omit<
|
||||
ControlInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ControlPolymorphic';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string | number;
|
||||
type: 'enum';
|
||||
@ -649,6 +873,18 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ColorField';
|
||||
};
|
||||
|
||||
export type ColorPolymorphicInputFieldTemplate = Omit<
|
||||
ColorInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ColorPolymorphic';
|
||||
};
|
||||
|
||||
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'ColorCollection';
|
||||
};
|
||||
|
||||
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: SchedulerParam;
|
||||
type: 'Scheduler';
|
||||
@ -659,6 +895,55 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'WorkflowField';
|
||||
};
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the field type and other field metadata (e.g. title, description,
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| BooleanCollectionInputFieldTemplate
|
||||
| BooleanPolymorphicInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ClipInputFieldTemplate
|
||||
| CollectionInputFieldTemplate
|
||||
| CollectionItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
| ColorCollectionInputFieldTemplate
|
||||
| ColorPolymorphicInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| ConditioningCollectionInputFieldTemplate
|
||||
| ConditioningPolymorphicInputFieldTemplate
|
||||
| ControlInputFieldTemplate
|
||||
| ControlCollectionInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| ControlPolymorphicInputFieldTemplate
|
||||
| DenoiseMaskInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| FloatCollectionInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| FloatPolymorphicInputFieldTemplate
|
||||
| ImageCollectionInputFieldTemplate
|
||||
| ImagePolymorphicInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| IntegerCollectionInputFieldTemplate
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| IntegerInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| LatentsCollectionInputFieldTemplate
|
||||
| LatentsPolymorphicInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| MainModelInputFieldTemplate
|
||||
| SchedulerInputFieldTemplate
|
||||
| SDXLMainModelInputFieldTemplate
|
||||
| SDXLRefinerModelInputFieldTemplate
|
||||
| StringCollectionInputFieldTemplate
|
||||
| StringPolymorphicInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate;
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||
@ -681,6 +966,7 @@ export type InvocationSchemaExtra = {
|
||||
title: string;
|
||||
category?: string;
|
||||
tags?: string[];
|
||||
version?: string;
|
||||
properties: Omit<
|
||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
||||
(_InputField | _OutputField),
|
||||
@ -731,8 +1017,22 @@ export type InvocationSchemaObject = (
|
||||
) & { class: 'invocation' };
|
||||
|
||||
export const isSchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
||||
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||
|
||||
export const isArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||
|
||||
export const isNonArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.NonArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||
|
||||
export const isRefObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||
|
||||
export const isInvocationSchemaObject = (
|
||||
obj:
|
||||
@ -800,6 +1100,29 @@ export const zCoreMetadata = z
|
||||
|
||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
major !== undefined &&
|
||||
Number.isInteger(Number(major)) &&
|
||||
minor !== undefined &&
|
||||
Number.isInteger(Number(minor)) &&
|
||||
patch !== undefined &&
|
||||
Number.isInteger(Number(patch))
|
||||
);
|
||||
});
|
||||
|
||||
export const zParsedSemver = zSemVer.transform((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return {
|
||||
major: Number(major),
|
||||
minor: Number(minor),
|
||||
patch: Number(patch),
|
||||
};
|
||||
});
|
||||
|
||||
export type SemVer = z.infer<typeof zSemVer>;
|
||||
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
||||
@ -812,6 +1135,7 @@ export const zInvocationNodeData = z.object({
|
||||
notes: z.string(),
|
||||
embedWorkflow: z.boolean(),
|
||||
isIntermediate: z.boolean(),
|
||||
version: zSemVer.optional(),
|
||||
});
|
||||
|
||||
// Massage this to get better type safety while developing
|
||||
@ -900,20 +1224,6 @@ export const zFieldIdentifier = z.object({
|
||||
|
||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
major !== undefined &&
|
||||
minor !== undefined &&
|
||||
patch !== undefined &&
|
||||
Number.isInteger(Number(major)) &&
|
||||
Number.isInteger(Number(minor)) &&
|
||||
Number.isInteger(Number(patch))
|
||||
);
|
||||
});
|
||||
|
||||
export type SemVer = z.infer<typeof zSemVer>;
|
||||
|
||||
export type WorkflowWarning = {
|
||||
message: string;
|
||||
issues: string[];
|
||||
|
@ -1,5 +1,14 @@
|
||||
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
SINGLE_TO_POLYMORPHIC_MAP,
|
||||
isCollectionItemType,
|
||||
isPolymorphicItemType,
|
||||
} from '../types/constants';
|
||||
import {
|
||||
BooleanCollectionInputFieldTemplate,
|
||||
BooleanInputFieldTemplate,
|
||||
ClipInputFieldTemplate,
|
||||
CollectionInputFieldTemplate,
|
||||
@ -11,10 +20,13 @@ import {
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatCollectionInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
ImageCollectionInputFieldTemplate,
|
||||
ImageInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
IntegerCollectionInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
InvocationFieldSchema,
|
||||
InvocationSchemaObject,
|
||||
@ -24,11 +36,32 @@ import {
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SchedulerInputFieldTemplate,
|
||||
StringCollectionInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
isArraySchemaObject,
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ColorPolymorphicInputFieldTemplate,
|
||||
ColorCollectionInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
LatentsCollectionInputFieldTemplate,
|
||||
ConditioningPolymorphicInputFieldTemplate,
|
||||
ConditioningCollectionInputFieldTemplate,
|
||||
ControlCollectionInputFieldTemplate,
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
|
||||
@ -45,15 +78,8 @@ export type BuildInputFieldArg = {
|
||||
* @example
|
||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||
*/
|
||||
export const refObjectToFieldType = (
|
||||
refObject: OpenAPIV3.ReferenceObject
|
||||
): FieldType => {
|
||||
const name = refObject.$ref.split('/').slice(-1)[0];
|
||||
if (!name) {
|
||||
throw `Unknown field type: ${name}`;
|
||||
}
|
||||
return name as FieldType;
|
||||
};
|
||||
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
|
||||
refObject.$ref.split('/').slice(-1)[0];
|
||||
|
||||
const buildIntegerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
|
||||
const template: IntegerPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IntegerPolymorphic',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
|
||||
const item_default =
|
||||
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: 0;
|
||||
const template: IntegerCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IntegerCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
|
||||
const template: FloatPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'FloatPolymorphic',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
|
||||
const item_default = isNumber(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: 0;
|
||||
const template: FloatCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'FloatCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
|
||||
const template: StringPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'StringPolymorphic',
|
||||
default: schemaObject.default ?? '',
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (schemaObject.pattern !== undefined) {
|
||||
template.pattern = schemaObject.pattern;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
|
||||
const item_default = isString(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: '';
|
||||
const template: StringCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'StringCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
|
||||
const template: BooleanPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'BooleanPolymorphic',
|
||||
default: schemaObject.default ?? false,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
|
||||
const item_default =
|
||||
schemaObject.item_default && isBoolean(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: false;
|
||||
const template: BooleanCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'BooleanCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMainModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImagePolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
|
||||
const template: ImagePolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ImagePolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -257,7 +468,8 @@ const buildImageCollectionInputFieldTemplate = ({
|
||||
const template: ImageCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ImageCollection',
|
||||
default: schemaObject.default ?? undefined,
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ImageField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
@ -289,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLatentsPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
|
||||
const template: LatentsPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'LatentsPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLatentsCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
|
||||
const template: LatentsCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'LatentsCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -302,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
|
||||
const template: ConditioningPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ConditioningPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
|
||||
const template: ConditioningCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ConditioningCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildUNetInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -355,6 +621,33 @@ const buildControlInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
|
||||
const template: ControlPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ControlPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
|
||||
const template: ControlCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ControlCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ControlField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -408,6 +701,32 @@ const buildColorInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
|
||||
const template: ColorPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ColorPolymorphic',
|
||||
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
|
||||
const template: ColorCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ColorCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSchedulerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -421,45 +740,138 @@ const buildSchedulerInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
|
||||
let fieldType = '';
|
||||
|
||||
const { ui_type } = schemaObject;
|
||||
if (ui_type) {
|
||||
fieldType = ui_type;
|
||||
export const getFieldType = (
|
||||
schemaObject: InvocationFieldSchema
|
||||
): string | undefined => {
|
||||
if (schemaObject?.ui_type) {
|
||||
return schemaObject.ui_type;
|
||||
} else if (!schemaObject.type) {
|
||||
// console.log('refObject', schemaObject);
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
|
||||
if (schemaObject.allOf) {
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
const allOf = schemaObject.allOf;
|
||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||
return refObjectToSchemaName(allOf[0]);
|
||||
}
|
||||
} else if (schemaObject.anyOf) {
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.oneOf) {
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
const anyOf = schemaObject.anyOf;
|
||||
/**
|
||||
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||
* - an `anyOf` with two items
|
||||
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||
*
|
||||
* Any other cases we ignore.
|
||||
*/
|
||||
|
||||
let firstType: string | undefined;
|
||||
let secondType: string | undefined;
|
||||
|
||||
if (isArraySchemaObject(anyOf[0])) {
|
||||
// first is array, second is not
|
||||
const first = anyOf[0].items;
|
||||
const second = anyOf[1];
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
} else if (isArraySchemaObject(anyOf[1])) {
|
||||
// first is not array, second is
|
||||
const first = anyOf[0];
|
||||
const second = anyOf[1].items;
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
}
|
||||
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||
}
|
||||
}
|
||||
} else if (schemaObject.enum) {
|
||||
fieldType = 'enum';
|
||||
return 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'number') {
|
||||
// floats are "number" in OpenAPI, while ints are "integer"
|
||||
fieldType = 'float';
|
||||
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||
return 'float';
|
||||
} else if (schemaObject.type === 'array') {
|
||||
const itemType = isSchemaObject(schemaObject.items)
|
||||
? schemaObject.items.type
|
||||
: refObjectToSchemaName(schemaObject.items);
|
||||
|
||||
if (isCollectionItemType(itemType)) {
|
||||
return COLLECTION_MAP[itemType];
|
||||
}
|
||||
|
||||
return;
|
||||
} else {
|
||||
fieldType = schemaObject.type;
|
||||
return schemaObject.type;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldType;
|
||||
return;
|
||||
};
|
||||
|
||||
const TEMPLATE_BUILDER_MAP = {
|
||||
boolean: buildBooleanInputFieldTemplate,
|
||||
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||
ClipField: buildClipInputFieldTemplate,
|
||||
Collection: buildCollectionInputFieldTemplate,
|
||||
CollectionItem: buildCollectionItemInputFieldTemplate,
|
||||
ColorCollection: buildColorCollectionInputFieldTemplate,
|
||||
ColorField: buildColorInputFieldTemplate,
|
||||
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
|
||||
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
|
||||
ConditioningField: buildConditioningInputFieldTemplate,
|
||||
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
|
||||
ControlCollection: buildControlCollectionInputFieldTemplate,
|
||||
ControlField: buildControlInputFieldTemplate,
|
||||
ControlNetModelField: buildControlNetModelInputFieldTemplate,
|
||||
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
|
||||
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
|
||||
enum: buildEnumInputFieldTemplate,
|
||||
float: buildFloatInputFieldTemplate,
|
||||
FloatCollection: buildFloatCollectionInputFieldTemplate,
|
||||
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
|
||||
ImageCollection: buildImageCollectionInputFieldTemplate,
|
||||
ImageField: buildImageInputFieldTemplate,
|
||||
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
|
||||
integer: buildIntegerInputFieldTemplate,
|
||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||
MainModelField: buildMainModelInputFieldTemplate,
|
||||
Scheduler: buildSchedulerInputFieldTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
|
||||
string: buildStringInputFieldTemplate,
|
||||
StringCollection: buildStringCollectionInputFieldTemplate,
|
||||
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
||||
UNetField: buildUNetInputFieldTemplate,
|
||||
VaeField: buildVaeInputFieldTemplate,
|
||||
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||
};
|
||||
|
||||
const isTemplatedFieldType = (
|
||||
fieldType: string | undefined
|
||||
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
|
||||
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
|
||||
|
||||
/**
|
||||
* Builds an input field from an invocation schema property.
|
||||
* @param fieldSchema The schema object
|
||||
@ -474,7 +886,8 @@ export const buildInputFieldTemplate = (
|
||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
||||
|
||||
const extra = {
|
||||
input,
|
||||
// TODO: Can we support polymorphic inputs in the UI?
|
||||
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
@ -490,146 +903,12 @@ export const buildInputFieldTemplate = (
|
||||
...extra,
|
||||
};
|
||||
|
||||
if (fieldType === 'ImageField') {
|
||||
return buildImageInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
if (!isTemplatedFieldType(fieldType)) {
|
||||
return;
|
||||
}
|
||||
if (fieldType === 'ImageCollection') {
|
||||
return buildImageCollectionInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'DenoiseMaskField') {
|
||||
return buildDenoiseMaskInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'LatentsField') {
|
||||
return buildLatentsInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ConditioningField') {
|
||||
return buildConditioningInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'UNetField') {
|
||||
return buildUNetInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ClipField') {
|
||||
return buildClipInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'VaeField') {
|
||||
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
|
||||
}
|
||||
if (fieldType === 'ControlField') {
|
||||
return buildControlInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'MainModelField') {
|
||||
return buildMainModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'SDXLRefinerModelField') {
|
||||
return buildRefinerModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'SDXLMainModelField') {
|
||||
return buildSDXLMainModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'VaeModelField') {
|
||||
return buildVaeModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'LoRAModelField') {
|
||||
return buildLoRAModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ControlNetModelField') {
|
||||
return buildControlNetModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'enum') {
|
||||
return buildEnumInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'integer') {
|
||||
return buildIntegerInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'float') {
|
||||
return buildFloatInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'string') {
|
||||
return buildStringInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'boolean') {
|
||||
return buildBooleanInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'Collection') {
|
||||
return buildCollectionInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'CollectionItem') {
|
||||
return buildCollectionItemInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ColorField') {
|
||||
return buildColorInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'Scheduler') {
|
||||
return buildSchedulerInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
return;
|
||||
|
||||
return TEMPLATE_BUILDER_MAP[fieldType]({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
};
|
||||
|
@ -1,104 +1,79 @@
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
|
||||
const FIELD_VALUE_FALLBACK_MAP = {
|
||||
'enum.number': 0,
|
||||
'enum.string': '',
|
||||
boolean: false,
|
||||
BooleanCollection: [],
|
||||
BooleanPolymorphic: false,
|
||||
ClipField: undefined,
|
||||
Collection: [],
|
||||
CollectionItem: undefined,
|
||||
ColorCollection: [],
|
||||
ColorField: undefined,
|
||||
ColorPolymorphic: undefined,
|
||||
ConditioningCollection: [],
|
||||
ConditioningField: undefined,
|
||||
ConditioningPolymorphic: undefined,
|
||||
ControlCollection: [],
|
||||
ControlField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
ControlPolymorphic: undefined,
|
||||
DenoiseMaskField: undefined,
|
||||
float: 0,
|
||||
FloatCollection: [],
|
||||
FloatPolymorphic: 0,
|
||||
ImageCollection: [],
|
||||
ImageField: undefined,
|
||||
ImagePolymorphic: undefined,
|
||||
integer: 0,
|
||||
IntegerCollection: [],
|
||||
IntegerPolymorphic: 0,
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
ONNXModelField: undefined,
|
||||
Scheduler: 'euler',
|
||||
SDXLMainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
string: '',
|
||||
StringCollection: [],
|
||||
StringPolymorphic: '',
|
||||
UNetField: undefined,
|
||||
VaeField: undefined,
|
||||
VaeModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildInputFieldValue = (
|
||||
id: string,
|
||||
template: InputFieldTemplate
|
||||
): InputFieldValue => {
|
||||
const fieldValue: InputFieldValue = {
|
||||
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
|
||||
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
|
||||
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
|
||||
// `InputFieldValue` union, but TS doesn't seem to like it...
|
||||
const fieldValue = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input',
|
||||
};
|
||||
|
||||
if (template.type === 'string') {
|
||||
fieldValue.value = template.default ?? '';
|
||||
}
|
||||
|
||||
if (template.type === 'integer') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'float') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'boolean') {
|
||||
fieldValue.value = template.default ?? false;
|
||||
}
|
||||
} as InputFieldValue;
|
||||
|
||||
if (template.type === 'enum') {
|
||||
if (template.enumType === 'number') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
|
||||
}
|
||||
if (template.enumType === 'string') {
|
||||
fieldValue.value = template.default ?? '';
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
|
||||
}
|
||||
}
|
||||
|
||||
if (template.type === 'Collection') {
|
||||
fieldValue.value = template.default ?? 1;
|
||||
}
|
||||
|
||||
if (template.type === 'ImageField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ImageCollection') {
|
||||
fieldValue.value = [];
|
||||
}
|
||||
|
||||
if (template.type === 'DenoiseMaskField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'LatentsField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ConditioningField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'UNetField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ClipField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'VaeField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ControlField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'MainModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'SDXLRefinerModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'VaeModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'LoRAModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ControlNetModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'Scheduler') {
|
||||
fieldValue.value = 'euler';
|
||||
} else {
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -73,6 +73,7 @@ export const parseSchema = (
|
||||
const title = schema.title.replace('Invocation', '');
|
||||
const tags = schema.tags ?? [];
|
||||
const description = schema.description ?? '';
|
||||
const version = schema.version ?? '';
|
||||
|
||||
const inputs = reduce(
|
||||
schema.properties,
|
||||
@ -225,11 +226,12 @@ export const parseSchema = (
|
||||
const invocation: InvocationTemplate = {
|
||||
title,
|
||||
type,
|
||||
version,
|
||||
tags,
|
||||
description,
|
||||
outputType,
|
||||
inputs,
|
||||
outputs,
|
||||
outputType,
|
||||
};
|
||||
|
||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||
|
@ -0,0 +1,96 @@
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { cloneDeep, keyBy } from 'lodash-es';
|
||||
import {
|
||||
InvocationTemplate,
|
||||
Workflow,
|
||||
WorkflowWarning,
|
||||
isWorkflowInvocationNode,
|
||||
} from '../types/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
|
||||
export const validateWorkflow = (
|
||||
workflow: Workflow,
|
||||
nodeTemplates: Record<string, InvocationTemplate>
|
||||
) => {
|
||||
const clone = cloneDeep(workflow);
|
||||
const { nodes, edges } = clone;
|
||||
const errors: WorkflowWarning[] = [];
|
||||
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||
nodes.forEach((node) => {
|
||||
if (!isWorkflowInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = nodeTemplates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
errors.push({
|
||||
message: `Node "${node.data.type}" skipped`,
|
||||
issues: [`Node type "${node.data.type}" does not exist`],
|
||||
data: node,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
nodeTemplate.version &&
|
||||
node.data.version &&
|
||||
compareVersions(nodeTemplate.version, node.data.version) !== 0
|
||||
) {
|
||||
errors.push({
|
||||
message: `Node "${node.data.type}" has mismatched version`,
|
||||
issues: [
|
||||
`Node "${node.data.type}" v${node.data.version} may be incompatible with installed v${nodeTemplate.version}`,
|
||||
],
|
||||
data: { node, nodeTemplate: parseify(nodeTemplate) },
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
edges.forEach((edge, i) => {
|
||||
const sourceNode = keyedNodes[edge.source];
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
const issues: string[] = [];
|
||||
if (!sourceNode) {
|
||||
issues.push(`Output node ${edge.source} does not exist`);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||
);
|
||||
}
|
||||
if (!targetNode) {
|
||||
issues.push(`Input node ${edge.target} does not exist`);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.targetHandle in targetNode.data.inputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||
);
|
||||
}
|
||||
if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||
issues.push(
|
||||
`Source node "${edge.source}" missing template "${sourceNode?.data.type}"`
|
||||
);
|
||||
}
|
||||
if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||
issues.push(
|
||||
`Source node "${edge.target}" missing template "${targetNode?.data.type}"`
|
||||
);
|
||||
}
|
||||
if (issues.length) {
|
||||
delete edges[i];
|
||||
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
|
||||
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
|
||||
errors.push({
|
||||
message: `Edge "${src} -> ${tgt}" skipped`,
|
||||
issues,
|
||||
data: edge,
|
||||
});
|
||||
}
|
||||
});
|
||||
return { workflow: clone, errors };
|
||||
};
|
File diff suppressed because one or more lines are too long
@ -2970,6 +2970,11 @@ commondir@^1.0.1:
|
||||
resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b"
|
||||
integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==
|
||||
|
||||
compare-versions@^6.1.0:
|
||||
version "6.1.0"
|
||||
resolved "https://registry.yarnpkg.com/compare-versions/-/compare-versions-6.1.0.tgz#3f2131e3ae93577df111dba133e6db876ffe127a"
|
||||
integrity sha512-LNZQXhqUvqUTotpZ00qLSaify3b4VFD588aRr8MKFw4CMUr98ytzCW5wDH5qx/DEY5kCDXcbcRuCqL0szEf2tg==
|
||||
|
||||
compute-scroll-into-view@1.0.20:
|
||||
version "1.0.20"
|
||||
resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.20.tgz#1768b5522d1172754f5d0c9b02de3af6be506a43"
|
||||
|
@ -74,6 +74,7 @@ dependencies = [
|
||||
"rich~=13.3",
|
||||
"safetensors==0.3.1",
|
||||
"scikit-image~=0.21.0",
|
||||
"semver~=3.0.1",
|
||||
"send2trash",
|
||||
"test-tube~=0.7.5",
|
||||
"torch~=2.0.1",
|
||||
|
@ -1,3 +1,10 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvalidVersionError,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .test_nodes import (
|
||||
ImageToImageTestInvocation,
|
||||
TextToImageTestInvocation,
|
||||
@ -20,7 +27,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import IntegerInvocation
|
||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||
from invokeai.app.services.default_graphs import create_text_to_image
|
||||
import pytest
|
||||
|
||||
@ -610,6 +617,79 @@ def test_graph_can_deserialize():
|
||||
assert g2.edges[0].destination.field == "image"
|
||||
|
||||
|
||||
def test_invocation_decorator():
|
||||
invocation_type = "test_invocation"
|
||||
title = "Test Invocation"
|
||||
tags = ["first", "second", "third"]
|
||||
category = "category"
|
||||
version = "1.2.3"
|
||||
|
||||
@invocation(invocation_type, title=title, tags=tags, category=category, version=version)
|
||||
class TestInvocation(BaseInvocation):
|
||||
def invoke(self):
|
||||
pass
|
||||
|
||||
schema = TestInvocation.schema()
|
||||
|
||||
assert schema.get("title") == title
|
||||
assert schema.get("tags") == tags
|
||||
assert schema.get("category") == category
|
||||
assert schema.get("version") == version
|
||||
assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added)
|
||||
|
||||
|
||||
def test_invocation_version_must_be_semver():
|
||||
invocation_type = "test_invocation"
|
||||
valid_version = "1.0.0"
|
||||
invalid_version = "not_semver"
|
||||
|
||||
@invocation(invocation_type, version=valid_version)
|
||||
class ValidVersionInvocation(BaseInvocation):
|
||||
def invoke(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(InvalidVersionError):
|
||||
|
||||
@invocation(invocation_type, version=invalid_version)
|
||||
class InvalidVersionInvocation(BaseInvocation):
|
||||
def invoke(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_invocation_output_decorator():
|
||||
output_type = "test_output"
|
||||
|
||||
@invocation_output(output_type)
|
||||
class TestOutput(BaseInvocationOutput):
|
||||
pass
|
||||
|
||||
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
|
||||
|
||||
|
||||
def test_floats_accept_ints():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = FloatInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_ints_do_not_accept_floats():
|
||||
g = Graph()
|
||||
n1 = FloatInvocation(id="1", value=1.0)
|
||||
n2 = IntegerInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
|
Loading…
Reference in New Issue
Block a user