mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into textfontimage
This commit is contained in:
commit
eb90ea41fd
@ -244,8 +244,12 @@ copy-paste the template above.
|
|||||||
We can use the `@invocation` decorator to provide some additional info to the
|
We can use the `@invocation` decorator to provide some additional info to the
|
||||||
UI, like a custom title, tags and category.
|
UI, like a custom title, tags and category.
|
||||||
|
|
||||||
|
We also encourage providing a version. This must be a
|
||||||
|
[semver](https://semver.org/) version string ("$MAJOR.$MINOR.$PATCH"). The UI
|
||||||
|
will let users know if their workflow is using a mismatched version of the node.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
|
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations", version="1.0.0")
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image"""
|
"""Resizes an image"""
|
||||||
|
|
||||||
@ -279,8 +283,6 @@ take a look a at our [contributing nodes overview](contributingNodes).
|
|||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
-->
|
|
||||||
|
|
||||||
### Custom Output Types
|
### Custom Output Types
|
||||||
|
|
||||||
Like with custom inputs, sometimes you might find yourself needing custom
|
Like with custom inputs, sometimes you might find yourself needing custom
|
||||||
|
@ -109,6 +109,73 @@ a Text-Generation-Webui instance (might work remotely too, but I never tried it)
|
|||||||
|
|
||||||
This node works best with SDXL models, especially as the style can be described independantly of the LLM's output.
|
This node works best with SDXL models, especially as the style can be described independantly of the LLM's output.
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Depth Map from Wavefront OBJ
|
||||||
|
|
||||||
|
**Description:** Render depth maps from Wavefront .obj files (triangulated) using this simple 3D renderer utilizing numpy and matplotlib to compute and color the scene. There are simple parameters to change the FOV, camera position, and model orientation.
|
||||||
|
|
||||||
|
To be imported, an .obj must use triangulated meshes, so make sure to enable that option if exporting from a 3D modeling program. This renderer makes each triangle a solid color based on its average depth, so it will cause anomalies if your .obj has large triangles. In Blender, the Remesh modifier can be helpful to subdivide a mesh into small pieces that work well given these limitations.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/dwringer/depth-from-obj-node
|
||||||
|
|
||||||
|
**Example Usage:**
|
||||||
|

|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Enhance Image (simple adjustments)
|
||||||
|
|
||||||
|
**Description:** Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
||||||
|
|
||||||
|
Color inversion is toggled with a simple switch, while each of the four enhancer modes are activated by entering a value other than 1 in each corresponding input field. Values less than 1 will reduce the corresponding property, while values greater than 1 will enhance it.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/dwringer/image-enhance-node
|
||||||
|
|
||||||
|
**Example Usage:**
|
||||||
|

|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Generative Grammar-Based Prompt Nodes
|
||||||
|
|
||||||
|
**Description:** This set of 3 nodes generates prompts from simple user-defined grammar rules (loaded from custom files - examples provided below). The prompts are made by recursively expanding a special template string, replacing nonterminal "parts-of-speech" until no more nonterminal terms remain in the string.
|
||||||
|
|
||||||
|
This includes 3 Nodes:
|
||||||
|
- *Lookup Table from File* - loads a YAML file "prompt" section (or of a whole folder of YAML's) into a JSON-ified dictionary (Lookups output)
|
||||||
|
- *Lookups Entry from Prompt* - places a single entry in a new Lookups output under the specified heading
|
||||||
|
- *Prompt from Lookup Table* - uses a Collection of Lookups as grammar rules from which to randomly generate prompts.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/dwringer/generative-grammar-prompt-nodes
|
||||||
|
|
||||||
|
**Example Usage:**
|
||||||
|

|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Image and Mask Composition Pack
|
||||||
|
|
||||||
|
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
|
||||||
|
|
||||||
|
This includes 4 Nodes:
|
||||||
|
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
||||||
|
- *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal.
|
||||||
|
- *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around.
|
||||||
|
- *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/dwringer/composition-nodes
|
||||||
|
|
||||||
|
**Example Usage:**
|
||||||
|

|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Size Stepper Nodes
|
||||||
|
|
||||||
|
**Description:** This is a set of nodes for calculating the necessary size increments for doing upscaling workflows. Use the *Final Size & Orientation* node to enter your full size dimensions and orientation (portrait/landscape/random), then plug that and your initial generation dimensions into the *Ideal Size Stepper* and get 1, 2, or 3 intermediate pairs of dimensions for upscaling. Note this does not output the initial size or full size dimensions: the 1, 2, or 3 outputs of this node are only the intermediate sizes.
|
||||||
|
|
||||||
|
A third node is included, *Random Switch (Integers)*, which is just a generic version of Final Size with no orientation selection.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/dwringer/size-stepper-nodes
|
||||||
|
|
||||||
|
**Example Usage:**
|
||||||
|

|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
|
|
||||||
### Text font to Image
|
### Text font to Image
|
||||||
|
@ -35,13 +35,13 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|
|||||||
|Inverse Lerp Image | Inverse linear interpolation of all pixels of an image|
|
|Inverse Lerp Image | Inverse linear interpolation of all pixels of an image|
|
||||||
|Image Primitive | An image primitive value|
|
|Image Primitive | An image primitive value|
|
||||||
|Lerp Image | Linear interpolation of all pixels of an image|
|
|Lerp Image | Linear interpolation of all pixels of an image|
|
||||||
|Image Luminosity Adjustment | Adjusts the Luminosity (Value) of an image.|
|
|Offset Image Channel | Add to or subtract from an image color channel by a uniform value.|
|
||||||
|
|Multiply Image Channel | Multiply or Invert an image color channel by a scalar value.|
|
||||||
|Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`.|
|
|Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`.|
|
||||||
|Blur NSFW Image | Add blur to NSFW-flagged images|
|
|Blur NSFW Image | Add blur to NSFW-flagged images|
|
||||||
|Paste Image | Pastes an image into another image.|
|
|Paste Image | Pastes an image into another image.|
|
||||||
|ImageProcessor | Base class for invocations that preprocess images for ControlNet|
|
|ImageProcessor | Base class for invocations that preprocess images for ControlNet|
|
||||||
|Resize Image | Resizes an image to specific dimensions|
|
|Resize Image | Resizes an image to specific dimensions|
|
||||||
|Image Saturation Adjustment | Adjusts the Saturation of an image.|
|
|
||||||
|Scale Image | Scales an image by a factor|
|
|Scale Image | Scales an image by a factor|
|
||||||
|Image to Latents | Encodes an image into latents.|
|
|Image to Latents | Encodes an image into latents.|
|
||||||
|Add Invisible Watermark | Add an invisible watermark to an image|
|
|Add Invisible Watermark | Add an invisible watermark to an image|
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
import typing
|
import typing
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pathlib import Path
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||||
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.util.logging import logging
|
||||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
|
||||||
|
|
||||||
from invokeai.version import __version__
|
from invokeai.version import __version__
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend.util.logging import logging
|
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(int, Enum):
|
class LogLevel(int, Enum):
|
||||||
@ -55,7 +55,7 @@ async def get_version() -> AppVersion:
|
|||||||
|
|
||||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||||
async def get_config() -> AppConfig:
|
async def get_config() -> AppConfig:
|
||||||
infill_methods = ["tile", "lama"]
|
infill_methods = ["tile", "lama", "cv2"]
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infill_methods.append("patchmatch")
|
infill_methods.append("patchmatch")
|
||||||
|
|
||||||
|
@ -26,11 +26,16 @@ from typing import (
|
|||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from pydantic.fields import Undefined, ModelField
|
from pydantic.fields import Undefined, ModelField
|
||||||
from pydantic.typing import NoArgAnyCallable
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
import semver
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidVersionError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FieldDescriptions:
|
class FieldDescriptions:
|
||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
@ -105,24 +110,39 @@ class UIType(str, Enum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# region Primitives
|
# region Primitives
|
||||||
Integer = "integer"
|
|
||||||
Float = "float"
|
|
||||||
Boolean = "boolean"
|
Boolean = "boolean"
|
||||||
String = "string"
|
Color = "ColorField"
|
||||||
Array = "array"
|
|
||||||
Image = "ImageField"
|
|
||||||
Latents = "LatentsField"
|
|
||||||
Conditioning = "ConditioningField"
|
Conditioning = "ConditioningField"
|
||||||
Control = "ControlField"
|
Control = "ControlField"
|
||||||
Color = "ColorField"
|
Float = "float"
|
||||||
ImageCollection = "ImageCollection"
|
Image = "ImageField"
|
||||||
ConditioningCollection = "ConditioningCollection"
|
Integer = "integer"
|
||||||
ColorCollection = "ColorCollection"
|
Latents = "LatentsField"
|
||||||
LatentsCollection = "LatentsCollection"
|
String = "string"
|
||||||
IntegerCollection = "IntegerCollection"
|
# endregion
|
||||||
FloatCollection = "FloatCollection"
|
|
||||||
StringCollection = "StringCollection"
|
# region Collection Primitives
|
||||||
BooleanCollection = "BooleanCollection"
|
BooleanCollection = "BooleanCollection"
|
||||||
|
ColorCollection = "ColorCollection"
|
||||||
|
ConditioningCollection = "ConditioningCollection"
|
||||||
|
ControlCollection = "ControlCollection"
|
||||||
|
FloatCollection = "FloatCollection"
|
||||||
|
ImageCollection = "ImageCollection"
|
||||||
|
IntegerCollection = "IntegerCollection"
|
||||||
|
LatentsCollection = "LatentsCollection"
|
||||||
|
StringCollection = "StringCollection"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Polymorphic Primitives
|
||||||
|
BooleanPolymorphic = "BooleanPolymorphic"
|
||||||
|
ColorPolymorphic = "ColorPolymorphic"
|
||||||
|
ConditioningPolymorphic = "ConditioningPolymorphic"
|
||||||
|
ControlPolymorphic = "ControlPolymorphic"
|
||||||
|
FloatPolymorphic = "FloatPolymorphic"
|
||||||
|
ImagePolymorphic = "ImagePolymorphic"
|
||||||
|
IntegerPolymorphic = "IntegerPolymorphic"
|
||||||
|
LatentsPolymorphic = "LatentsPolymorphic"
|
||||||
|
StringPolymorphic = "StringPolymorphic"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Models
|
# region Models
|
||||||
@ -176,6 +196,7 @@ class _InputField(BaseModel):
|
|||||||
ui_type: Optional[UIType]
|
ui_type: Optional[UIType]
|
||||||
ui_component: Optional[UIComponent]
|
ui_component: Optional[UIComponent]
|
||||||
ui_order: Optional[int]
|
ui_order: Optional[int]
|
||||||
|
item_default: Optional[Any]
|
||||||
|
|
||||||
|
|
||||||
class _OutputField(BaseModel):
|
class _OutputField(BaseModel):
|
||||||
@ -223,6 +244,7 @@ def InputField(
|
|||||||
ui_component: Optional[UIComponent] = None,
|
ui_component: Optional[UIComponent] = None,
|
||||||
ui_hidden: bool = False,
|
ui_hidden: bool = False,
|
||||||
ui_order: Optional[int] = None,
|
ui_order: Optional[int] = None,
|
||||||
|
item_default: Optional[Any] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -249,6 +271,11 @@ def InputField(
|
|||||||
For this case, you could provide `UIComponent.Textarea`.
|
For this case, you could provide `UIComponent.Textarea`.
|
||||||
|
|
||||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||||
|
|
||||||
|
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||||
|
|
||||||
|
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||||
|
Ignored for non-collection fields..
|
||||||
"""
|
"""
|
||||||
return Field(
|
return Field(
|
||||||
*args,
|
*args,
|
||||||
@ -282,6 +309,7 @@ def InputField(
|
|||||||
ui_component=ui_component,
|
ui_component=ui_component,
|
||||||
ui_hidden=ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
ui_order=ui_order,
|
ui_order=ui_order,
|
||||||
|
item_default=item_default,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -332,6 +360,8 @@ def OutputField(
|
|||||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||||
|
|
||||||
|
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||||
"""
|
"""
|
||||||
return Field(
|
return Field(
|
||||||
*args,
|
*args,
|
||||||
@ -376,6 +406,9 @@ class UIConfigBase(BaseModel):
|
|||||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||||
category: Optional[str] = Field(default=None, description="The node's category")
|
category: Optional[str] = Field(default=None, description="The node's category")
|
||||||
|
version: Optional[str] = Field(
|
||||||
|
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
@ -474,6 +507,8 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["tags"] = uiconfig.tags
|
schema["tags"] = uiconfig.tags
|
||||||
if uiconfig and hasattr(uiconfig, "category"):
|
if uiconfig and hasattr(uiconfig, "category"):
|
||||||
schema["category"] = uiconfig.category
|
schema["category"] = uiconfig.category
|
||||||
|
if uiconfig and hasattr(uiconfig, "version"):
|
||||||
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = list()
|
schema["required"] = list()
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
@ -542,7 +577,11 @@ GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
|||||||
|
|
||||||
|
|
||||||
def invocation(
|
def invocation(
|
||||||
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
|
invocation_type: str,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||||
"""
|
"""
|
||||||
Adds metadata to an invocation.
|
Adds metadata to an invocation.
|
||||||
@ -569,6 +608,12 @@ def invocation(
|
|||||||
cls.UIConfig.tags = tags
|
cls.UIConfig.tags = tags
|
||||||
if category is not None:
|
if category is not None:
|
||||||
cls.UIConfig.category = category
|
cls.UIConfig.category = category
|
||||||
|
if version is not None:
|
||||||
|
try:
|
||||||
|
semver.Version.parse(version)
|
||||||
|
except ValueError as e:
|
||||||
|
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||||
|
cls.UIConfig.version = version
|
||||||
|
|
||||||
# Add the invocation type to the pydantic model of the invocation
|
# Add the invocation type to the pydantic model of the invocation
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
|
@ -10,7 +10,9 @@ from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
|
@invocation(
|
||||||
|
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
|
||||||
|
)
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range of numbers from start to stop with step"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
|
||||||
@ -33,6 +35,7 @@ class RangeInvocation(BaseInvocation):
|
|||||||
title="Integer Range of Size",
|
title="Integer Range of Size",
|
||||||
tags=["collection", "integer", "size", "range"],
|
tags=["collection", "integer", "size", "range"],
|
||||||
category="collections",
|
category="collections",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class RangeOfSizeInvocation(BaseInvocation):
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
"""Creates a range from start to start + size with step"""
|
"""Creates a range from start to start + size with step"""
|
||||||
@ -50,6 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
|||||||
title="Random Range",
|
title="Random Range",
|
||||||
tags=["range", "integer", "random", "collection"],
|
tags=["range", "integer", "random", "collection"],
|
||||||
category="collections",
|
category="collections",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
"""Creates a collection of random numbers"""
|
"""Creates a collection of random numbers"""
|
||||||
|
@ -44,7 +44,7 @@ class ConditioningFieldData:
|
|||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
|
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
@ -267,6 +267,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -351,6 +352,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
title="SDXL Refiner Prompt",
|
title="SDXL Refiner Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -403,7 +405,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
|||||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
|
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
|
@ -95,14 +95,12 @@ class ControlOutput(BaseInvocationOutput):
|
|||||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||||
|
|
||||||
|
|
||||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
|
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
|
||||||
class ControlNetInvocation(BaseInvocation):
|
class ControlNetInvocation(BaseInvocation):
|
||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The control image")
|
image: ImageField = InputField(description="The control image")
|
||||||
control_model: ControlNetModelField = InputField(
|
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
|
||||||
)
|
|
||||||
control_weight: Union[float, List[float]] = InputField(
|
control_weight: Union[float, List[float]] = InputField(
|
||||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||||
)
|
)
|
||||||
@ -129,7 +127,9 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
|
@invocation(
|
||||||
|
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||||
|
)
|
||||||
class ImageProcessorInvocation(BaseInvocation):
|
class ImageProcessorInvocation(BaseInvocation):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
@ -173,6 +173,7 @@ class ImageProcessorInvocation(BaseInvocation):
|
|||||||
title="Canny Processor",
|
title="Canny Processor",
|
||||||
tags=["controlnet", "canny"],
|
tags=["controlnet", "canny"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
@ -195,6 +196,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="HED (softedge) Processor",
|
title="HED (softedge) Processor",
|
||||||
tags=["controlnet", "hed", "softedge"],
|
tags=["controlnet", "hed", "softedge"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
@ -223,6 +225,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Processor",
|
title="Lineart Processor",
|
||||||
tags=["controlnet", "lineart"],
|
tags=["controlnet", "lineart"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
@ -244,6 +247,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Anime Processor",
|
title="Lineart Anime Processor",
|
||||||
tags=["controlnet", "lineart", "anime"],
|
tags=["controlnet", "lineart", "anime"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
@ -266,6 +270,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Openpose Processor",
|
title="Openpose Processor",
|
||||||
tags=["controlnet", "openpose", "pose"],
|
tags=["controlnet", "openpose", "pose"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Openpose processing to image"""
|
"""Applies Openpose processing to image"""
|
||||||
@ -290,6 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Midas Depth Processor",
|
title="Midas Depth Processor",
|
||||||
tags=["controlnet", "midas"],
|
tags=["controlnet", "midas"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
@ -316,6 +322,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Normal BAE Processor",
|
title="Normal BAE Processor",
|
||||||
tags=["controlnet"],
|
tags=["controlnet"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
@ -331,7 +338,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
|
@invocation(
|
||||||
|
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0"
|
||||||
|
)
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
|
|
||||||
@ -352,7 +361,9 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
|
@invocation(
|
||||||
|
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0"
|
||||||
|
)
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
|
|
||||||
@ -378,6 +389,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Content Shuffle Processor",
|
title="Content Shuffle Processor",
|
||||||
tags=["controlnet", "contentshuffle"],
|
tags=["controlnet", "contentshuffle"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
@ -407,6 +419,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Zoe (Depth) Processor",
|
title="Zoe (Depth) Processor",
|
||||||
tags=["controlnet", "zoe", "depth"],
|
tags=["controlnet", "zoe", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
@ -422,6 +435,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Mediapipe Face Processor",
|
title="Mediapipe Face Processor",
|
||||||
tags=["controlnet", "mediapipe", "face"],
|
tags=["controlnet", "mediapipe", "face"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
@ -444,6 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Leres (Depth) Processor",
|
title="Leres (Depth) Processor",
|
||||||
tags=["controlnet", "leres", "depth"],
|
tags=["controlnet", "leres", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
@ -472,6 +487,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Tile Resample Processor",
|
title="Tile Resample Processor",
|
||||||
tags=["controlnet", "tile"],
|
tags=["controlnet", "tile"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Tile resampler processor"""
|
"""Tile resampler processor"""
|
||||||
@ -511,6 +527,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Segment Anything Processor",
|
title="Segment Anything Processor",
|
||||||
tags=["controlnet", "segmentanything"],
|
tags=["controlnet", "segmentanything"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
|
@ -10,12 +10,7 @@ from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
"cv_inpaint",
|
|
||||||
title="OpenCV Inpaint",
|
|
||||||
tags=["opencv", "inpaint"],
|
|
||||||
category="inpaint",
|
|
||||||
)
|
|
||||||
class CvInpaintInvocation(BaseInvocation):
|
class CvInpaintInvocation(BaseInvocation):
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from ..models.image import ImageCategory, ResourceOrigin
|
|||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("show_image", title="Show Image", tags=["image"], category="image")
|
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||||
class ShowImageInvocation(BaseInvocation):
|
class ShowImageInvocation(BaseInvocation):
|
||||||
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
|
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image")
|
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||||
class BlankImageInvocation(BaseInvocation):
|
class BlankImageInvocation(BaseInvocation):
|
||||||
"""Creates a blank image and forwards it to the pipeline"""
|
"""Creates a blank image and forwards it to the pipeline"""
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
|
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||||
class ImageCropInvocation(BaseInvocation):
|
class ImageCropInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
|
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.0")
|
||||||
class ImagePasteInvocation(BaseInvocation):
|
class ImagePasteInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
|
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||||
class MaskFromAlphaInvocation(BaseInvocation):
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
|
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||||
class ImageMultiplyInvocation(BaseInvocation):
|
class ImageMultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
@ -210,7 +210,7 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
|
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||||
class ImageChannelInvocation(BaseInvocation):
|
class ImageChannelInvocation(BaseInvocation):
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
@ -242,7 +242,7 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
|
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||||
class ImageConvertInvocation(BaseInvocation):
|
class ImageConvertInvocation(BaseInvocation):
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
@ -271,7 +271,7 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
|
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||||
class ImageBlurInvocation(BaseInvocation):
|
class ImageBlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
|
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||||
class ImageResizeInvocation(BaseInvocation):
|
class ImageResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
@ -365,7 +365,7 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
|
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||||
class ImageScaleInvocation(BaseInvocation):
|
class ImageScaleInvocation(BaseInvocation):
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
@ -406,7 +406,7 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
|
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||||
class ImageLerpInvocation(BaseInvocation):
|
class ImageLerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
|
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||||
class ImageInverseLerpInvocation(BaseInvocation):
|
class ImageInverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@ -472,7 +472,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
|
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
@ -517,7 +517,9 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
return caution.resize((caution.width // 2, caution.height // 2))
|
return caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
|
@invocation(
|
||||||
|
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
||||||
|
)
|
||||||
class ImageWatermarkInvocation(BaseInvocation):
|
class ImageWatermarkInvocation(BaseInvocation):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
|
|
||||||
@ -548,7 +550,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
|
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||||
class MaskEdgeInvocation(BaseInvocation):
|
class MaskEdgeInvocation(BaseInvocation):
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
@ -561,7 +563,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
mask = context.services.images.get_pil_image(self.image.image_name).convert("L")
|
||||||
|
|
||||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||||
npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0)))
|
npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0)))
|
||||||
@ -593,7 +595,9 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
|
@invocation(
|
||||||
|
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
||||||
|
)
|
||||||
class MaskCombineInvocation(BaseInvocation):
|
class MaskCombineInvocation(BaseInvocation):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
@ -623,7 +627,7 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
|
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||||
class ColorCorrectInvocation(BaseInvocation):
|
class ColorCorrectInvocation(BaseInvocation):
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
@ -696,8 +700,13 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
# Blur the mask out (into init image) by specified amount
|
# Blur the mask out (into init image) by specified amount
|
||||||
if self.mask_blur_radius > 0:
|
if self.mask_blur_radius > 0:
|
||||||
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
||||||
|
inverted_nm = 255 - nm
|
||||||
|
dilation_size = int(round(self.mask_blur_radius) + 20)
|
||||||
|
dilating_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size))
|
||||||
|
inverted_dilated_nm = cv2.dilate(inverted_nm, dilating_kernel)
|
||||||
|
dilated_nm = 255 - inverted_dilated_nm
|
||||||
nmd = cv2.erode(
|
nmd = cv2.erode(
|
||||||
nm,
|
dilated_nm,
|
||||||
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
|
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
|
||||||
iterations=int(self.mask_blur_radius / 2),
|
iterations=int(self.mask_blur_radius / 2),
|
||||||
)
|
)
|
||||||
@ -728,7 +737,7 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
|
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
@ -769,38 +778,95 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
COLOR_CHANNELS = Literal[
|
||||||
|
"Red (RGBA)",
|
||||||
|
"Green (RGBA)",
|
||||||
|
"Blue (RGBA)",
|
||||||
|
"Alpha (RGBA)",
|
||||||
|
"Cyan (CMYK)",
|
||||||
|
"Magenta (CMYK)",
|
||||||
|
"Yellow (CMYK)",
|
||||||
|
"Black (CMYK)",
|
||||||
|
"Hue (HSV)",
|
||||||
|
"Saturation (HSV)",
|
||||||
|
"Value (HSV)",
|
||||||
|
"Luminosity (LAB)",
|
||||||
|
"A (LAB)",
|
||||||
|
"B (LAB)",
|
||||||
|
"Y (YCbCr)",
|
||||||
|
"Cb (YCbCr)",
|
||||||
|
"Cr (YCbCr)",
|
||||||
|
]
|
||||||
|
|
||||||
|
CHANNEL_FORMATS = {
|
||||||
|
"Red (RGBA)": ("RGBA", 0),
|
||||||
|
"Green (RGBA)": ("RGBA", 1),
|
||||||
|
"Blue (RGBA)": ("RGBA", 2),
|
||||||
|
"Alpha (RGBA)": ("RGBA", 3),
|
||||||
|
"Cyan (CMYK)": ("CMYK", 0),
|
||||||
|
"Magenta (CMYK)": ("CMYK", 1),
|
||||||
|
"Yellow (CMYK)": ("CMYK", 2),
|
||||||
|
"Black (CMYK)": ("CMYK", 3),
|
||||||
|
"Hue (HSV)": ("HSV", 0),
|
||||||
|
"Saturation (HSV)": ("HSV", 1),
|
||||||
|
"Value (HSV)": ("HSV", 2),
|
||||||
|
"Luminosity (LAB)": ("LAB", 0),
|
||||||
|
"A (LAB)": ("LAB", 1),
|
||||||
|
"B (LAB)": ("LAB", 2),
|
||||||
|
"Y (YCbCr)": ("YCbCr", 0),
|
||||||
|
"Cb (YCbCr)": ("YCbCr", 1),
|
||||||
|
"Cr (YCbCr)": ("YCbCr", 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"img_luminosity_adjust",
|
"img_channel_offset",
|
||||||
title="Adjust Image Luminosity",
|
title="Offset Image Channel",
|
||||||
tags=["image", "luminosity", "hsl"],
|
tags=[
|
||||||
|
"image",
|
||||||
|
"offset",
|
||||||
|
"red",
|
||||||
|
"green",
|
||||||
|
"blue",
|
||||||
|
"alpha",
|
||||||
|
"cyan",
|
||||||
|
"magenta",
|
||||||
|
"yellow",
|
||||||
|
"black",
|
||||||
|
"hue",
|
||||||
|
"saturation",
|
||||||
|
"luminosity",
|
||||||
|
"value",
|
||||||
|
],
|
||||||
category="image",
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
class ImageChannelOffsetInvocation(BaseInvocation):
|
||||||
"""Adjusts the Luminosity (Value) of an image."""
|
"""Add or subtract a value from a specific color channel of an image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
luminosity: float = InputField(
|
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
|
||||||
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
# extract the channel and mode from the input and reference tuple
|
||||||
# ordering is changed from RGB to BGR
|
mode = CHANNEL_FORMATS[self.channel][0]
|
||||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||||
|
|
||||||
# Convert image to HSV color space
|
# Convert PIL image to new format
|
||||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
converted_image = numpy.array(pil_image.convert(mode)).astype(int)
|
||||||
|
image_channel = converted_image[:, :, channel_number]
|
||||||
|
|
||||||
# Adjust the luminosity (value)
|
# Adjust the value, clipping to 0..255
|
||||||
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
|
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
|
||||||
|
|
||||||
# Convert image back to BGR color space
|
# Put the channel back into the image
|
||||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
converted_image[:, :, channel_number] = image_channel
|
||||||
|
|
||||||
# Convert back to PIL format and to original color mode
|
# Convert back to RGBA format and output
|
||||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
@ -822,35 +888,60 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"img_saturation_adjust",
|
"img_channel_multiply",
|
||||||
title="Adjust Image Saturation",
|
title="Multiply Image Channel",
|
||||||
tags=["image", "saturation", "hsl"],
|
tags=[
|
||||||
|
"image",
|
||||||
|
"invert",
|
||||||
|
"scale",
|
||||||
|
"multiply",
|
||||||
|
"red",
|
||||||
|
"green",
|
||||||
|
"blue",
|
||||||
|
"alpha",
|
||||||
|
"cyan",
|
||||||
|
"magenta",
|
||||||
|
"yellow",
|
||||||
|
"black",
|
||||||
|
"hue",
|
||||||
|
"saturation",
|
||||||
|
"luminosity",
|
||||||
|
"value",
|
||||||
|
],
|
||||||
category="image",
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||||
"""Adjusts the Saturation of an image."""
|
"""Scale a specific color channel of an image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
|
||||||
|
scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.")
|
||||||
|
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
# extract the channel and mode from the input and reference tuple
|
||||||
# ordering is changed from RGB to BGR
|
mode = CHANNEL_FORMATS[self.channel][0]
|
||||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||||
|
|
||||||
# Convert image to HSV color space
|
# Convert PIL image to new format
|
||||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
converted_image = numpy.array(pil_image.convert(mode)).astype(float)
|
||||||
|
image_channel = converted_image[:, :, channel_number]
|
||||||
|
|
||||||
# Adjust the saturation
|
# Adjust the value, clipping to 0..255
|
||||||
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
|
image_channel = numpy.clip(image_channel * self.scale, 0, 255)
|
||||||
|
|
||||||
# Convert image back to BGR color space
|
# Invert the channel if requested
|
||||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
if self.invert_channel:
|
||||||
|
image_channel = 255 - image_channel
|
||||||
|
|
||||||
# Convert back to PIL format and to original color mode
|
# Put the channel back into the image
|
||||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
converted_image[:, :, channel_number] = image_channel
|
||||||
|
|
||||||
|
# Convert back to RGBA format and output
|
||||||
|
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
|
@ -8,19 +8,17 @@ from PIL import Image, ImageOps
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
def infill_methods() -> list[str]:
|
||||||
methods = [
|
methods = ["tile", "solid", "lama", "cv2"]
|
||||||
"tile",
|
|
||||||
"solid",
|
|
||||||
"lama",
|
|
||||||
]
|
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
methods.insert(0, "patchmatch")
|
methods.insert(0, "patchmatch")
|
||||||
return methods
|
return methods
|
||||||
@ -49,6 +47,10 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
|
|||||||
return im_patched
|
return im_patched
|
||||||
|
|
||||||
|
|
||||||
|
def infill_cv2(im: Image.Image) -> Image.Image:
|
||||||
|
return cv2_inpaint(im)
|
||||||
|
|
||||||
|
|
||||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||||
_nrows, _ncols, depth = image.shape
|
_nrows, _ncols, depth = image.shape
|
||||||
_strides = image.strides
|
_strides = image.strides
|
||||||
@ -116,7 +118,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return si
|
return si
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class InfillColorInvocation(BaseInvocation):
|
class InfillColorInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
@ -151,7 +153,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class InfillTileInvocation(BaseInvocation):
|
class InfillTileInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
@ -187,20 +189,42 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation(
|
||||||
|
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
||||||
|
)
|
||||||
class InfillPatchMatchInvocation(BaseInvocation):
|
class InfillPatchMatchInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||||
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||||
|
|
||||||
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
|
infill_image = image.copy()
|
||||||
|
width = int(image.width / self.downscale)
|
||||||
|
height = int(image.height / self.downscale)
|
||||||
|
infill_image = infill_image.resize(
|
||||||
|
(width, height),
|
||||||
|
resample=resample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infilled = infill_patchmatch(image.copy())
|
infilled = infill_patchmatch(infill_image)
|
||||||
else:
|
else:
|
||||||
raise ValueError("PatchMatch is not available on this system")
|
raise ValueError("PatchMatch is not available on this system")
|
||||||
|
|
||||||
|
infilled = infilled.resize(
|
||||||
|
(image.width, image.height),
|
||||||
|
resample=resample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||||
|
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=infilled,
|
image=infilled,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
@ -218,7 +242,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class LaMaInfillInvocation(BaseInvocation):
|
class LaMaInfillInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
@ -243,3 +267,30 @@ class LaMaInfillInvocation(BaseInvocation):
|
|||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
|
||||||
|
class CV2InfillInvocation(BaseInvocation):
|
||||||
|
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
infilled = infill_cv2(image.copy())
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=infilled,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
@ -74,7 +74,7 @@ class SchedulerOutput(BaseInvocationOutput):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
|
|
||||||
|
|
||||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
|
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
||||||
class SchedulerInvocation(BaseInvocation):
|
class SchedulerInvocation(BaseInvocation):
|
||||||
"""Selects a scheduler."""
|
"""Selects a scheduler."""
|
||||||
|
|
||||||
@ -86,7 +86,9 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
return SchedulerOutput(scheduler=self.scheduler)
|
return SchedulerOutput(scheduler=self.scheduler)
|
||||||
|
|
||||||
|
|
||||||
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
|
@invocation(
|
||||||
|
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
|
|
||||||
@ -186,6 +188,7 @@ def get_scheduler(
|
|||||||
title="Denoise Latents",
|
title="Denoise Latents",
|
||||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
@ -208,12 +211,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||||
control: Union[ControlField, list[ControlField]] = InputField(
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
default=None,
|
||||||
|
description=FieldDescriptions.control,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=5,
|
||||||
)
|
)
|
||||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None,
|
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
|
||||||
description=FieldDescriptions.mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -317,7 +322,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
# really only need model for dtype and device
|
# really only need model for dtype and device
|
||||||
model: StableDiffusionGeneratorPipeline,
|
model: StableDiffusionGeneratorPipeline,
|
||||||
control_input: List[ControlField],
|
control_input: Union[ControlField, List[ControlField]],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
@ -542,7 +547,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
|
@invocation(
|
||||||
|
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
@ -639,7 +646,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
|
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
@ -683,7 +690,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
|
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
@ -719,7 +726,9 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
|
@invocation(
|
||||||
|
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
@ -799,7 +808,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
|
||||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
|
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerOutput
|
|||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math")
|
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||||
class AddInvocation(BaseInvocation):
|
class AddInvocation(BaseInvocation):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a + self.b)
|
return IntegerOutput(value=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
|
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0")
|
||||||
class SubtractInvocation(BaseInvocation):
|
class SubtractInvocation(BaseInvocation):
|
||||||
"""Subtracts two numbers"""
|
"""Subtracts two numbers"""
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a - self.b)
|
return IntegerOutput(value=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
|
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0")
|
||||||
class MultiplyInvocation(BaseInvocation):
|
class MultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two numbers"""
|
"""Multiplies two numbers"""
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a * self.b)
|
return IntegerOutput(value=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
|
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0")
|
||||||
class DivideInvocation(BaseInvocation):
|
class DivideInvocation(BaseInvocation):
|
||||||
"""Divides two numbers"""
|
"""Divides two numbers"""
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=int(self.a / self.b))
|
return IntegerOutput(value=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
|
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
|
||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
|
|
||||||
|
@ -98,7 +98,9 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
|||||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||||
|
|
||||||
|
|
||||||
@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
|
@invocation(
|
||||||
|
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
|
||||||
|
)
|
||||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||||
"""Outputs a Core Metadata Object"""
|
"""Outputs a Core Metadata Object"""
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ class LoRAModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
|
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
@ -173,7 +173,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model")
|
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
|
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
@ -338,7 +338,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
@ -376,7 +376,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
|||||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
|
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
||||||
class SeamlessModeInvocation(BaseInvocation):
|
class SeamlessModeInvocation(BaseInvocation):
|
||||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
|
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ ORT_TO_NP_TYPE = {
|
|||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
|
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
@ -143,6 +143,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
title="ONNX Text to Latents",
|
title="ONNX Text to Latents",
|
||||||
tags=["latents", "inference", "txt2img", "onnx"],
|
tags=["latents", "inference", "txt2img", "onnx"],
|
||||||
category="latents",
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
@ -319,6 +320,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
title="ONNX Latents to Image",
|
title="ONNX Latents to Image",
|
||||||
tags=["latents", "image", "vae", "onnx"],
|
tags=["latents", "image", "vae", "onnx"],
|
||||||
category="image",
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
@ -403,7 +405,7 @@ class OnnxModelField(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
|
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
|
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
||||||
class FloatLinearRangeInvocation(BaseInvocation):
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range"""
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
|||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
|
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
UIType,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -40,10 +39,14 @@ class BooleanOutput(BaseInvocationOutput):
|
|||||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of booleans"""
|
"""Base class for nodes that output a collection of booleans"""
|
||||||
|
|
||||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
collection: list[bool] = OutputField(
|
||||||
|
description="The output boolean collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
|
@invocation(
|
||||||
|
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
|
||||||
|
)
|
||||||
class BooleanInvocation(BaseInvocation):
|
class BooleanInvocation(BaseInvocation):
|
||||||
"""A boolean primitive value"""
|
"""A boolean primitive value"""
|
||||||
|
|
||||||
@ -58,13 +61,12 @@ class BooleanInvocation(BaseInvocation):
|
|||||||
title="Boolean Collection Primitive",
|
title="Boolean Collection Primitive",
|
||||||
tags=["primitives", "boolean", "collection"],
|
tags=["primitives", "boolean", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class BooleanCollectionInvocation(BaseInvocation):
|
class BooleanCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of boolean primitive values"""
|
"""A collection of boolean primitive values"""
|
||||||
|
|
||||||
collection: list[bool] = InputField(
|
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
|
||||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||||
return BooleanCollectionOutput(collection=self.collection)
|
return BooleanCollectionOutput(collection=self.collection)
|
||||||
@ -86,10 +88,14 @@ class IntegerOutput(BaseInvocationOutput):
|
|||||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of integers"""
|
"""Base class for nodes that output a collection of integers"""
|
||||||
|
|
||||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
collection: list[int] = OutputField(
|
||||||
|
description="The int collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
|
@invocation(
|
||||||
|
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
|
||||||
|
)
|
||||||
class IntegerInvocation(BaseInvocation):
|
class IntegerInvocation(BaseInvocation):
|
||||||
"""An integer primitive value"""
|
"""An integer primitive value"""
|
||||||
|
|
||||||
@ -104,13 +110,12 @@ class IntegerInvocation(BaseInvocation):
|
|||||||
title="Integer Collection Primitive",
|
title="Integer Collection Primitive",
|
||||||
tags=["primitives", "integer", "collection"],
|
tags=["primitives", "integer", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class IntegerCollectionInvocation(BaseInvocation):
|
class IntegerCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of integer primitive values"""
|
"""A collection of integer primitive values"""
|
||||||
|
|
||||||
collection: list[int] = InputField(
|
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
|
||||||
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||||
return IntegerCollectionOutput(collection=self.collection)
|
return IntegerCollectionOutput(collection=self.collection)
|
||||||
@ -132,10 +137,12 @@ class FloatOutput(BaseInvocationOutput):
|
|||||||
class FloatCollectionOutput(BaseInvocationOutput):
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of floats"""
|
"""Base class for nodes that output a collection of floats"""
|
||||||
|
|
||||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
collection: list[float] = OutputField(
|
||||||
|
description="The float collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
|
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
|
||||||
class FloatInvocation(BaseInvocation):
|
class FloatInvocation(BaseInvocation):
|
||||||
"""A float primitive value"""
|
"""A float primitive value"""
|
||||||
|
|
||||||
@ -150,13 +157,12 @@ class FloatInvocation(BaseInvocation):
|
|||||||
title="Float Collection Primitive",
|
title="Float Collection Primitive",
|
||||||
tags=["primitives", "float", "collection"],
|
tags=["primitives", "float", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class FloatCollectionInvocation(BaseInvocation):
|
class FloatCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of float primitive values"""
|
"""A collection of float primitive values"""
|
||||||
|
|
||||||
collection: list[float] = InputField(
|
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
|
||||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
return FloatCollectionOutput(collection=self.collection)
|
return FloatCollectionOutput(collection=self.collection)
|
||||||
@ -178,10 +184,12 @@ class StringOutput(BaseInvocationOutput):
|
|||||||
class StringCollectionOutput(BaseInvocationOutput):
|
class StringCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of strings"""
|
"""Base class for nodes that output a collection of strings"""
|
||||||
|
|
||||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
collection: list[str] = OutputField(
|
||||||
|
description="The output strings",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
|
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
|
||||||
class StringInvocation(BaseInvocation):
|
class StringInvocation(BaseInvocation):
|
||||||
"""A string primitive value"""
|
"""A string primitive value"""
|
||||||
|
|
||||||
@ -196,13 +204,12 @@ class StringInvocation(BaseInvocation):
|
|||||||
title="String Collection Primitive",
|
title="String Collection Primitive",
|
||||||
tags=["primitives", "string", "collection"],
|
tags=["primitives", "string", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class StringCollectionInvocation(BaseInvocation):
|
class StringCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of string primitive values"""
|
"""A collection of string primitive values"""
|
||||||
|
|
||||||
collection: list[str] = InputField(
|
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
|
||||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
return StringCollectionOutput(collection=self.collection)
|
return StringCollectionOutput(collection=self.collection)
|
||||||
@ -232,10 +239,12 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of images"""
|
"""Base class for nodes that output a collection of images"""
|
||||||
|
|
||||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
collection: list[ImageField] = OutputField(
|
||||||
|
description="The output images",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||||
class ImageInvocation(BaseInvocation):
|
class ImageInvocation(BaseInvocation):
|
||||||
"""An image primitive value"""
|
"""An image primitive value"""
|
||||||
|
|
||||||
@ -256,13 +265,12 @@ class ImageInvocation(BaseInvocation):
|
|||||||
title="Image Collection Primitive",
|
title="Image Collection Primitive",
|
||||||
tags=["primitives", "image", "collection"],
|
tags=["primitives", "image", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageCollectionInvocation(BaseInvocation):
|
class ImageCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of image primitive values"""
|
"""A collection of image primitive values"""
|
||||||
|
|
||||||
collection: list[ImageField] = InputField(
|
collection: list[ImageField] = InputField(description="The collection of image values")
|
||||||
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||||
return ImageCollectionOutput(collection=self.collection)
|
return ImageCollectionOutput(collection=self.collection)
|
||||||
@ -316,11 +324,12 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
collection: list[LatentsField] = OutputField(
|
collection: list[LatentsField] = OutputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
ui_type=UIType.LatentsCollection,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
|
@invocation(
|
||||||
|
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
||||||
|
)
|
||||||
class LatentsInvocation(BaseInvocation):
|
class LatentsInvocation(BaseInvocation):
|
||||||
"""A latents tensor primitive value"""
|
"""A latents tensor primitive value"""
|
||||||
|
|
||||||
@ -337,12 +346,13 @@ class LatentsInvocation(BaseInvocation):
|
|||||||
title="Latents Collection Primitive",
|
title="Latents Collection Primitive",
|
||||||
tags=["primitives", "latents", "collection"],
|
tags=["primitives", "latents", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LatentsCollectionInvocation(BaseInvocation):
|
class LatentsCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of latents tensor primitive values"""
|
"""A collection of latents tensor primitive values"""
|
||||||
|
|
||||||
collection: list[LatentsField] = InputField(
|
collection: list[LatentsField] = InputField(
|
||||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
description="The collection of latents tensors",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||||
@ -385,10 +395,12 @@ class ColorOutput(BaseInvocationOutput):
|
|||||||
class ColorCollectionOutput(BaseInvocationOutput):
|
class ColorCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of colors"""
|
"""Base class for nodes that output a collection of colors"""
|
||||||
|
|
||||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
collection: list[ColorField] = OutputField(
|
||||||
|
description="The output colors",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
|
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
|
||||||
class ColorInvocation(BaseInvocation):
|
class ColorInvocation(BaseInvocation):
|
||||||
"""A color primitive value"""
|
"""A color primitive value"""
|
||||||
|
|
||||||
@ -422,7 +434,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
collection: list[ConditioningField] = OutputField(
|
collection: list[ConditioningField] = OutputField(
|
||||||
description="The output conditioning tensors",
|
description="The output conditioning tensors",
|
||||||
ui_type=UIType.ConditioningCollection,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -431,6 +442,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
|||||||
title="Conditioning Primitive",
|
title="Conditioning Primitive",
|
||||||
tags=["primitives", "conditioning"],
|
tags=["primitives", "conditioning"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ConditioningInvocation(BaseInvocation):
|
class ConditioningInvocation(BaseInvocation):
|
||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
@ -446,6 +458,7 @@ class ConditioningInvocation(BaseInvocation):
|
|||||||
title="Conditioning Collection Primitive",
|
title="Conditioning Collection Primitive",
|
||||||
tags=["primitives", "conditioning", "collection"],
|
tags=["primitives", "conditioning", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ConditioningCollectionInvocation(BaseInvocation):
|
class ConditioningCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of conditioning tensor primitive values"""
|
"""A collection of conditioning tensor primitive values"""
|
||||||
@ -453,7 +466,6 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
|||||||
collection: list[ConditioningField] = InputField(
|
collection: list[ConditioningField] = InputField(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="The collection of conditioning tensors",
|
description="The collection of conditioning tensors",
|
||||||
ui_type=UIType.ConditioningCollection,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||||
|
@ -10,7 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
|
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
|
||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
return StringCollectionOutput(collection=prompts)
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
|
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
|
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
@ -119,6 +119,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
title="SDXL Refiner Model",
|
title="SDXL Refiner Model",
|
||||||
tags=["model", "sdxl", "refiner"],
|
tags=["model", "sdxl", "refiner"],
|
||||||
category="model",
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
@ -23,7 +23,7 @@ ESRGAN_MODELS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
|
@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
|||||||
if to_type in get_args(from_type):
|
if to_type in get_args(from_type):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# allow int -> float, pydantic will cast for us
|
||||||
|
if from_type is int and to_type is float:
|
||||||
|
return True
|
||||||
|
|
||||||
# if not issubclass(from_type, to_type):
|
# if not issubclass(from_type, to_type):
|
||||||
if not is_union_subtype(from_type, to_type):
|
if not is_union_subtype(from_type, to_type):
|
||||||
return False
|
return False
|
||||||
|
20
invokeai/backend/image_util/cv2_inpaint.py
Normal file
20
invokeai/backend/image_util/cv2_inpaint.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def cv2_inpaint(image: Image.Image) -> Image.Image:
|
||||||
|
# Prepare Image
|
||||||
|
image_array = np.array(image.convert("RGB"))
|
||||||
|
image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# Prepare Mask From Alpha Channel
|
||||||
|
mask = image.split()[3].convert("RGB")
|
||||||
|
mask_array = np.array(mask)
|
||||||
|
mask_cv = cv2.cvtColor(mask_array, cv2.COLOR_BGR2GRAY)
|
||||||
|
mask_inv = cv2.bitwise_not(mask_cv)
|
||||||
|
|
||||||
|
# Inpaint Image
|
||||||
|
inpainted_result = cv2.inpaint(image_cv, mask_inv, 3, cv2.INPAINT_TELEA)
|
||||||
|
inpainted_image = Image.fromarray(cv2.cvtColor(inpainted_result, cv2.COLOR_BGR2RGB))
|
||||||
|
return inpainted_image
|
@ -5,6 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ def norm_img(np_img):
|
|||||||
|
|
||||||
def load_jit_model(url_or_path, device):
|
def load_jit_model(url_or_path, device):
|
||||||
model_path = url_or_path
|
model_path = url_or_path
|
||||||
print(f"Loading model from: {model_path}")
|
logger.info(f"Loading model from: {model_path}")
|
||||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
@ -52,5 +53,6 @@ class LaMA:
|
|||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return infilled_image
|
return infilled_image
|
||||||
|
@ -290,9 +290,20 @@ def download_realesrgan():
|
|||||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------
|
||||||
|
def download_lama():
|
||||||
|
logger.info("Installing lama infill model")
|
||||||
|
download_with_progress_bar(
|
||||||
|
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
|
config.models_path / "core/misc/lama/lama.pt",
|
||||||
|
"lama infill model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_support_models():
|
def download_support_models():
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
|
download_lama()
|
||||||
download_conversion_models()
|
download_conversion_models()
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,6 +75,7 @@
|
|||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
"@stevebel/png": "^1.5.1",
|
"@stevebel/png": "^1.5.1",
|
||||||
|
"compare-versions": "^6.1.0",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"formik": "^2.4.3",
|
"formik": "^2.4.3",
|
||||||
"framer-motion": "^10.16.1",
|
"framer-motion": "^10.16.1",
|
||||||
|
@ -511,6 +511,7 @@
|
|||||||
"maskBlur": "Blur",
|
"maskBlur": "Blur",
|
||||||
"maskBlurMethod": "Blur Method",
|
"maskBlurMethod": "Blur Method",
|
||||||
"coherencePassHeader": "Coherence Pass",
|
"coherencePassHeader": "Coherence Pass",
|
||||||
|
"coherenceMode": "Mode",
|
||||||
"coherenceSteps": "Steps",
|
"coherenceSteps": "Steps",
|
||||||
"coherenceStrength": "Strength",
|
"coherenceStrength": "Strength",
|
||||||
"seamLowThreshold": "Low",
|
"seamLowThreshold": "Low",
|
||||||
@ -520,6 +521,7 @@
|
|||||||
"scaledHeight": "Scaled H",
|
"scaledHeight": "Scaled H",
|
||||||
"infillMethod": "Infill Method",
|
"infillMethod": "Infill Method",
|
||||||
"tileSize": "Tile Size",
|
"tileSize": "Tile Size",
|
||||||
|
"patchmatchDownScaleSize": "Downscale",
|
||||||
"boundingBoxHeader": "Bounding Box",
|
"boundingBoxHeader": "Bounding Box",
|
||||||
"seamCorrectionHeader": "Seam Correction",
|
"seamCorrectionHeader": "Seam Correction",
|
||||||
"infillScalingHeader": "Infill and Scaling",
|
"infillScalingHeader": "Infill and Scaling",
|
||||||
|
@ -84,6 +84,7 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
|
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -202,6 +203,9 @@ addBoardIdSelectedListener();
|
|||||||
// Node schemas
|
// Node schemas
|
||||||
addReceivedOpenAPISchemaListener();
|
addReceivedOpenAPISchemaListener();
|
||||||
|
|
||||||
|
// Workflows
|
||||||
|
addWorkflowLoadedListener();
|
||||||
|
|
||||||
// DND
|
// DND
|
||||||
addImageDroppedListener();
|
addImageDroppedListener();
|
||||||
|
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
|
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||||
|
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addWorkflowLoadedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: workflowLoadRequested,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const log = logger('nodes');
|
||||||
|
const workflow = action.payload;
|
||||||
|
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||||
|
|
||||||
|
const { workflow: validatedWorkflow, errors } = validateWorkflow(
|
||||||
|
workflow,
|
||||||
|
nodeTemplates
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(workflowLoaded(validatedWorkflow));
|
||||||
|
|
||||||
|
if (!errors.length) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded with Warnings',
|
||||||
|
status: 'warning',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
errors.forEach(({ message, ...rest }) => {
|
||||||
|
log.warn(rest, message);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(setActiveTab('nodes'));
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
$flow.get()?.fitView();
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -31,48 +31,54 @@ const selector = createSelector(
|
|||||||
reasons.push('No initial image selected');
|
reasons.push('No initial image selected');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (activeTabName === 'nodes' && nodes.shouldValidateGraph) {
|
if (activeTabName === 'nodes') {
|
||||||
if (!nodes.nodes.length) {
|
if (nodes.shouldValidateGraph) {
|
||||||
reasons.push('No nodes in graph');
|
if (!nodes.nodes.length) {
|
||||||
}
|
reasons.push('No nodes in graph');
|
||||||
|
|
||||||
nodes.nodes.forEach((node) => {
|
|
||||||
if (!isInvocationNode(node)) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const nodeTemplate = nodes.nodeTemplates[node.data.type];
|
nodes.nodes.forEach((node) => {
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
if (!nodeTemplate) {
|
|
||||||
// Node type not found
|
|
||||||
reasons.push('Missing node template');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
|
||||||
|
|
||||||
forEach(node.data.inputs, (field) => {
|
|
||||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
|
||||||
const hasConnection = connectedEdges.some(
|
|
||||||
(edge) =>
|
|
||||||
edge.target === node.id && edge.targetHandle === field.name
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!fieldTemplate) {
|
|
||||||
reasons.push('Missing field template');
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fieldTemplate.required && !field.value && !hasConnection) {
|
const nodeTemplate = nodes.nodeTemplates[node.data.type];
|
||||||
reasons.push(
|
|
||||||
`${node.data.label || nodeTemplate.title} -> ${
|
if (!nodeTemplate) {
|
||||||
field.label || fieldTemplate.title
|
// Node type not found
|
||||||
} missing input`
|
reasons.push('Missing node template');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||||
|
|
||||||
|
forEach(node.data.inputs, (field) => {
|
||||||
|
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||||
|
const hasConnection = connectedEdges.some(
|
||||||
|
(edge) =>
|
||||||
|
edge.target === node.id && edge.targetHandle === field.name
|
||||||
);
|
);
|
||||||
return;
|
|
||||||
}
|
if (!fieldTemplate) {
|
||||||
|
reasons.push('Missing field template');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
fieldTemplate.required &&
|
||||||
|
field.value === undefined &&
|
||||||
|
!hasConnection
|
||||||
|
) {
|
||||||
|
reasons.push(
|
||||||
|
`${node.data.label || nodeTemplate.title} -> ${
|
||||||
|
field.label || fieldTemplate.title
|
||||||
|
} missing input`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
reasons.push('No model selected');
|
reasons.push('No model selected');
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
export const colorTokenToCssVar = (colorToken: string) =>
|
export const colorTokenToCssVar = (colorToken: string) =>
|
||||||
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
|
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;
|
||||||
|
@ -118,7 +118,11 @@ const IAICanvasToolChooserOptions = () => {
|
|||||||
useHotkeys(
|
useHotkeys(
|
||||||
['BracketLeft'],
|
['BracketLeft'],
|
||||||
() => {
|
() => {
|
||||||
dispatch(setBrushSize(Math.max(brushSize - 5, 5)));
|
if (brushSize - 5 <= 5) {
|
||||||
|
dispatch(setBrushSize(Math.max(brushSize - 1, 1)));
|
||||||
|
} else {
|
||||||
|
dispatch(setBrushSize(Math.max(brushSize - 5, 1)));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
enabled: () => !isStaging,
|
enabled: () => !isStaging,
|
||||||
|
@ -17,16 +17,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
setActiveTab,
|
|
||||||
setShouldShowImageDetails,
|
setShouldShowImageDetails,
|
||||||
setShouldShowProgressInViewer,
|
setShouldShowProgressInViewer,
|
||||||
} from 'features/ui/store/uiSlice';
|
} from 'features/ui/store/uiSlice';
|
||||||
@ -124,16 +121,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
if (!workflow) {
|
if (!workflow) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(workflowLoaded(workflow));
|
dispatch(workflowLoadRequested(workflow));
|
||||||
dispatch(setActiveTab('nodes'));
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded',
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}, [dispatch, workflow]);
|
}, [dispatch, workflow]);
|
||||||
|
|
||||||
const handleClickUseAllParameters = useCallback(() => {
|
const handleClickUseAllParameters = useCallback(() => {
|
||||||
|
@ -7,12 +7,9 @@ import {
|
|||||||
isModalOpenChanged,
|
isModalOpenChanged,
|
||||||
} from 'features/changeBoardModal/store/slice';
|
} from 'features/changeBoardModal/store/slice';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
@ -36,6 +33,7 @@ import {
|
|||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||||
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
|
|
||||||
type SingleSelectionMenuItemsProps = {
|
type SingleSelectionMenuItemsProps = {
|
||||||
imageDTO: ImageDTO;
|
imageDTO: ImageDTO;
|
||||||
@ -102,16 +100,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
if (!workflow) {
|
if (!workflow) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(workflowLoaded(workflow));
|
dispatch(workflowLoadRequested(workflow));
|
||||||
dispatch(setActiveTab('nodes'));
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded',
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}, [dispatch, workflow]);
|
}, [dispatch, workflow]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
|
@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
@ -13,6 +14,7 @@ import {
|
|||||||
OnConnectStart,
|
OnConnectStart,
|
||||||
OnEdgesChange,
|
OnEdgesChange,
|
||||||
OnEdgesDelete,
|
OnEdgesDelete,
|
||||||
|
OnInit,
|
||||||
OnMoveEnd,
|
OnMoveEnd,
|
||||||
OnNodesChange,
|
OnNodesChange,
|
||||||
OnNodesDelete,
|
OnNodesDelete,
|
||||||
@ -147,6 +149,11 @@ export const Flow = () => {
|
|||||||
dispatch(contextMenusClosed());
|
dispatch(contextMenusClosed());
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
|
const onInit: OnInit = useCallback((flow) => {
|
||||||
|
$flow.set(flow);
|
||||||
|
flow.fitView();
|
||||||
|
}, []);
|
||||||
|
|
||||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
dispatch(selectionCopied());
|
dispatch(selectionCopied());
|
||||||
@ -170,6 +177,7 @@ export const Flow = () => {
|
|||||||
edgeTypes={edgeTypes}
|
edgeTypes={edgeTypes}
|
||||||
nodes={nodes}
|
nodes={nodes}
|
||||||
edges={edges}
|
edges={edges}
|
||||||
|
onInit={onInit}
|
||||||
onNodesChange={onNodesChange}
|
onNodesChange={onNodesChange}
|
||||||
onEdgesChange={onEdgesChange}
|
onEdgesChange={onEdgesChange}
|
||||||
onEdgesDelete={onEdgesDelete}
|
onEdgesDelete={onEdgesDelete}
|
||||||
|
@ -12,6 +12,7 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
import { compare } from 'compare-versions';
|
||||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||||
@ -20,6 +21,7 @@ import { isInvocationNodeData } from 'features/nodes/types/types';
|
|||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { FaInfoCircle } from 'react-icons/fa';
|
import { FaInfoCircle } from 'react-icons/fa';
|
||||||
import NotesTextarea from './NotesTextarea';
|
import NotesTextarea from './NotesTextarea';
|
||||||
|
import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -29,6 +31,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
|||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
const label = useNodeLabel(nodeId);
|
const label = useNodeLabel(nodeId);
|
||||||
const title = useNodeTemplateTitle(nodeId);
|
const title = useNodeTemplateTitle(nodeId);
|
||||||
|
const doVersionsMatch = useDoNodeVersionsMatch(nodeId);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -50,7 +53,11 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
|||||||
>
|
>
|
||||||
<Icon
|
<Icon
|
||||||
as={FaInfoCircle}
|
as={FaInfoCircle}
|
||||||
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
|
sx={{
|
||||||
|
boxSize: 4,
|
||||||
|
w: 8,
|
||||||
|
color: doVersionsMatch ? 'base.400' : 'error.400',
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
@ -92,16 +99,59 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
|||||||
return 'Unknown Node';
|
return 'Unknown Node';
|
||||||
}, [data, nodeTemplate]);
|
}, [data, nodeTemplate]);
|
||||||
|
|
||||||
|
const versionComponent = useMemo(() => {
|
||||||
|
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data.version) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version unknown
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nodeTemplate.version) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version {data.version} (unknown template)
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version {data.version} (update node)
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version {data.version} (update app)
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return <Text as="span">Version {data.version}</Text>;
|
||||||
|
}, [data, nodeTemplate]);
|
||||||
|
|
||||||
if (!isInvocationNodeData(data)) {
|
if (!isInvocationNodeData(data)) {
|
||||||
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ flexDir: 'column' }}>
|
<Flex sx={{ flexDir: 'column' }}>
|
||||||
<Text sx={{ fontWeight: 600 }}>{title}</Text>
|
<Text as="span" sx={{ fontWeight: 600 }}>
|
||||||
|
{title}
|
||||||
|
</Text>
|
||||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||||
{nodeTemplate?.description}
|
{nodeTemplate?.description}
|
||||||
</Text>
|
</Text>
|
||||||
|
{versionComponent}
|
||||||
{data?.notes && <Text>{data.notes}</Text>}
|
{data?.notes && <Text>{data.notes}</Text>}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import { Tooltip } from '@chakra-ui/react';
|
import { Tooltip } from '@chakra-ui/react';
|
||||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||||
import {
|
import {
|
||||||
|
COLLECTION_TYPES,
|
||||||
FIELDS,
|
FIELDS,
|
||||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||||
|
MODEL_TYPES,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
} from 'features/nodes/types/constants';
|
} from 'features/nodes/types/constants';
|
||||||
import {
|
import {
|
||||||
InputFieldTemplate,
|
InputFieldTemplate,
|
||||||
@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
|
|||||||
borderWidth: 0,
|
borderWidth: 0,
|
||||||
zIndex: 1,
|
zIndex: 1,
|
||||||
};
|
};
|
||||||
|
``;
|
||||||
|
|
||||||
export const inputHandleStyles: CSSProperties = {
|
export const inputHandleStyles: CSSProperties = {
|
||||||
left: '-1rem',
|
left: '-1rem',
|
||||||
@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
|
|||||||
connectionError,
|
connectionError,
|
||||||
} = props;
|
} = props;
|
||||||
const { name, type } = fieldTemplate;
|
const { name, type } = fieldTemplate;
|
||||||
const { color, title } = FIELDS[type];
|
const { color: typeColor, title } = FIELDS[type];
|
||||||
|
|
||||||
const styles: CSSProperties = useMemo(() => {
|
const styles: CSSProperties = useMemo(() => {
|
||||||
|
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||||
|
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||||
|
const isModelType = MODEL_TYPES.includes(type);
|
||||||
|
const color = colorTokenToCssVar(typeColor);
|
||||||
const s: CSSProperties = {
|
const s: CSSProperties = {
|
||||||
backgroundColor: colorTokenToCssVar(color),
|
backgroundColor:
|
||||||
|
isCollectionType || isPolymorphicType
|
||||||
|
? 'var(--invokeai-colors-base-900)'
|
||||||
|
: color,
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
width: '1rem',
|
width: '1rem',
|
||||||
height: '1rem',
|
height: '1rem',
|
||||||
borderWidth: 0,
|
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||||
|
borderStyle: 'solid',
|
||||||
|
borderColor: color,
|
||||||
|
borderRadius: isModelType ? 4 : '100%',
|
||||||
zIndex: 1,
|
zIndex: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
|
|||||||
|
|
||||||
return s;
|
return s;
|
||||||
}, [
|
}, [
|
||||||
color,
|
|
||||||
connectionError,
|
connectionError,
|
||||||
handleType,
|
handleType,
|
||||||
isConnectionInProgress,
|
isConnectionInProgress,
|
||||||
isConnectionStartField,
|
isConnectionStartField,
|
||||||
|
type,
|
||||||
|
typeColor,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const tooltip = useMemo(() => {
|
const tooltip = useMemo(() => {
|
||||||
|
@ -75,6 +75,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
sx={{
|
sx={{
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
|
h: 'full',
|
||||||
mb: 0,
|
mb: 0,
|
||||||
px: 1,
|
px: 1,
|
||||||
gap: 2,
|
gap: 2,
|
||||||
|
@ -3,18 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
|||||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import BooleanInputField from './inputs/BooleanInputField';
|
import BooleanInputField from './inputs/BooleanInputField';
|
||||||
import ClipInputField from './inputs/ClipInputField';
|
|
||||||
import CollectionInputField from './inputs/CollectionInputField';
|
|
||||||
import CollectionItemInputField from './inputs/CollectionItemInputField';
|
|
||||||
import ColorInputField from './inputs/ColorInputField';
|
import ColorInputField from './inputs/ColorInputField';
|
||||||
import ConditioningInputField from './inputs/ConditioningInputField';
|
|
||||||
import ControlInputField from './inputs/ControlInputField';
|
|
||||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||||
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
|
|
||||||
import EnumInputField from './inputs/EnumInputField';
|
import EnumInputField from './inputs/EnumInputField';
|
||||||
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
|
|
||||||
import ImageInputField from './inputs/ImageInputField';
|
import ImageInputField from './inputs/ImageInputField';
|
||||||
import LatentsInputField from './inputs/LatentsInputField';
|
|
||||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||||
import MainModelInputField from './inputs/MainModelInputField';
|
import MainModelInputField from './inputs/MainModelInputField';
|
||||||
import NumberInputField from './inputs/NumberInputField';
|
import NumberInputField from './inputs/NumberInputField';
|
||||||
@ -22,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
|
|||||||
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||||
import StringInputField from './inputs/StringInputField';
|
import StringInputField from './inputs/StringInputField';
|
||||||
import UnetInputField from './inputs/UnetInputField';
|
|
||||||
import VaeInputField from './inputs/VaeInputField';
|
|
||||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||||
|
|
||||||
type InputFieldProps = {
|
type InputFieldProps = {
|
||||||
@ -31,7 +21,6 @@ type InputFieldProps = {
|
|||||||
fieldName: string;
|
fieldName: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
// build an individual input element based on the schema
|
|
||||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||||
const field = useFieldData(nodeId, fieldName);
|
const field = useFieldData(nodeId, fieldName);
|
||||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||||
@ -93,88 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'LatentsField' &&
|
|
||||||
fieldTemplate?.type === 'LatentsField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<LatentsInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'DenoiseMaskField' &&
|
|
||||||
fieldTemplate?.type === 'DenoiseMaskField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<DenoiseMaskInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'ConditioningField' &&
|
|
||||||
fieldTemplate?.type === 'ConditioningField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<ConditioningInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
|
|
||||||
return (
|
|
||||||
<UnetInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
|
|
||||||
return (
|
|
||||||
<ClipInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
|
|
||||||
return (
|
|
||||||
<VaeInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'ControlField' &&
|
|
||||||
fieldTemplate?.type === 'ControlField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<ControlInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
field?.type === 'MainModelField' &&
|
field?.type === 'MainModelField' &&
|
||||||
fieldTemplate?.type === 'MainModelField'
|
fieldTemplate?.type === 'MainModelField'
|
||||||
@ -240,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
|
|
||||||
return (
|
|
||||||
<CollectionInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'CollectionItem' &&
|
|
||||||
fieldTemplate?.type === 'CollectionItem'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<CollectionItemInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||||
return (
|
return (
|
||||||
<ColorInputField
|
<ColorInputField
|
||||||
@ -273,19 +157,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'ImageCollection' &&
|
|
||||||
fieldTemplate?.type === 'ImageCollection'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<ImageCollectionInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
field?.type === 'SDXLMainModelField' &&
|
field?.type === 'SDXLMainModelField' &&
|
||||||
fieldTemplate?.type === 'SDXLMainModelField'
|
fieldTemplate?.type === 'SDXLMainModelField'
|
||||||
@ -309,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (field && fieldTemplate) {
|
||||||
|
// Fallback for when there is no component for the type
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box p={1}>
|
<Box p={1}>
|
||||||
<Text
|
<Text
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
import {
|
import {
|
||||||
ControlInputFieldTemplate,
|
ControlInputFieldTemplate,
|
||||||
ControlInputFieldValue,
|
ControlInputFieldValue,
|
||||||
|
ControlPolymorphicInputFieldTemplate,
|
||||||
|
ControlPolymorphicInputFieldValue,
|
||||||
FieldComponentProps,
|
FieldComponentProps,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
const ControlInputFieldComponent = (
|
const ControlInputFieldComponent = (
|
||||||
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
|
_props: FieldComponentProps<
|
||||||
|
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
|
||||||
|
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
|
||||||
|
>
|
||||||
) => {
|
) => {
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
@ -9,9 +9,9 @@ import {
|
|||||||
} from 'features/dnd/types';
|
} from 'features/dnd/types';
|
||||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
|
FieldComponentProps,
|
||||||
ImageInputFieldTemplate,
|
ImageInputFieldTemplate,
|
||||||
ImageInputFieldValue,
|
ImageInputFieldValue,
|
||||||
FieldComponentProps,
|
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { FaUndo } from 'react-icons/fa';
|
import { FaUndo } from 'react-icons/fa';
|
||||||
|
@ -2,11 +2,16 @@ import {
|
|||||||
LatentsInputFieldTemplate,
|
LatentsInputFieldTemplate,
|
||||||
LatentsInputFieldValue,
|
LatentsInputFieldValue,
|
||||||
FieldComponentProps,
|
FieldComponentProps,
|
||||||
|
LatentsPolymorphicInputFieldValue,
|
||||||
|
LatentsPolymorphicInputFieldTemplate,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
const LatentsInputFieldComponent = (
|
const LatentsInputFieldComponent = (
|
||||||
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
|
_props: FieldComponentProps<
|
||||||
|
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
|
||||||
|
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
|
||||||
|
>
|
||||||
) => {
|
) => {
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||||
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
|
FieldComponentProps,
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
FloatInputFieldValue,
|
FloatInputFieldValue,
|
||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
IntegerInputFieldValue,
|
IntegerInputFieldValue,
|
||||||
FieldComponentProps,
|
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo, useEffect, useMemo, useState } from 'react';
|
import { memo, useEffect, useMemo, useState } from 'react';
|
||||||
|
|
||||||
|
@ -138,13 +138,14 @@ export const useBuildNodeData = () => {
|
|||||||
data: {
|
data: {
|
||||||
id: nodeId,
|
id: nodeId,
|
||||||
type,
|
type,
|
||||||
inputs,
|
version: template.version,
|
||||||
outputs,
|
|
||||||
isOpen: true,
|
|
||||||
label: '',
|
label: '',
|
||||||
notes: '',
|
notes: '',
|
||||||
|
isOpen: true,
|
||||||
embedWorkflow: false,
|
embedWorkflow: false,
|
||||||
isIntermediate: true,
|
isIntermediate: true,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { compareVersions } from 'compare-versions';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { isInvocationNode } from '../types/types';
|
||||||
|
|
||||||
|
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => {
|
||||||
|
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||||
|
if (!nodeTemplate?.version || !node.data?.version) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const nodeTemplate = useAppSelector(selector);
|
||||||
|
|
||||||
|
return nodeTemplate;
|
||||||
|
};
|
@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
|||||||
if (!isInvocationNode(node)) {
|
if (!isInvocationNode(node)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
return Boolean(node?.data.inputs[fieldName]?.value);
|
return node?.data.inputs[fieldName]?.value !== undefined;
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
),
|
),
|
||||||
|
@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
||||||
import { COLLECTION_TYPES } from '../types/constants';
|
import {
|
||||||
|
COLLECTION_MAP,
|
||||||
|
COLLECTION_TYPES,
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
|
} from '../types/constants';
|
||||||
import { InvocationNodeData } from '../types/types';
|
import { InvocationNodeData } from '../types/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||||
|
* TODO: Figure out how to do this without duplicating all the logic
|
||||||
|
*/
|
||||||
|
|
||||||
export const useIsValidConnection = () => {
|
export const useIsValidConnection = () => {
|
||||||
const flow = useReactFlow();
|
const flow = useReactFlow();
|
||||||
const shouldValidateGraph = useAppSelector(
|
const shouldValidateGraph = useAppSelector(
|
||||||
@ -42,6 +52,19 @@ export const useIsValidConnection = () => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
edges
|
||||||
|
.filter((edge) => {
|
||||||
|
return edge.target === target && edge.targetHandle === targetHandle;
|
||||||
|
})
|
||||||
|
.find((edge) => {
|
||||||
|
edge.source === source && edge.sourceHandle === sourceHandle;
|
||||||
|
})
|
||||||
|
) {
|
||||||
|
// We already have a connection from this source to this target
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Connection is invalid if target already has a connection
|
// Connection is invalid if target already has a connection
|
||||||
if (
|
if (
|
||||||
edges.find((edge) => {
|
edges.find((edge) => {
|
||||||
@ -53,21 +76,62 @@ export const useIsValidConnection = () => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection types must be the same for a connection
|
/**
|
||||||
if (
|
* Connection types must be the same for a connection, with exceptions:
|
||||||
sourceType !== targetType &&
|
* - CollectionItem can connect to any non-Collection
|
||||||
sourceType !== 'CollectionItem' &&
|
* - Non-Collections can connect to CollectionItem
|
||||||
targetType !== 'CollectionItem'
|
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||||
) {
|
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||||
if (
|
* - Any Collection can connect to a Generic Collection
|
||||||
!(
|
*/
|
||||||
COLLECTION_TYPES.includes(targetType) &&
|
|
||||||
COLLECTION_TYPES.includes(sourceType)
|
if (sourceType !== targetType) {
|
||||||
)
|
const isCollectionItemToNonCollection =
|
||||||
) {
|
sourceType === 'CollectionItem' &&
|
||||||
return false;
|
!COLLECTION_TYPES.includes(targetType);
|
||||||
}
|
|
||||||
|
const isNonCollectionToCollectionItem =
|
||||||
|
targetType === 'CollectionItem' &&
|
||||||
|
!COLLECTION_TYPES.includes(sourceType) &&
|
||||||
|
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isAnythingToPolymorphicOfSameBaseType =
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||||
|
(() => {
|
||||||
|
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const baseType =
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP[
|
||||||
|
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||||
|
];
|
||||||
|
|
||||||
|
const collectionType =
|
||||||
|
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||||
|
|
||||||
|
return sourceType === baseType || sourceType === collectionType;
|
||||||
|
})();
|
||||||
|
|
||||||
|
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||||
|
sourceType === 'Collection' &&
|
||||||
|
(COLLECTION_TYPES.includes(targetType) ||
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType));
|
||||||
|
|
||||||
|
const isCollectionToGenericCollection =
|
||||||
|
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||||
|
|
||||||
|
return (
|
||||||
|
isCollectionItemToNonCollection ||
|
||||||
|
isNonCollectionToCollectionItem ||
|
||||||
|
isAnythingToPolymorphicOfSameBaseType ||
|
||||||
|
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||||
|
isCollectionToGenericCollection ||
|
||||||
|
isIntToFloat
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graphs much be acyclic (no loops!)
|
// Graphs much be acyclic (no loops!)
|
||||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||||
},
|
},
|
||||||
|
@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
|||||||
import { useLogger } from 'app/logging/useLogger';
|
import { useLogger } from 'app/logging/useLogger';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
import { zWorkflow } from 'features/nodes/types/types';
|
||||||
import { zValidatedWorkflow } from 'features/nodes/types/types';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { ZodError } from 'zod';
|
import { ZodError } from 'zod';
|
||||||
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
||||||
|
import { workflowLoadRequested } from '../store/actions';
|
||||||
|
|
||||||
export const useLoadWorkflowFromFile = () => {
|
export const useLoadWorkflowFromFile = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -24,7 +24,7 @@ export const useLoadWorkflowFromFile = () => {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const parsedJSON = JSON.parse(String(rawJSON));
|
const parsedJSON = JSON.parse(String(rawJSON));
|
||||||
const result = zValidatedWorkflow.safeParse(parsedJSON);
|
const result = zWorkflow.safeParse(parsedJSON);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
const { message } = fromZodError(result.error, {
|
const { message } = fromZodError(result.error, {
|
||||||
@ -45,32 +45,8 @@ export const useLoadWorkflowFromFile = () => {
|
|||||||
reader.abort();
|
reader.abort();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(workflowLoaded(result.data.workflow));
|
|
||||||
|
|
||||||
if (!result.data.warnings.length) {
|
dispatch(workflowLoadRequested(result.data));
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded',
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
reader.abort();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded with Warnings',
|
|
||||||
status: 'warning',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
result.data.warnings.forEach(({ message, ...rest }) => {
|
|
||||||
logger.warn(rest, message);
|
|
||||||
});
|
|
||||||
|
|
||||||
reader.abort();
|
reader.abort();
|
||||||
} catch {
|
} catch {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { Graph } from 'services/api/types';
|
import { Graph } from 'services/api/types';
|
||||||
|
import { Workflow } from '../types/types';
|
||||||
|
|
||||||
export const textToImageGraphBuilt = createAction<Graph>(
|
export const textToImageGraphBuilt = createAction<Graph>(
|
||||||
'nodes/textToImageGraphBuilt'
|
'nodes/textToImageGraphBuilt'
|
||||||
@ -16,3 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
|
|||||||
canvasGraphBuilt,
|
canvasGraphBuilt,
|
||||||
nodesGraphBuilt
|
nodesGraphBuilt
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const workflowLoadRequested = createAction<Workflow>(
|
||||||
|
'nodes/workflowLoadRequested'
|
||||||
|
);
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
import { atom } from 'nanostores';
|
||||||
|
import { ReactFlowInstance } from 'reactflow';
|
||||||
|
|
||||||
|
export const $flow = atom<ReactFlowInstance | null>(null);
|
@ -1,10 +1,20 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
||||||
import { COLLECTION_TYPES } from 'features/nodes/types/constants';
|
import {
|
||||||
|
COLLECTION_MAP,
|
||||||
|
COLLECTION_TYPES,
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
|
} from 'features/nodes/types/constants';
|
||||||
import { FieldType } from 'features/nodes/types/types';
|
import { FieldType } from 'features/nodes/types/types';
|
||||||
import { HandleType } from 'reactflow';
|
import { HandleType } from 'reactflow';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||||
|
* TODO: Figure out how to do this without duplicating all the logic
|
||||||
|
*/
|
||||||
|
|
||||||
export const makeConnectionErrorSelector = (
|
export const makeConnectionErrorSelector = (
|
||||||
nodeId: string,
|
nodeId: string,
|
||||||
fieldName: string,
|
fieldName: string,
|
||||||
@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = (
|
|||||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||||
state.nodes;
|
state.nodes;
|
||||||
|
|
||||||
if (!state.nodes.shouldValidateGraph) {
|
|
||||||
// manual override!
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!connectionStartParams || !currentConnectionFieldType) {
|
if (!connectionStartParams || !currentConnectionFieldType) {
|
||||||
return 'No connection in progress';
|
return 'No connection in progress';
|
||||||
}
|
}
|
||||||
@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = (
|
|||||||
return 'No connection data';
|
return 'No connection data';
|
||||||
}
|
}
|
||||||
|
|
||||||
const targetFieldType =
|
const targetType =
|
||||||
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
||||||
const sourceFieldType =
|
const sourceType =
|
||||||
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
||||||
|
|
||||||
if (nodeId === connectionNodeId) {
|
if (nodeId === connectionNodeId) {
|
||||||
@ -55,30 +60,73 @@ export const makeConnectionErrorSelector = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fieldType !== currentConnectionFieldType &&
|
|
||||||
fieldType !== 'CollectionItem' &&
|
|
||||||
currentConnectionFieldType !== 'CollectionItem'
|
|
||||||
) {
|
|
||||||
if (
|
|
||||||
!(
|
|
||||||
COLLECTION_TYPES.includes(targetFieldType) &&
|
|
||||||
COLLECTION_TYPES.includes(sourceFieldType)
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
// except for collection items, field types must match
|
|
||||||
return 'Field types must match';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
handleType === 'target' &&
|
|
||||||
edges.find((edge) => {
|
edges.find((edge) => {
|
||||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
return edge.target === nodeId && edge.targetHandle === fieldName;
|
||||||
}) &&
|
}) &&
|
||||||
// except CollectionItem inputs can have multiples
|
// except CollectionItem inputs can have multiples
|
||||||
targetFieldType !== 'CollectionItem'
|
targetType !== 'CollectionItem'
|
||||||
) {
|
) {
|
||||||
return 'Inputs may only have one connection';
|
return 'Input may only have one connection';
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Connection types must be the same for a connection, with exceptions:
|
||||||
|
* - CollectionItem can connect to any non-Collection
|
||||||
|
* - Non-Collections can connect to CollectionItem
|
||||||
|
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||||
|
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||||
|
* - Any Collection can connect to a Generic Collection
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (sourceType !== targetType) {
|
||||||
|
const isCollectionItemToNonCollection =
|
||||||
|
sourceType === 'CollectionItem' &&
|
||||||
|
!COLLECTION_TYPES.includes(targetType);
|
||||||
|
|
||||||
|
const isNonCollectionToCollectionItem =
|
||||||
|
targetType === 'CollectionItem' &&
|
||||||
|
!COLLECTION_TYPES.includes(sourceType) &&
|
||||||
|
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isAnythingToPolymorphicOfSameBaseType =
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||||
|
(() => {
|
||||||
|
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const baseType =
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP[
|
||||||
|
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||||
|
];
|
||||||
|
|
||||||
|
const collectionType =
|
||||||
|
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||||
|
|
||||||
|
return sourceType === baseType || sourceType === collectionType;
|
||||||
|
})();
|
||||||
|
|
||||||
|
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||||
|
sourceType === 'Collection' &&
|
||||||
|
(COLLECTION_TYPES.includes(targetType) ||
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType));
|
||||||
|
|
||||||
|
const isCollectionToGenericCollection =
|
||||||
|
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||||
|
|
||||||
|
if (
|
||||||
|
!(
|
||||||
|
isCollectionItemToNonCollection ||
|
||||||
|
isNonCollectionToCollectionItem ||
|
||||||
|
isAnythingToPolymorphicOfSameBaseType ||
|
||||||
|
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||||
|
isCollectionToGenericCollection ||
|
||||||
|
isIntToFloat
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
return 'Field types must match';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const isGraphAcyclic = getIsGraphAcyclic(
|
const isGraphAcyclic = getIsGraphAcyclic(
|
||||||
|
@ -17,176 +17,297 @@ export const KIND_MAP = {
|
|||||||
export const COLLECTION_TYPES: FieldType[] = [
|
export const COLLECTION_TYPES: FieldType[] = [
|
||||||
'Collection',
|
'Collection',
|
||||||
'IntegerCollection',
|
'IntegerCollection',
|
||||||
|
'BooleanCollection',
|
||||||
'FloatCollection',
|
'FloatCollection',
|
||||||
'StringCollection',
|
'StringCollection',
|
||||||
'BooleanCollection',
|
|
||||||
'ImageCollection',
|
'ImageCollection',
|
||||||
|
'LatentsCollection',
|
||||||
|
'ConditioningCollection',
|
||||||
|
'ControlCollection',
|
||||||
|
'ColorCollection',
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export const POLYMORPHIC_TYPES = [
|
||||||
|
'IntegerPolymorphic',
|
||||||
|
'BooleanPolymorphic',
|
||||||
|
'FloatPolymorphic',
|
||||||
|
'StringPolymorphic',
|
||||||
|
'ImagePolymorphic',
|
||||||
|
'LatentsPolymorphic',
|
||||||
|
'ConditioningPolymorphic',
|
||||||
|
'ControlPolymorphic',
|
||||||
|
'ColorPolymorphic',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const MODEL_TYPES = [
|
||||||
|
'ControlNetModelField',
|
||||||
|
'LoRAModelField',
|
||||||
|
'MainModelField',
|
||||||
|
'ONNXModelField',
|
||||||
|
'SDXLMainModelField',
|
||||||
|
'SDXLRefinerModelField',
|
||||||
|
'VaeModelField',
|
||||||
|
'UNetField',
|
||||||
|
'VaeField',
|
||||||
|
'ClipField',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const COLLECTION_MAP = {
|
||||||
|
integer: 'IntegerCollection',
|
||||||
|
boolean: 'BooleanCollection',
|
||||||
|
number: 'FloatCollection',
|
||||||
|
float: 'FloatCollection',
|
||||||
|
string: 'StringCollection',
|
||||||
|
ImageField: 'ImageCollection',
|
||||||
|
LatentsField: 'LatentsCollection',
|
||||||
|
ConditioningField: 'ConditioningCollection',
|
||||||
|
ControlField: 'ControlCollection',
|
||||||
|
ColorField: 'ColorCollection',
|
||||||
|
};
|
||||||
|
export const isCollectionItemType = (
|
||||||
|
itemType: string | undefined
|
||||||
|
): itemType is keyof typeof COLLECTION_MAP =>
|
||||||
|
Boolean(itemType && itemType in COLLECTION_MAP);
|
||||||
|
|
||||||
|
export const SINGLE_TO_POLYMORPHIC_MAP = {
|
||||||
|
integer: 'IntegerPolymorphic',
|
||||||
|
boolean: 'BooleanPolymorphic',
|
||||||
|
number: 'FloatPolymorphic',
|
||||||
|
float: 'FloatPolymorphic',
|
||||||
|
string: 'StringPolymorphic',
|
||||||
|
ImageField: 'ImagePolymorphic',
|
||||||
|
LatentsField: 'LatentsPolymorphic',
|
||||||
|
ConditioningField: 'ConditioningPolymorphic',
|
||||||
|
ControlField: 'ControlPolymorphic',
|
||||||
|
ColorField: 'ColorPolymorphic',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const POLYMORPHIC_TO_SINGLE_MAP = {
|
||||||
|
IntegerPolymorphic: 'integer',
|
||||||
|
BooleanPolymorphic: 'boolean',
|
||||||
|
FloatPolymorphic: 'float',
|
||||||
|
StringPolymorphic: 'string',
|
||||||
|
ImagePolymorphic: 'ImageField',
|
||||||
|
LatentsPolymorphic: 'LatentsField',
|
||||||
|
ConditioningPolymorphic: 'ConditioningField',
|
||||||
|
ControlPolymorphic: 'ControlField',
|
||||||
|
ColorPolymorphic: 'ColorField',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isPolymorphicItemType = (
|
||||||
|
itemType: string | undefined
|
||||||
|
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||||
|
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||||
|
|
||||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||||
integer: {
|
|
||||||
title: 'Integer',
|
|
||||||
description: 'Integers are whole numbers, without a decimal point.',
|
|
||||||
color: 'red.500',
|
|
||||||
},
|
|
||||||
float: {
|
|
||||||
title: 'Float',
|
|
||||||
description: 'Floats are numbers with a decimal point.',
|
|
||||||
color: 'orange.500',
|
|
||||||
},
|
|
||||||
string: {
|
|
||||||
title: 'String',
|
|
||||||
description: 'Strings are text.',
|
|
||||||
color: 'yellow.500',
|
|
||||||
},
|
|
||||||
boolean: {
|
boolean: {
|
||||||
title: 'Boolean',
|
|
||||||
color: 'green.500',
|
color: 'green.500',
|
||||||
description: 'Booleans are true or false.',
|
description: 'Booleans are true or false.',
|
||||||
|
title: 'Boolean',
|
||||||
},
|
},
|
||||||
enum: {
|
BooleanCollection: {
|
||||||
title: 'Enum',
|
color: 'green.500',
|
||||||
description: 'Enums are values that may be one of a number of options.',
|
description: 'A collection of booleans.',
|
||||||
color: 'blue.500',
|
title: 'Boolean Collection',
|
||||||
},
|
},
|
||||||
array: {
|
BooleanPolymorphic: {
|
||||||
title: 'Array',
|
color: 'green.500',
|
||||||
description: 'Enums are values that may be one of a number of options.',
|
description: 'A collection of booleans.',
|
||||||
color: 'base.500',
|
title: 'Boolean Polymorphic',
|
||||||
},
|
|
||||||
ImageField: {
|
|
||||||
title: 'Image',
|
|
||||||
description: 'Images may be passed between nodes.',
|
|
||||||
color: 'purple.500',
|
|
||||||
},
|
|
||||||
DenoiseMaskField: {
|
|
||||||
title: 'Denoise Mask',
|
|
||||||
description: 'Denoise Mask may be passed between nodes',
|
|
||||||
color: 'base.500',
|
|
||||||
},
|
|
||||||
LatentsField: {
|
|
||||||
title: 'Latents',
|
|
||||||
description: 'Latents may be passed between nodes.',
|
|
||||||
color: 'pink.500',
|
|
||||||
},
|
|
||||||
LatentsCollection: {
|
|
||||||
title: 'Latents Collection',
|
|
||||||
description: 'Latents may be passed between nodes.',
|
|
||||||
color: 'pink.500',
|
|
||||||
},
|
|
||||||
ConditioningField: {
|
|
||||||
color: 'cyan.500',
|
|
||||||
title: 'Conditioning',
|
|
||||||
description: 'Conditioning may be passed between nodes.',
|
|
||||||
},
|
|
||||||
ConditioningCollection: {
|
|
||||||
color: 'cyan.500',
|
|
||||||
title: 'Conditioning Collection',
|
|
||||||
description: 'Conditioning may be passed between nodes.',
|
|
||||||
},
|
|
||||||
ImageCollection: {
|
|
||||||
title: 'Image Collection',
|
|
||||||
description: 'A collection of images.',
|
|
||||||
color: 'base.300',
|
|
||||||
},
|
|
||||||
UNetField: {
|
|
||||||
color: 'red.500',
|
|
||||||
title: 'UNet',
|
|
||||||
description: 'UNet submodel.',
|
|
||||||
},
|
},
|
||||||
ClipField: {
|
ClipField: {
|
||||||
color: 'green.500',
|
color: 'green.500',
|
||||||
title: 'Clip',
|
|
||||||
description: 'Tokenizer and text_encoder submodels.',
|
description: 'Tokenizer and text_encoder submodels.',
|
||||||
},
|
title: 'Clip',
|
||||||
VaeField: {
|
|
||||||
color: 'blue.500',
|
|
||||||
title: 'Vae',
|
|
||||||
description: 'Vae submodel.',
|
|
||||||
},
|
|
||||||
ControlField: {
|
|
||||||
color: 'cyan.500',
|
|
||||||
title: 'Control',
|
|
||||||
description: 'Control info passed between nodes.',
|
|
||||||
},
|
|
||||||
MainModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'Model',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
SDXLRefinerModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'Refiner Model',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
VaeModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'VAE',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
LoRAModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'LoRA',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
ControlNetModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'ControlNet',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
Scheduler: {
|
|
||||||
color: 'base.500',
|
|
||||||
title: 'Scheduler',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
},
|
||||||
Collection: {
|
Collection: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
title: 'Collection',
|
|
||||||
description: 'TODO',
|
description: 'TODO',
|
||||||
|
title: 'Collection',
|
||||||
},
|
},
|
||||||
CollectionItem: {
|
CollectionItem: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
title: 'Collection Item',
|
|
||||||
description: 'TODO',
|
description: 'TODO',
|
||||||
|
title: 'Collection Item',
|
||||||
|
},
|
||||||
|
ColorCollection: {
|
||||||
|
color: 'pink.300',
|
||||||
|
description: 'A collection of colors.',
|
||||||
|
title: 'Color Collection',
|
||||||
},
|
},
|
||||||
ColorField: {
|
ColorField: {
|
||||||
title: 'Color',
|
color: 'pink.300',
|
||||||
description: 'A RGBA color.',
|
description: 'A RGBA color.',
|
||||||
color: 'base.500',
|
title: 'Color',
|
||||||
},
|
},
|
||||||
BooleanCollection: {
|
ColorPolymorphic: {
|
||||||
title: 'Boolean Collection',
|
color: 'pink.300',
|
||||||
description: 'A collection of booleans.',
|
description: 'A collection of colors.',
|
||||||
color: 'green.500',
|
title: 'Color Polymorphic',
|
||||||
},
|
},
|
||||||
IntegerCollection: {
|
ConditioningCollection: {
|
||||||
title: 'Integer Collection',
|
color: 'cyan.500',
|
||||||
description: 'A collection of integers.',
|
description: 'Conditioning may be passed between nodes.',
|
||||||
color: 'red.500',
|
title: 'Conditioning Collection',
|
||||||
|
},
|
||||||
|
ConditioningField: {
|
||||||
|
color: 'cyan.500',
|
||||||
|
description: 'Conditioning may be passed between nodes.',
|
||||||
|
title: 'Conditioning',
|
||||||
|
},
|
||||||
|
ConditioningPolymorphic: {
|
||||||
|
color: 'cyan.500',
|
||||||
|
description: 'Conditioning may be passed between nodes.',
|
||||||
|
title: 'Conditioning Polymorphic',
|
||||||
|
},
|
||||||
|
ControlCollection: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'Control info passed between nodes.',
|
||||||
|
title: 'Control Collection',
|
||||||
|
},
|
||||||
|
ControlField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'Control info passed between nodes.',
|
||||||
|
title: 'Control',
|
||||||
|
},
|
||||||
|
ControlNetModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'ControlNet',
|
||||||
|
},
|
||||||
|
ControlPolymorphic: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'Control info passed between nodes.',
|
||||||
|
title: 'Control Polymorphic',
|
||||||
|
},
|
||||||
|
DenoiseMaskField: {
|
||||||
|
color: 'blue.300',
|
||||||
|
description: 'Denoise Mask may be passed between nodes',
|
||||||
|
title: 'Denoise Mask',
|
||||||
|
},
|
||||||
|
enum: {
|
||||||
|
color: 'blue.500',
|
||||||
|
description: 'Enums are values that may be one of a number of options.',
|
||||||
|
title: 'Enum',
|
||||||
|
},
|
||||||
|
float: {
|
||||||
|
color: 'orange.500',
|
||||||
|
description: 'Floats are numbers with a decimal point.',
|
||||||
|
title: 'Float',
|
||||||
},
|
},
|
||||||
FloatCollection: {
|
FloatCollection: {
|
||||||
color: 'orange.500',
|
color: 'orange.500',
|
||||||
title: 'Float Collection',
|
|
||||||
description: 'A collection of floats.',
|
description: 'A collection of floats.',
|
||||||
|
title: 'Float Collection',
|
||||||
},
|
},
|
||||||
ColorCollection: {
|
FloatPolymorphic: {
|
||||||
color: 'base.500',
|
color: 'orange.500',
|
||||||
title: 'Color Collection',
|
description: 'A collection of floats.',
|
||||||
description: 'A collection of colors.',
|
title: 'Float Polymorphic',
|
||||||
|
},
|
||||||
|
ImageCollection: {
|
||||||
|
color: 'purple.500',
|
||||||
|
description: 'A collection of images.',
|
||||||
|
title: 'Image Collection',
|
||||||
|
},
|
||||||
|
ImageField: {
|
||||||
|
color: 'purple.500',
|
||||||
|
description: 'Images may be passed between nodes.',
|
||||||
|
title: 'Image',
|
||||||
|
},
|
||||||
|
ImagePolymorphic: {
|
||||||
|
color: 'purple.500',
|
||||||
|
description: 'A collection of images.',
|
||||||
|
title: 'Image Polymorphic',
|
||||||
|
},
|
||||||
|
integer: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'Integers are whole numbers, without a decimal point.',
|
||||||
|
title: 'Integer',
|
||||||
|
},
|
||||||
|
IntegerCollection: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'A collection of integers.',
|
||||||
|
title: 'Integer Collection',
|
||||||
|
},
|
||||||
|
IntegerPolymorphic: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'A collection of integers.',
|
||||||
|
title: 'Integer Polymorphic',
|
||||||
|
},
|
||||||
|
LatentsCollection: {
|
||||||
|
color: 'pink.500',
|
||||||
|
description: 'Latents may be passed between nodes.',
|
||||||
|
title: 'Latents Collection',
|
||||||
|
},
|
||||||
|
LatentsField: {
|
||||||
|
color: 'pink.500',
|
||||||
|
description: 'Latents may be passed between nodes.',
|
||||||
|
title: 'Latents',
|
||||||
|
},
|
||||||
|
LatentsPolymorphic: {
|
||||||
|
color: 'pink.500',
|
||||||
|
description: 'Latents may be passed between nodes.',
|
||||||
|
title: 'Latents Polymorphic',
|
||||||
|
},
|
||||||
|
LoRAModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'LoRA',
|
||||||
|
},
|
||||||
|
MainModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'Model',
|
||||||
},
|
},
|
||||||
ONNXModelField: {
|
ONNXModelField: {
|
||||||
color: 'base.500',
|
color: 'teal.500',
|
||||||
title: 'ONNX Model',
|
|
||||||
description: 'ONNX model field.',
|
description: 'ONNX model field.',
|
||||||
|
title: 'ONNX Model',
|
||||||
|
},
|
||||||
|
Scheduler: {
|
||||||
|
color: 'base.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'Scheduler',
|
||||||
},
|
},
|
||||||
SDXLMainModelField: {
|
SDXLMainModelField: {
|
||||||
color: 'base.500',
|
color: 'teal.500',
|
||||||
title: 'SDXL Model',
|
|
||||||
description: 'SDXL model field.',
|
description: 'SDXL model field.',
|
||||||
|
title: 'SDXL Model',
|
||||||
|
},
|
||||||
|
SDXLRefinerModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'Refiner Model',
|
||||||
|
},
|
||||||
|
string: {
|
||||||
|
color: 'yellow.500',
|
||||||
|
description: 'Strings are text.',
|
||||||
|
title: 'String',
|
||||||
},
|
},
|
||||||
StringCollection: {
|
StringCollection: {
|
||||||
color: 'yellow.500',
|
color: 'yellow.500',
|
||||||
title: 'String Collection',
|
|
||||||
description: 'A collection of strings.',
|
description: 'A collection of strings.',
|
||||||
|
title: 'String Collection',
|
||||||
|
},
|
||||||
|
StringPolymorphic: {
|
||||||
|
color: 'yellow.500',
|
||||||
|
description: 'A collection of strings.',
|
||||||
|
title: 'String Polymorphic',
|
||||||
|
},
|
||||||
|
UNetField: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'UNet submodel.',
|
||||||
|
title: 'UNet',
|
||||||
|
},
|
||||||
|
VaeField: {
|
||||||
|
color: 'blue.500',
|
||||||
|
description: 'Vae submodel.',
|
||||||
|
title: 'Vae',
|
||||||
|
},
|
||||||
|
VaeModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'VAE',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -11,7 +11,7 @@ import { keyBy } from 'lodash-es';
|
|||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { Node } from 'reactflow';
|
import { Node } from 'reactflow';
|
||||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
AnyInvocationType,
|
AnyInvocationType,
|
||||||
AnyResult,
|
AnyResult,
|
||||||
@ -52,6 +52,10 @@ export type InvocationTemplate = {
|
|||||||
* The type of this node's output
|
* The type of this node's output
|
||||||
*/
|
*/
|
||||||
outputType: string; // TODO: generate a union of output types
|
outputType: string; // TODO: generate a union of output types
|
||||||
|
/**
|
||||||
|
* The invocation's version.
|
||||||
|
*/
|
||||||
|
version?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type FieldUIConfig = {
|
export type FieldUIConfig = {
|
||||||
@ -62,50 +66,48 @@ export type FieldUIConfig = {
|
|||||||
|
|
||||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||||
export const zFieldType = z.enum([
|
export const zFieldType = z.enum([
|
||||||
// region Primitives
|
|
||||||
'integer',
|
|
||||||
'float',
|
|
||||||
'boolean',
|
'boolean',
|
||||||
'string',
|
|
||||||
'array',
|
|
||||||
'ImageField',
|
|
||||||
'DenoiseMaskField',
|
|
||||||
'LatentsField',
|
|
||||||
'ConditioningField',
|
|
||||||
'ControlField',
|
|
||||||
'ColorField',
|
|
||||||
'ImageCollection',
|
|
||||||
'ConditioningCollection',
|
|
||||||
'ColorCollection',
|
|
||||||
'LatentsCollection',
|
|
||||||
'IntegerCollection',
|
|
||||||
'FloatCollection',
|
|
||||||
'StringCollection',
|
|
||||||
'BooleanCollection',
|
'BooleanCollection',
|
||||||
// endregion
|
'BooleanPolymorphic',
|
||||||
|
|
||||||
// region Models
|
|
||||||
'MainModelField',
|
|
||||||
'SDXLMainModelField',
|
|
||||||
'SDXLRefinerModelField',
|
|
||||||
'ONNXModelField',
|
|
||||||
'VaeModelField',
|
|
||||||
'LoRAModelField',
|
|
||||||
'ControlNetModelField',
|
|
||||||
'UNetField',
|
|
||||||
'VaeField',
|
|
||||||
'ClipField',
|
'ClipField',
|
||||||
// endregion
|
|
||||||
|
|
||||||
// region Iterate/Collect
|
|
||||||
'Collection',
|
'Collection',
|
||||||
'CollectionItem',
|
'CollectionItem',
|
||||||
// endregion
|
'ColorCollection',
|
||||||
|
'ColorField',
|
||||||
// region Misc
|
'ColorPolymorphic',
|
||||||
|
'ConditioningCollection',
|
||||||
|
'ConditioningField',
|
||||||
|
'ConditioningPolymorphic',
|
||||||
|
'ControlCollection',
|
||||||
|
'ControlField',
|
||||||
|
'ControlNetModelField',
|
||||||
|
'ControlPolymorphic',
|
||||||
|
'DenoiseMaskField',
|
||||||
'enum',
|
'enum',
|
||||||
|
'float',
|
||||||
|
'FloatCollection',
|
||||||
|
'FloatPolymorphic',
|
||||||
|
'ImageCollection',
|
||||||
|
'ImageField',
|
||||||
|
'ImagePolymorphic',
|
||||||
|
'integer',
|
||||||
|
'IntegerCollection',
|
||||||
|
'IntegerPolymorphic',
|
||||||
|
'LatentsCollection',
|
||||||
|
'LatentsField',
|
||||||
|
'LatentsPolymorphic',
|
||||||
|
'LoRAModelField',
|
||||||
|
'MainModelField',
|
||||||
|
'ONNXModelField',
|
||||||
'Scheduler',
|
'Scheduler',
|
||||||
// endregion
|
'SDXLMainModelField',
|
||||||
|
'SDXLRefinerModelField',
|
||||||
|
'string',
|
||||||
|
'StringCollection',
|
||||||
|
'StringPolymorphic',
|
||||||
|
'UNetField',
|
||||||
|
'VaeField',
|
||||||
|
'VaeModelField',
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type FieldType = z.infer<typeof zFieldType>;
|
export type FieldType = z.infer<typeof zFieldType>;
|
||||||
@ -122,38 +124,6 @@ export const isFieldType = (value: unknown): value is FieldType =>
|
|||||||
zFieldType.safeParse(value).success ||
|
zFieldType.safeParse(value).success ||
|
||||||
zReservedFieldType.safeParse(value).success;
|
zReservedFieldType.safeParse(value).success;
|
||||||
|
|
||||||
/**
|
|
||||||
* An input field template is generated on each page load from the OpenAPI schema.
|
|
||||||
*
|
|
||||||
* The template provides the field type and other field metadata (e.g. title, description,
|
|
||||||
* maximum length, pattern to match, etc).
|
|
||||||
*/
|
|
||||||
export type InputFieldTemplate =
|
|
||||||
| IntegerInputFieldTemplate
|
|
||||||
| FloatInputFieldTemplate
|
|
||||||
| StringInputFieldTemplate
|
|
||||||
| BooleanInputFieldTemplate
|
|
||||||
| ImageInputFieldTemplate
|
|
||||||
| DenoiseMaskInputFieldTemplate
|
|
||||||
| LatentsInputFieldTemplate
|
|
||||||
| ConditioningInputFieldTemplate
|
|
||||||
| UNetInputFieldTemplate
|
|
||||||
| ClipInputFieldTemplate
|
|
||||||
| VaeInputFieldTemplate
|
|
||||||
| ControlInputFieldTemplate
|
|
||||||
| EnumInputFieldTemplate
|
|
||||||
| MainModelInputFieldTemplate
|
|
||||||
| SDXLMainModelInputFieldTemplate
|
|
||||||
| SDXLRefinerModelInputFieldTemplate
|
|
||||||
| VaeModelInputFieldTemplate
|
|
||||||
| LoRAModelInputFieldTemplate
|
|
||||||
| ControlNetModelInputFieldTemplate
|
|
||||||
| CollectionInputFieldTemplate
|
|
||||||
| CollectionItemInputFieldTemplate
|
|
||||||
| ColorInputFieldTemplate
|
|
||||||
| ImageCollectionInputFieldTemplate
|
|
||||||
| SchedulerInputFieldTemplate;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates the kind of input(s) this field may have.
|
* Indicates the kind of input(s) this field may have.
|
||||||
*/
|
*/
|
||||||
@ -232,24 +202,88 @@ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
||||||
|
|
||||||
|
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('IntegerCollection'),
|
||||||
|
value: z.array(z.number().int()).optional(),
|
||||||
|
});
|
||||||
|
export type IntegerCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zIntegerCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('IntegerPolymorphic'),
|
||||||
|
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
|
||||||
|
});
|
||||||
|
export type IntegerPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zIntegerPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
|
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('float'),
|
type: z.literal('float'),
|
||||||
value: z.number().optional(),
|
value: z.number().optional(),
|
||||||
});
|
});
|
||||||
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
|
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
|
||||||
|
|
||||||
|
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('FloatCollection'),
|
||||||
|
value: z.array(z.number()).optional(),
|
||||||
|
});
|
||||||
|
export type FloatCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zFloatCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('FloatPolymorphic'),
|
||||||
|
value: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
|
});
|
||||||
|
export type FloatPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zFloatPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zStringInputFieldValue = zInputFieldValueBase.extend({
|
export const zStringInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('string'),
|
type: z.literal('string'),
|
||||||
value: z.string().optional(),
|
value: z.string().optional(),
|
||||||
});
|
});
|
||||||
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
|
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
|
||||||
|
|
||||||
|
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('StringCollection'),
|
||||||
|
value: z.array(z.string()).optional(),
|
||||||
|
});
|
||||||
|
export type StringCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zStringCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('StringPolymorphic'),
|
||||||
|
value: z.union([z.string(), z.array(z.string())]).optional(),
|
||||||
|
});
|
||||||
|
export type StringPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zStringPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
|
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('boolean'),
|
type: z.literal('boolean'),
|
||||||
value: z.boolean().optional(),
|
value: z.boolean().optional(),
|
||||||
});
|
});
|
||||||
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
|
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
|
||||||
|
|
||||||
|
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('BooleanCollection'),
|
||||||
|
value: z.array(z.boolean()).optional(),
|
||||||
|
});
|
||||||
|
export type BooleanCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zBooleanCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('BooleanPolymorphic'),
|
||||||
|
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
|
||||||
|
});
|
||||||
|
export type BooleanPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zBooleanPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
|
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('enum'),
|
type: z.literal('enum'),
|
||||||
value: z.union([z.string(), z.number()]).optional(),
|
value: z.union([z.string(), z.number()]).optional(),
|
||||||
@ -262,6 +296,22 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
||||||
|
|
||||||
|
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('LatentsCollection'),
|
||||||
|
value: z.array(zLatentsField).optional(),
|
||||||
|
});
|
||||||
|
export type LatentsCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zLatentsCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('LatentsPolymorphic'),
|
||||||
|
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
|
||||||
|
});
|
||||||
|
export type LatentsPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zLatentsPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('DenoiseMaskField'),
|
type: z.literal('DenoiseMaskField'),
|
||||||
value: zDenoiseMaskField.optional(),
|
value: zDenoiseMaskField.optional(),
|
||||||
@ -278,6 +328,26 @@ export type ConditioningInputFieldValue = z.infer<
|
|||||||
typeof zConditioningInputFieldValue
|
typeof zConditioningInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zConditioningCollectionInputFieldValue =
|
||||||
|
zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ConditioningCollection'),
|
||||||
|
value: z.array(zConditioningField).optional(),
|
||||||
|
});
|
||||||
|
export type ConditioningCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zConditioningCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zConditioningPolymorphicInputFieldValue =
|
||||||
|
zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ConditioningPolymorphic'),
|
||||||
|
value: z
|
||||||
|
.union([zConditioningField, z.array(zConditioningField)])
|
||||||
|
.optional(),
|
||||||
|
});
|
||||||
|
export type ConditioningPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zConditioningPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zControlNetModel = zModelIdentifier;
|
export const zControlNetModel = zModelIdentifier;
|
||||||
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
||||||
|
|
||||||
@ -302,6 +372,22 @@ export const zControlInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
|
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
|
||||||
|
|
||||||
|
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ControlPolymorphic'),
|
||||||
|
value: z.union([zControlField, z.array(zControlField)]).optional(),
|
||||||
|
});
|
||||||
|
export type ControlPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zControlPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ControlCollection'),
|
||||||
|
value: z.array(zControlField).optional(),
|
||||||
|
});
|
||||||
|
export type ControlCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zControlCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zModelType = z.enum([
|
export const zModelType = z.enum([
|
||||||
'onnx',
|
'onnx',
|
||||||
'main',
|
'main',
|
||||||
@ -381,6 +467,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
||||||
|
|
||||||
|
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ImagePolymorphic'),
|
||||||
|
value: z.union([zImageField, z.array(zImageField)]).optional(),
|
||||||
|
});
|
||||||
|
export type ImagePolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zImagePolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
|
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('ImageCollection'),
|
type: z.literal('ImageCollection'),
|
||||||
value: z.array(zImageField).optional(),
|
value: z.array(zImageField).optional(),
|
||||||
@ -473,6 +567,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
|
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
|
||||||
|
|
||||||
|
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ColorCollection'),
|
||||||
|
value: z.array(zColorField).optional(),
|
||||||
|
});
|
||||||
|
export type ColorCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zColorCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ColorPolymorphic'),
|
||||||
|
value: z.union([zColorField, z.array(zColorField)]).optional(),
|
||||||
|
});
|
||||||
|
export type ColorPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zColorPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
|
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('Scheduler'),
|
type: z.literal('Scheduler'),
|
||||||
value: zScheduler.optional(),
|
value: zScheduler.optional(),
|
||||||
@ -482,30 +592,47 @@ export type SchedulerInputFieldValue = z.infer<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||||
zIntegerInputFieldValue,
|
zBooleanCollectionInputFieldValue,
|
||||||
zFloatInputFieldValue,
|
|
||||||
zStringInputFieldValue,
|
|
||||||
zBooleanInputFieldValue,
|
zBooleanInputFieldValue,
|
||||||
zImageInputFieldValue,
|
zBooleanPolymorphicInputFieldValue,
|
||||||
zLatentsInputFieldValue,
|
|
||||||
zDenoiseMaskInputFieldValue,
|
|
||||||
zConditioningInputFieldValue,
|
|
||||||
zUNetInputFieldValue,
|
|
||||||
zClipInputFieldValue,
|
zClipInputFieldValue,
|
||||||
zVaeInputFieldValue,
|
|
||||||
zControlInputFieldValue,
|
|
||||||
zEnumInputFieldValue,
|
|
||||||
zMainModelInputFieldValue,
|
|
||||||
zSDXLMainModelInputFieldValue,
|
|
||||||
zSDXLRefinerModelInputFieldValue,
|
|
||||||
zVaeModelInputFieldValue,
|
|
||||||
zLoRAModelInputFieldValue,
|
|
||||||
zControlNetModelInputFieldValue,
|
|
||||||
zCollectionInputFieldValue,
|
zCollectionInputFieldValue,
|
||||||
zCollectionItemInputFieldValue,
|
zCollectionItemInputFieldValue,
|
||||||
zColorInputFieldValue,
|
zColorInputFieldValue,
|
||||||
|
zColorCollectionInputFieldValue,
|
||||||
|
zColorPolymorphicInputFieldValue,
|
||||||
|
zConditioningInputFieldValue,
|
||||||
|
zConditioningCollectionInputFieldValue,
|
||||||
|
zConditioningPolymorphicInputFieldValue,
|
||||||
|
zControlInputFieldValue,
|
||||||
|
zControlNetModelInputFieldValue,
|
||||||
|
zControlCollectionInputFieldValue,
|
||||||
|
zControlPolymorphicInputFieldValue,
|
||||||
|
zDenoiseMaskInputFieldValue,
|
||||||
|
zEnumInputFieldValue,
|
||||||
|
zFloatCollectionInputFieldValue,
|
||||||
|
zFloatInputFieldValue,
|
||||||
|
zFloatPolymorphicInputFieldValue,
|
||||||
zImageCollectionInputFieldValue,
|
zImageCollectionInputFieldValue,
|
||||||
|
zImagePolymorphicInputFieldValue,
|
||||||
|
zImageInputFieldValue,
|
||||||
|
zIntegerCollectionInputFieldValue,
|
||||||
|
zIntegerPolymorphicInputFieldValue,
|
||||||
|
zIntegerInputFieldValue,
|
||||||
|
zLatentsInputFieldValue,
|
||||||
|
zLatentsCollectionInputFieldValue,
|
||||||
|
zLatentsPolymorphicInputFieldValue,
|
||||||
|
zLoRAModelInputFieldValue,
|
||||||
|
zMainModelInputFieldValue,
|
||||||
zSchedulerInputFieldValue,
|
zSchedulerInputFieldValue,
|
||||||
|
zSDXLMainModelInputFieldValue,
|
||||||
|
zSDXLRefinerModelInputFieldValue,
|
||||||
|
zStringCollectionInputFieldValue,
|
||||||
|
zStringPolymorphicInputFieldValue,
|
||||||
|
zStringInputFieldValue,
|
||||||
|
zUNetInputFieldValue,
|
||||||
|
zVaeInputFieldValue,
|
||||||
|
zVaeModelInputFieldValue,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||||
@ -514,7 +641,6 @@ export type InputFieldTemplateBase = {
|
|||||||
name: string;
|
name: string;
|
||||||
title: string;
|
title: string;
|
||||||
description: string;
|
description: string;
|
||||||
type: FieldType;
|
|
||||||
required: boolean;
|
required: boolean;
|
||||||
fieldKind: 'input';
|
fieldKind: 'input';
|
||||||
} & _InputField;
|
} & _InputField;
|
||||||
@ -529,6 +655,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
exclusiveMinimum?: boolean;
|
exclusiveMinimum?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'IntegerCollection';
|
||||||
|
default: number[];
|
||||||
|
item_default?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type IntegerPolymorphicInputFieldTemplate = Omit<
|
||||||
|
IntegerInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'IntegerPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
type: 'float';
|
type: 'float';
|
||||||
default: number;
|
default: number;
|
||||||
@ -539,6 +678,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
exclusiveMinimum?: boolean;
|
exclusiveMinimum?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'FloatCollection';
|
||||||
|
default: number[];
|
||||||
|
item_default?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type FloatPolymorphicInputFieldTemplate = Omit<
|
||||||
|
FloatInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'FloatPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
type: 'string';
|
type: 'string';
|
||||||
default: string;
|
default: string;
|
||||||
@ -547,19 +699,53 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
pattern?: string;
|
pattern?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'StringCollection';
|
||||||
|
default: string[];
|
||||||
|
item_default?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type StringPolymorphicInputFieldTemplate = Omit<
|
||||||
|
StringInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'StringPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: boolean;
|
default: boolean;
|
||||||
type: 'boolean';
|
type: 'boolean';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'BooleanCollection';
|
||||||
|
default: boolean[];
|
||||||
|
item_default?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type BooleanPolymorphicInputFieldTemplate = Omit<
|
||||||
|
BooleanInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'BooleanPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: ImageDTO;
|
default: ImageField;
|
||||||
type: 'ImageField';
|
type: 'ImageField';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: ImageField[];
|
default: ImageField[];
|
||||||
type: 'ImageCollection';
|
type: 'ImageCollection';
|
||||||
|
item_default?: ImageField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ImagePolymorphicInputFieldTemplate = Omit<
|
||||||
|
ImageInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ImagePolymorphic';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
@ -568,15 +754,40 @@ export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string;
|
default: LatentsField;
|
||||||
type: 'LatentsField';
|
type: 'LatentsField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: LatentsField[];
|
||||||
|
type: 'LatentsCollection';
|
||||||
|
item_default?: LatentsField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: LatentsField;
|
||||||
|
type: 'LatentsPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: undefined;
|
default: undefined;
|
||||||
type: 'ConditioningField';
|
type: 'ConditioningField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ConditioningCollectionInputFieldTemplate =
|
||||||
|
InputFieldTemplateBase & {
|
||||||
|
default: ConditioningField[];
|
||||||
|
type: 'ConditioningCollection';
|
||||||
|
item_default?: ConditioningField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ConditioningPolymorphicInputFieldTemplate = Omit<
|
||||||
|
ConditioningInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ConditioningPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: undefined;
|
default: undefined;
|
||||||
type: 'UNetField';
|
type: 'UNetField';
|
||||||
@ -597,6 +808,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'ControlField';
|
type: 'ControlField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'ControlCollection';
|
||||||
|
item_default?: ControlField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ControlPolymorphicInputFieldTemplate = Omit<
|
||||||
|
ControlInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ControlPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string | number;
|
default: string | number;
|
||||||
type: 'enum';
|
type: 'enum';
|
||||||
@ -649,6 +873,18 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'ColorField';
|
type: 'ColorField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ColorPolymorphicInputFieldTemplate = Omit<
|
||||||
|
ColorInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ColorPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: [];
|
||||||
|
type: 'ColorCollection';
|
||||||
|
};
|
||||||
|
|
||||||
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: SchedulerParam;
|
default: SchedulerParam;
|
||||||
type: 'Scheduler';
|
type: 'Scheduler';
|
||||||
@ -659,6 +895,55 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'WorkflowField';
|
type: 'WorkflowField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An input field template is generated on each page load from the OpenAPI schema.
|
||||||
|
*
|
||||||
|
* The template provides the field type and other field metadata (e.g. title, description,
|
||||||
|
* maximum length, pattern to match, etc).
|
||||||
|
*/
|
||||||
|
export type InputFieldTemplate =
|
||||||
|
| BooleanCollectionInputFieldTemplate
|
||||||
|
| BooleanPolymorphicInputFieldTemplate
|
||||||
|
| BooleanInputFieldTemplate
|
||||||
|
| ClipInputFieldTemplate
|
||||||
|
| CollectionInputFieldTemplate
|
||||||
|
| CollectionItemInputFieldTemplate
|
||||||
|
| ColorInputFieldTemplate
|
||||||
|
| ColorCollectionInputFieldTemplate
|
||||||
|
| ColorPolymorphicInputFieldTemplate
|
||||||
|
| ConditioningInputFieldTemplate
|
||||||
|
| ConditioningCollectionInputFieldTemplate
|
||||||
|
| ConditioningPolymorphicInputFieldTemplate
|
||||||
|
| ControlInputFieldTemplate
|
||||||
|
| ControlCollectionInputFieldTemplate
|
||||||
|
| ControlNetModelInputFieldTemplate
|
||||||
|
| ControlPolymorphicInputFieldTemplate
|
||||||
|
| DenoiseMaskInputFieldTemplate
|
||||||
|
| EnumInputFieldTemplate
|
||||||
|
| FloatCollectionInputFieldTemplate
|
||||||
|
| FloatInputFieldTemplate
|
||||||
|
| FloatPolymorphicInputFieldTemplate
|
||||||
|
| ImageCollectionInputFieldTemplate
|
||||||
|
| ImagePolymorphicInputFieldTemplate
|
||||||
|
| ImageInputFieldTemplate
|
||||||
|
| IntegerCollectionInputFieldTemplate
|
||||||
|
| IntegerPolymorphicInputFieldTemplate
|
||||||
|
| IntegerInputFieldTemplate
|
||||||
|
| LatentsInputFieldTemplate
|
||||||
|
| LatentsCollectionInputFieldTemplate
|
||||||
|
| LatentsPolymorphicInputFieldTemplate
|
||||||
|
| LoRAModelInputFieldTemplate
|
||||||
|
| MainModelInputFieldTemplate
|
||||||
|
| SchedulerInputFieldTemplate
|
||||||
|
| SDXLMainModelInputFieldTemplate
|
||||||
|
| SDXLRefinerModelInputFieldTemplate
|
||||||
|
| StringCollectionInputFieldTemplate
|
||||||
|
| StringPolymorphicInputFieldTemplate
|
||||||
|
| StringInputFieldTemplate
|
||||||
|
| UNetInputFieldTemplate
|
||||||
|
| VaeInputFieldTemplate
|
||||||
|
| VaeModelInputFieldTemplate;
|
||||||
|
|
||||||
export const isInputFieldValue = (
|
export const isInputFieldValue = (
|
||||||
field?: InputFieldValue | OutputFieldValue
|
field?: InputFieldValue | OutputFieldValue
|
||||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||||
@ -681,6 +966,7 @@ export type InvocationSchemaExtra = {
|
|||||||
title: string;
|
title: string;
|
||||||
category?: string;
|
category?: string;
|
||||||
tags?: string[];
|
tags?: string[];
|
||||||
|
version?: string;
|
||||||
properties: Omit<
|
properties: Omit<
|
||||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
||||||
(_InputField | _OutputField),
|
(_InputField | _OutputField),
|
||||||
@ -731,8 +1017,22 @@ export type InvocationSchemaObject = (
|
|||||||
) & { class: 'invocation' };
|
) & { class: 'invocation' };
|
||||||
|
|
||||||
export const isSchemaObject = (
|
export const isSchemaObject = (
|
||||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
|
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||||
|
|
||||||
|
export const isArraySchemaObject = (
|
||||||
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
|
): obj is OpenAPIV3.ArraySchemaObject =>
|
||||||
|
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||||
|
|
||||||
|
export const isNonArraySchemaObject = (
|
||||||
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
|
): obj is OpenAPIV3.NonArraySchemaObject =>
|
||||||
|
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||||
|
|
||||||
|
export const isRefObject = (
|
||||||
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
|
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||||
|
|
||||||
export const isInvocationSchemaObject = (
|
export const isInvocationSchemaObject = (
|
||||||
obj:
|
obj:
|
||||||
@ -800,6 +1100,29 @@ export const zCoreMetadata = z
|
|||||||
|
|
||||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||||
|
|
||||||
|
export const zSemVer = z.string().refine((val) => {
|
||||||
|
const [major, minor, patch] = val.split('.');
|
||||||
|
return (
|
||||||
|
major !== undefined &&
|
||||||
|
Number.isInteger(Number(major)) &&
|
||||||
|
minor !== undefined &&
|
||||||
|
Number.isInteger(Number(minor)) &&
|
||||||
|
patch !== undefined &&
|
||||||
|
Number.isInteger(Number(patch))
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
export const zParsedSemver = zSemVer.transform((val) => {
|
||||||
|
const [major, minor, patch] = val.split('.');
|
||||||
|
return {
|
||||||
|
major: Number(major),
|
||||||
|
minor: Number(minor),
|
||||||
|
patch: Number(patch),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
export type SemVer = z.infer<typeof zSemVer>;
|
||||||
|
|
||||||
export const zInvocationNodeData = z.object({
|
export const zInvocationNodeData = z.object({
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
||||||
@ -812,6 +1135,7 @@ export const zInvocationNodeData = z.object({
|
|||||||
notes: z.string(),
|
notes: z.string(),
|
||||||
embedWorkflow: z.boolean(),
|
embedWorkflow: z.boolean(),
|
||||||
isIntermediate: z.boolean(),
|
isIntermediate: z.boolean(),
|
||||||
|
version: zSemVer.optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Massage this to get better type safety while developing
|
// Massage this to get better type safety while developing
|
||||||
@ -900,20 +1224,6 @@ export const zFieldIdentifier = z.object({
|
|||||||
|
|
||||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||||
|
|
||||||
export const zSemVer = z.string().refine((val) => {
|
|
||||||
const [major, minor, patch] = val.split('.');
|
|
||||||
return (
|
|
||||||
major !== undefined &&
|
|
||||||
minor !== undefined &&
|
|
||||||
patch !== undefined &&
|
|
||||||
Number.isInteger(Number(major)) &&
|
|
||||||
Number.isInteger(Number(minor)) &&
|
|
||||||
Number.isInteger(Number(patch))
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
export type SemVer = z.infer<typeof zSemVer>;
|
|
||||||
|
|
||||||
export type WorkflowWarning = {
|
export type WorkflowWarning = {
|
||||||
message: string;
|
message: string;
|
||||||
issues: string[];
|
issues: string[];
|
||||||
|
@ -1,5 +1,14 @@
|
|||||||
|
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import {
|
import {
|
||||||
|
COLLECTION_MAP,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
|
SINGLE_TO_POLYMORPHIC_MAP,
|
||||||
|
isCollectionItemType,
|
||||||
|
isPolymorphicItemType,
|
||||||
|
} from '../types/constants';
|
||||||
|
import {
|
||||||
|
BooleanCollectionInputFieldTemplate,
|
||||||
BooleanInputFieldTemplate,
|
BooleanInputFieldTemplate,
|
||||||
ClipInputFieldTemplate,
|
ClipInputFieldTemplate,
|
||||||
CollectionInputFieldTemplate,
|
CollectionInputFieldTemplate,
|
||||||
@ -11,10 +20,13 @@ import {
|
|||||||
DenoiseMaskInputFieldTemplate,
|
DenoiseMaskInputFieldTemplate,
|
||||||
EnumInputFieldTemplate,
|
EnumInputFieldTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
|
FloatCollectionInputFieldTemplate,
|
||||||
|
FloatPolymorphicInputFieldTemplate,
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
ImageCollectionInputFieldTemplate,
|
ImageCollectionInputFieldTemplate,
|
||||||
ImageInputFieldTemplate,
|
ImageInputFieldTemplate,
|
||||||
InputFieldTemplateBase,
|
InputFieldTemplateBase,
|
||||||
|
IntegerCollectionInputFieldTemplate,
|
||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
InvocationFieldSchema,
|
InvocationFieldSchema,
|
||||||
InvocationSchemaObject,
|
InvocationSchemaObject,
|
||||||
@ -24,11 +36,32 @@ import {
|
|||||||
SDXLMainModelInputFieldTemplate,
|
SDXLMainModelInputFieldTemplate,
|
||||||
SDXLRefinerModelInputFieldTemplate,
|
SDXLRefinerModelInputFieldTemplate,
|
||||||
SchedulerInputFieldTemplate,
|
SchedulerInputFieldTemplate,
|
||||||
|
StringCollectionInputFieldTemplate,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
VaeInputFieldTemplate,
|
VaeInputFieldTemplate,
|
||||||
VaeModelInputFieldTemplate,
|
VaeModelInputFieldTemplate,
|
||||||
|
isArraySchemaObject,
|
||||||
|
isNonArraySchemaObject,
|
||||||
|
isRefObject,
|
||||||
|
isSchemaObject,
|
||||||
|
ControlPolymorphicInputFieldTemplate,
|
||||||
|
ColorPolymorphicInputFieldTemplate,
|
||||||
|
ColorCollectionInputFieldTemplate,
|
||||||
|
IntegerPolymorphicInputFieldTemplate,
|
||||||
|
StringPolymorphicInputFieldTemplate,
|
||||||
|
BooleanPolymorphicInputFieldTemplate,
|
||||||
|
ImagePolymorphicInputFieldTemplate,
|
||||||
|
LatentsPolymorphicInputFieldTemplate,
|
||||||
|
LatentsCollectionInputFieldTemplate,
|
||||||
|
ConditioningPolymorphicInputFieldTemplate,
|
||||||
|
ConditioningCollectionInputFieldTemplate,
|
||||||
|
ControlCollectionInputFieldTemplate,
|
||||||
|
ImageField,
|
||||||
|
LatentsField,
|
||||||
|
ConditioningField,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
|
import { ControlField } from 'services/api/types';
|
||||||
|
|
||||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||||
|
|
||||||
@ -45,15 +78,8 @@ export type BuildInputFieldArg = {
|
|||||||
* @example
|
* @example
|
||||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||||
*/
|
*/
|
||||||
export const refObjectToFieldType = (
|
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
|
||||||
refObject: OpenAPIV3.ReferenceObject
|
refObject.$ref.split('/').slice(-1)[0];
|
||||||
): FieldType => {
|
|
||||||
const name = refObject.$ref.split('/').slice(-1)[0];
|
|
||||||
if (!name) {
|
|
||||||
throw `Unknown field type: ${name}`;
|
|
||||||
}
|
|
||||||
return name as FieldType;
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildIntegerInputFieldTemplate = ({
|
const buildIntegerInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildIntegerPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
|
||||||
|
const template: IntegerPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'IntegerPolymorphic',
|
||||||
|
default: schemaObject.default ?? 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (schemaObject.multipleOf !== undefined) {
|
||||||
|
template.multipleOf = schemaObject.multipleOf;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.maximum !== undefined) {
|
||||||
|
template.maximum = schemaObject.maximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||||
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.minimum !== undefined) {
|
||||||
|
template.minimum = schemaObject.minimum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||||
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
|
}
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildIntegerCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
|
||||||
|
const item_default =
|
||||||
|
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: 0;
|
||||||
|
const template: IntegerCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'IntegerCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildFloatInputFieldTemplate = ({
|
const buildFloatInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildFloatPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
|
||||||
|
const template: FloatPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'FloatPolymorphic',
|
||||||
|
default: schemaObject.default ?? 0,
|
||||||
|
};
|
||||||
|
if (schemaObject.multipleOf !== undefined) {
|
||||||
|
template.multipleOf = schemaObject.multipleOf;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.maximum !== undefined) {
|
||||||
|
template.maximum = schemaObject.maximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||||
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.minimum !== undefined) {
|
||||||
|
template.minimum = schemaObject.minimum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||||
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
|
}
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildFloatCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
|
||||||
|
const item_default = isNumber(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: 0;
|
||||||
|
const template: FloatCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'FloatCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildStringInputFieldTemplate = ({
|
const buildStringInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildStringPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
|
||||||
|
const template: StringPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'StringPolymorphic',
|
||||||
|
default: schemaObject.default ?? '',
|
||||||
|
};
|
||||||
|
|
||||||
|
if (schemaObject.minLength !== undefined) {
|
||||||
|
template.minLength = schemaObject.minLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.maxLength !== undefined) {
|
||||||
|
template.maxLength = schemaObject.maxLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.pattern !== undefined) {
|
||||||
|
template.pattern = schemaObject.pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildStringCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
|
||||||
|
const item_default = isString(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: '';
|
||||||
|
const template: StringCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'StringCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildBooleanInputFieldTemplate = ({
|
const buildBooleanInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildBooleanPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
|
||||||
|
const template: BooleanPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'BooleanPolymorphic',
|
||||||
|
default: schemaObject.default ?? false,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildBooleanCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
|
||||||
|
const item_default =
|
||||||
|
schemaObject.item_default && isBoolean(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: false;
|
||||||
|
const template: BooleanCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'BooleanCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildMainModelInputFieldTemplate = ({
|
const buildMainModelInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildImagePolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
|
||||||
|
const template: ImagePolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ImagePolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageCollectionInputFieldTemplate = ({
|
const buildImageCollectionInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -257,7 +468,8 @@ const buildImageCollectionInputFieldTemplate = ({
|
|||||||
const template: ImageCollectionInputFieldTemplate = {
|
const template: ImageCollectionInputFieldTemplate = {
|
||||||
...baseField,
|
...baseField,
|
||||||
type: 'ImageCollection',
|
type: 'ImageCollection',
|
||||||
default: schemaObject.default ?? undefined,
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as ImageField) ?? undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
return template;
|
return template;
|
||||||
@ -289,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildLatentsPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
|
||||||
|
const template: LatentsPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'LatentsPolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildLatentsCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
|
||||||
|
const template: LatentsCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'LatentsCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildConditioningInputFieldTemplate = ({
|
const buildConditioningInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -302,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildConditioningPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
|
||||||
|
const template: ConditioningPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ConditioningPolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildConditioningCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
|
||||||
|
const template: ConditioningCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ConditioningCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildUNetInputFieldTemplate = ({
|
const buildUNetInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -355,6 +621,33 @@ const buildControlInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildControlPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
|
||||||
|
const template: ControlPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ControlPolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildControlCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
|
||||||
|
const template: ControlCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ControlCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as ControlField) ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildEnumInputFieldTemplate = ({
|
const buildEnumInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -408,6 +701,32 @@ const buildColorInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildColorPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
|
||||||
|
const template: ColorPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ColorPolymorphic',
|
||||||
|
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildColorCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
|
||||||
|
const template: ColorCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ColorCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildSchedulerInputFieldTemplate = ({
|
const buildSchedulerInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -421,45 +740,138 @@ const buildSchedulerInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
|
export const getFieldType = (
|
||||||
let fieldType = '';
|
schemaObject: InvocationFieldSchema
|
||||||
|
): string | undefined => {
|
||||||
const { ui_type } = schemaObject;
|
if (schemaObject?.ui_type) {
|
||||||
if (ui_type) {
|
return schemaObject.ui_type;
|
||||||
fieldType = ui_type;
|
|
||||||
} else if (!schemaObject.type) {
|
} else if (!schemaObject.type) {
|
||||||
// console.log('refObject', schemaObject);
|
|
||||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||||
|
|
||||||
if (schemaObject.allOf) {
|
if (schemaObject.allOf) {
|
||||||
fieldType = refObjectToFieldType(
|
const allOf = schemaObject.allOf;
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
return refObjectToSchemaName(allOf[0]);
|
||||||
);
|
}
|
||||||
} else if (schemaObject.anyOf) {
|
} else if (schemaObject.anyOf) {
|
||||||
fieldType = refObjectToFieldType(
|
const anyOf = schemaObject.anyOf;
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
/**
|
||||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||||
);
|
* - an `anyOf` with two items
|
||||||
} else if (schemaObject.oneOf) {
|
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||||
fieldType = refObjectToFieldType(
|
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
*
|
||||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
* Any other cases we ignore.
|
||||||
);
|
*/
|
||||||
|
|
||||||
|
let firstType: string | undefined;
|
||||||
|
let secondType: string | undefined;
|
||||||
|
|
||||||
|
if (isArraySchemaObject(anyOf[0])) {
|
||||||
|
// first is array, second is not
|
||||||
|
const first = anyOf[0].items;
|
||||||
|
const second = anyOf[1];
|
||||||
|
if (isRefObject(first) && isRefObject(second)) {
|
||||||
|
firstType = refObjectToSchemaName(first);
|
||||||
|
secondType = refObjectToSchemaName(second);
|
||||||
|
} else if (
|
||||||
|
isNonArraySchemaObject(first) &&
|
||||||
|
isNonArraySchemaObject(second)
|
||||||
|
) {
|
||||||
|
firstType = first.type;
|
||||||
|
secondType = second.type;
|
||||||
|
}
|
||||||
|
} else if (isArraySchemaObject(anyOf[1])) {
|
||||||
|
// first is not array, second is
|
||||||
|
const first = anyOf[0];
|
||||||
|
const second = anyOf[1].items;
|
||||||
|
if (isRefObject(first) && isRefObject(second)) {
|
||||||
|
firstType = refObjectToSchemaName(first);
|
||||||
|
secondType = refObjectToSchemaName(second);
|
||||||
|
} else if (
|
||||||
|
isNonArraySchemaObject(first) &&
|
||||||
|
isNonArraySchemaObject(second)
|
||||||
|
) {
|
||||||
|
firstType = first.type;
|
||||||
|
secondType = second.type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||||
|
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (schemaObject.enum) {
|
} else if (schemaObject.enum) {
|
||||||
fieldType = 'enum';
|
return 'enum';
|
||||||
} else if (schemaObject.type) {
|
} else if (schemaObject.type) {
|
||||||
if (schemaObject.type === 'number') {
|
if (schemaObject.type === 'number') {
|
||||||
// floats are "number" in OpenAPI, while ints are "integer"
|
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||||
fieldType = 'float';
|
return 'float';
|
||||||
|
} else if (schemaObject.type === 'array') {
|
||||||
|
const itemType = isSchemaObject(schemaObject.items)
|
||||||
|
? schemaObject.items.type
|
||||||
|
: refObjectToSchemaName(schemaObject.items);
|
||||||
|
|
||||||
|
if (isCollectionItemType(itemType)) {
|
||||||
|
return COLLECTION_MAP[itemType];
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
} else {
|
} else {
|
||||||
fieldType = schemaObject.type;
|
return schemaObject.type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
return fieldType;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const TEMPLATE_BUILDER_MAP = {
|
||||||
|
boolean: buildBooleanInputFieldTemplate,
|
||||||
|
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||||
|
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||||
|
ClipField: buildClipInputFieldTemplate,
|
||||||
|
Collection: buildCollectionInputFieldTemplate,
|
||||||
|
CollectionItem: buildCollectionItemInputFieldTemplate,
|
||||||
|
ColorCollection: buildColorCollectionInputFieldTemplate,
|
||||||
|
ColorField: buildColorInputFieldTemplate,
|
||||||
|
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
|
||||||
|
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
|
||||||
|
ConditioningField: buildConditioningInputFieldTemplate,
|
||||||
|
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
|
||||||
|
ControlCollection: buildControlCollectionInputFieldTemplate,
|
||||||
|
ControlField: buildControlInputFieldTemplate,
|
||||||
|
ControlNetModelField: buildControlNetModelInputFieldTemplate,
|
||||||
|
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
|
||||||
|
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
|
||||||
|
enum: buildEnumInputFieldTemplate,
|
||||||
|
float: buildFloatInputFieldTemplate,
|
||||||
|
FloatCollection: buildFloatCollectionInputFieldTemplate,
|
||||||
|
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
|
||||||
|
ImageCollection: buildImageCollectionInputFieldTemplate,
|
||||||
|
ImageField: buildImageInputFieldTemplate,
|
||||||
|
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
|
||||||
|
integer: buildIntegerInputFieldTemplate,
|
||||||
|
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||||
|
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||||
|
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||||
|
LatentsField: buildLatentsInputFieldTemplate,
|
||||||
|
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||||
|
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||||
|
MainModelField: buildMainModelInputFieldTemplate,
|
||||||
|
Scheduler: buildSchedulerInputFieldTemplate,
|
||||||
|
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||||
|
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
|
||||||
|
string: buildStringInputFieldTemplate,
|
||||||
|
StringCollection: buildStringCollectionInputFieldTemplate,
|
||||||
|
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
||||||
|
UNetField: buildUNetInputFieldTemplate,
|
||||||
|
VaeField: buildVaeInputFieldTemplate,
|
||||||
|
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||||
|
};
|
||||||
|
|
||||||
|
const isTemplatedFieldType = (
|
||||||
|
fieldType: string | undefined
|
||||||
|
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
|
||||||
|
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an input field from an invocation schema property.
|
* Builds an input field from an invocation schema property.
|
||||||
* @param fieldSchema The schema object
|
* @param fieldSchema The schema object
|
||||||
@ -474,7 +886,8 @@ export const buildInputFieldTemplate = (
|
|||||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
||||||
|
|
||||||
const extra = {
|
const extra = {
|
||||||
input,
|
// TODO: Can we support polymorphic inputs in the UI?
|
||||||
|
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
|
||||||
ui_hidden,
|
ui_hidden,
|
||||||
ui_component,
|
ui_component,
|
||||||
ui_type,
|
ui_type,
|
||||||
@ -490,146 +903,12 @@ export const buildInputFieldTemplate = (
|
|||||||
...extra,
|
...extra,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (fieldType === 'ImageField') {
|
if (!isTemplatedFieldType(fieldType)) {
|
||||||
return buildImageInputFieldTemplate({
|
return;
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
if (fieldType === 'ImageCollection') {
|
|
||||||
return buildImageCollectionInputFieldTemplate({
|
return TEMPLATE_BUILDER_MAP[fieldType]({
|
||||||
schemaObject: fieldSchema,
|
schemaObject: fieldSchema,
|
||||||
baseField,
|
baseField,
|
||||||
});
|
});
|
||||||
}
|
|
||||||
if (fieldType === 'DenoiseMaskField') {
|
|
||||||
return buildDenoiseMaskInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'LatentsField') {
|
|
||||||
return buildLatentsInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ConditioningField') {
|
|
||||||
return buildConditioningInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'UNetField') {
|
|
||||||
return buildUNetInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ClipField') {
|
|
||||||
return buildClipInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'VaeField') {
|
|
||||||
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
|
|
||||||
}
|
|
||||||
if (fieldType === 'ControlField') {
|
|
||||||
return buildControlInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'MainModelField') {
|
|
||||||
return buildMainModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'SDXLRefinerModelField') {
|
|
||||||
return buildRefinerModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'SDXLMainModelField') {
|
|
||||||
return buildSDXLMainModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'VaeModelField') {
|
|
||||||
return buildVaeModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'LoRAModelField') {
|
|
||||||
return buildLoRAModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ControlNetModelField') {
|
|
||||||
return buildControlNetModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'enum') {
|
|
||||||
return buildEnumInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'integer') {
|
|
||||||
return buildIntegerInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'float') {
|
|
||||||
return buildFloatInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'string') {
|
|
||||||
return buildStringInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'boolean') {
|
|
||||||
return buildBooleanInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'Collection') {
|
|
||||||
return buildCollectionInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'CollectionItem') {
|
|
||||||
return buildCollectionItemInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ColorField') {
|
|
||||||
return buildColorInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'Scheduler') {
|
|
||||||
return buildSchedulerInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
};
|
};
|
||||||
|
@ -1,104 +1,79 @@
|
|||||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||||
|
|
||||||
|
const FIELD_VALUE_FALLBACK_MAP = {
|
||||||
|
'enum.number': 0,
|
||||||
|
'enum.string': '',
|
||||||
|
boolean: false,
|
||||||
|
BooleanCollection: [],
|
||||||
|
BooleanPolymorphic: false,
|
||||||
|
ClipField: undefined,
|
||||||
|
Collection: [],
|
||||||
|
CollectionItem: undefined,
|
||||||
|
ColorCollection: [],
|
||||||
|
ColorField: undefined,
|
||||||
|
ColorPolymorphic: undefined,
|
||||||
|
ConditioningCollection: [],
|
||||||
|
ConditioningField: undefined,
|
||||||
|
ConditioningPolymorphic: undefined,
|
||||||
|
ControlCollection: [],
|
||||||
|
ControlField: undefined,
|
||||||
|
ControlNetModelField: undefined,
|
||||||
|
ControlPolymorphic: undefined,
|
||||||
|
DenoiseMaskField: undefined,
|
||||||
|
float: 0,
|
||||||
|
FloatCollection: [],
|
||||||
|
FloatPolymorphic: 0,
|
||||||
|
ImageCollection: [],
|
||||||
|
ImageField: undefined,
|
||||||
|
ImagePolymorphic: undefined,
|
||||||
|
integer: 0,
|
||||||
|
IntegerCollection: [],
|
||||||
|
IntegerPolymorphic: 0,
|
||||||
|
LatentsCollection: [],
|
||||||
|
LatentsField: undefined,
|
||||||
|
LatentsPolymorphic: undefined,
|
||||||
|
LoRAModelField: undefined,
|
||||||
|
MainModelField: undefined,
|
||||||
|
ONNXModelField: undefined,
|
||||||
|
Scheduler: 'euler',
|
||||||
|
SDXLMainModelField: undefined,
|
||||||
|
SDXLRefinerModelField: undefined,
|
||||||
|
string: '',
|
||||||
|
StringCollection: [],
|
||||||
|
StringPolymorphic: '',
|
||||||
|
UNetField: undefined,
|
||||||
|
VaeField: undefined,
|
||||||
|
VaeModelField: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
export const buildInputFieldValue = (
|
export const buildInputFieldValue = (
|
||||||
id: string,
|
id: string,
|
||||||
template: InputFieldTemplate
|
template: InputFieldTemplate
|
||||||
): InputFieldValue => {
|
): InputFieldValue => {
|
||||||
const fieldValue: InputFieldValue = {
|
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
|
||||||
|
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
|
||||||
|
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
|
||||||
|
// `InputFieldValue` union, but TS doesn't seem to like it...
|
||||||
|
const fieldValue = {
|
||||||
id,
|
id,
|
||||||
name: template.name,
|
name: template.name,
|
||||||
type: template.type,
|
type: template.type,
|
||||||
label: '',
|
label: '',
|
||||||
fieldKind: 'input',
|
fieldKind: 'input',
|
||||||
};
|
} as InputFieldValue;
|
||||||
|
|
||||||
if (template.type === 'string') {
|
|
||||||
fieldValue.value = template.default ?? '';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'integer') {
|
|
||||||
fieldValue.value = template.default ?? 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'float') {
|
|
||||||
fieldValue.value = template.default ?? 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'boolean') {
|
|
||||||
fieldValue.value = template.default ?? false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'enum') {
|
if (template.type === 'enum') {
|
||||||
if (template.enumType === 'number') {
|
if (template.enumType === 'number') {
|
||||||
fieldValue.value = template.default ?? 0;
|
fieldValue.value =
|
||||||
|
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
|
||||||
}
|
}
|
||||||
if (template.enumType === 'string') {
|
if (template.enumType === 'string') {
|
||||||
fieldValue.value = template.default ?? '';
|
fieldValue.value =
|
||||||
|
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
fieldValue.value =
|
||||||
if (template.type === 'Collection') {
|
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
|
||||||
fieldValue.value = template.default ?? 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ImageField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ImageCollection') {
|
|
||||||
fieldValue.value = [];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'DenoiseMaskField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'LatentsField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ConditioningField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'UNetField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ClipField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'VaeField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ControlField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'MainModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'SDXLRefinerModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'VaeModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'LoRAModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ControlNetModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'Scheduler') {
|
|
||||||
fieldValue.value = 'euler';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
|
@ -10,7 +10,8 @@ import {
|
|||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
INPAINT_IMAGE_RESIZE_UP,
|
INPAINT_IMAGE_RESIZE_UP,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MASK_BLUR,
|
MASK_COMBINE,
|
||||||
|
MASK_RESIZE_UP,
|
||||||
METADATA_ACCUMULATOR,
|
METADATA_ACCUMULATOR,
|
||||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
@ -46,6 +47,8 @@ export const addSDXLRefinerToGraph = (
|
|||||||
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
|
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
|
||||||
const { boundingBoxScaleMethod } = state.canvas;
|
const { boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
boundingBoxScaleMethod
|
boundingBoxScaleMethod
|
||||||
);
|
);
|
||||||
@ -231,7 +234,7 @@ export const addSDXLRefinerToGraph = (
|
|||||||
type: 'create_denoise_mask',
|
type: 'create_denoise_mask',
|
||||||
id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (isUsingScaledDimensions) {
|
if (isUsingScaledDimensions) {
|
||||||
@ -257,7 +260,7 @@ export const addSDXLRefinerToGraph = (
|
|||||||
graph.edges.push(
|
graph.edges.push(
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_BLUR,
|
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
|
@ -2,6 +2,7 @@ import { RootState } from 'app/store/store';
|
|||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
CANVAS_INPAINT_GRAPH,
|
CANVAS_INPAINT_GRAPH,
|
||||||
CANVAS_OUTPAINT_GRAPH,
|
CANVAS_OUTPAINT_GRAPH,
|
||||||
@ -31,7 +32,7 @@ export const addVAEToGraph = (
|
|||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||||
): void => {
|
): void => {
|
||||||
const { vae } = state.generation;
|
const { vae, canvasCoherenceMode } = state.generation;
|
||||||
const { boundingBoxScaleMethod } = state.canvas;
|
const { boundingBoxScaleMethod } = state.canvas;
|
||||||
const { shouldUseSDXLRefiner } = state.sdxl;
|
const { shouldUseSDXLRefiner } = state.sdxl;
|
||||||
|
|
||||||
@ -146,6 +147,20 @@ export const addVAEToGraph = (
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Handle Coherence Mode
|
||||||
|
if (canvasCoherenceMode !== 'unmasked') {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
|
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (shouldUseSDXLRefiner) {
|
if (shouldUseSDXLRefiner) {
|
||||||
|
@ -59,6 +59,8 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
boundingBoxScaleMethod
|
boundingBoxScaleMethod
|
||||||
);
|
);
|
||||||
@ -245,7 +247,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
graph.nodes[CANVAS_OUTPUT] = {
|
graph.nodes[CANVAS_OUTPUT] = {
|
||||||
id: CANVAS_OUTPUT,
|
id: CANVAS_OUTPUT,
|
||||||
@ -292,7 +294,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: CANVAS_OUTPUT,
|
id: CANVAS_OUTPUT,
|
||||||
is_intermediate: !shouldAutoSave,
|
is_intermediate: !shouldAutoSave,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
|
|
||||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
||||||
|
@ -6,6 +6,7 @@ import {
|
|||||||
ImageBlurInvocation,
|
ImageBlurInvocation,
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
|
MaskEdgeInvocation,
|
||||||
NoiseInvocation,
|
NoiseInvocation,
|
||||||
RandomIntInvocation,
|
RandomIntInvocation,
|
||||||
RangeOfSizeInvocation,
|
RangeOfSizeInvocation,
|
||||||
@ -18,6 +19,8 @@ import { addVAEToGraph } from './addVAEToGraph';
|
|||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
CANVAS_COHERENCE_MASK_EDGE,
|
||||||
CANVAS_COHERENCE_NOISE,
|
CANVAS_COHERENCE_NOISE,
|
||||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||||
CANVAS_INPAINT_GRAPH,
|
CANVAS_INPAINT_GRAPH,
|
||||||
@ -67,6 +70,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
shouldUseCpuNoise,
|
shouldUseCpuNoise,
|
||||||
maskBlur,
|
maskBlur,
|
||||||
maskBlurMethod,
|
maskBlurMethod,
|
||||||
|
canvasCoherenceMode,
|
||||||
canvasCoherenceSteps,
|
canvasCoherenceSteps,
|
||||||
canvasCoherenceStrength,
|
canvasCoherenceStrength,
|
||||||
clipSkip,
|
clipSkip,
|
||||||
@ -89,6 +93,12 @@ export const buildCanvasInpaintGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
|
boundingBoxScaleMethod
|
||||||
|
);
|
||||||
|
|
||||||
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
||||||
|
|
||||||
const use_cpu = shouldUseNoiseSettings
|
const use_cpu = shouldUseNoiseSettings
|
||||||
@ -133,13 +143,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
type: 'i2l',
|
type: 'i2l',
|
||||||
id: INPAINT_IMAGE,
|
id: INPAINT_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
|
||||||
[INPAINT_CREATE_MASK]: {
|
|
||||||
type: 'create_denoise_mask',
|
|
||||||
id: INPAINT_CREATE_MASK,
|
|
||||||
is_intermediate: true,
|
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
|
||||||
},
|
},
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
@ -147,6 +151,12 @@ export const buildCanvasInpaintGraph = (
|
|||||||
use_cpu,
|
use_cpu,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
},
|
},
|
||||||
|
[INPAINT_CREATE_MASK]: {
|
||||||
|
type: 'create_denoise_mask',
|
||||||
|
id: INPAINT_CREATE_MASK,
|
||||||
|
is_intermediate: true,
|
||||||
|
fp32,
|
||||||
|
},
|
||||||
[DENOISE_LATENTS]: {
|
[DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
id: DENOISE_LATENTS,
|
id: DENOISE_LATENTS,
|
||||||
@ -171,7 +181,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
|
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
id: DENOISE_LATENTS,
|
id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
steps: canvasCoherenceSteps,
|
steps: canvasCoherenceSteps,
|
||||||
cfg_scale: cfg_scale,
|
cfg_scale: cfg_scale,
|
||||||
@ -183,7 +193,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[CANVAS_OUTPUT]: {
|
[CANVAS_OUTPUT]: {
|
||||||
type: 'color_correct',
|
type: 'color_correct',
|
||||||
@ -418,7 +428,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||||
|
|
||||||
@ -581,6 +591,116 @@ export const buildCanvasInpaintGraph = (
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Coherence Mode
|
||||||
|
if (canvasCoherenceMode !== 'unmasked') {
|
||||||
|
// Create Mask If Coherence Mode Is Not Full
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
type: 'create_denoise_mask',
|
||||||
|
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
is_intermediate: true,
|
||||||
|
fp32,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Image Input For Mask Creation
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
...(graph.nodes[
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||||
|
] as CreateDenoiseMaskInvocation),
|
||||||
|
image: canvasInitImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Mask If Coherence Mode Is Mask
|
||||||
|
if (canvasCoherenceMode === 'mask') {
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
...(graph.nodes[
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||||
|
] as CreateDenoiseMaskInvocation),
|
||||||
|
mask: canvasMaskImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Mask Edge If Coherence Mode Is Edge
|
||||||
|
if (canvasCoherenceMode === 'edge') {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||||
|
type: 'mask_edge',
|
||||||
|
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
is_intermediate: true,
|
||||||
|
edge_blur: maskBlur,
|
||||||
|
edge_size: maskBlur * 2,
|
||||||
|
low_threshold: 100,
|
||||||
|
high_threshold: 200,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Scaled Dimensions For Mask Edge
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||||
|
...(graph.nodes[CANVAS_COHERENCE_MASK_EDGE] as MaskEdgeInvocation),
|
||||||
|
image: canvasMaskImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plug Denoise Mask To Coherence Denoise Latents
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Seed
|
// Handle Seed
|
||||||
if (shouldRandomizeSeed) {
|
if (shouldRandomizeSeed) {
|
||||||
// Random int node to generate the starting seed
|
// Random int node to generate the starting seed
|
||||||
|
@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageBlurInvocation,
|
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
InfillPatchMatchInvocation,
|
InfillPatchMatchInvocation,
|
||||||
@ -19,6 +18,8 @@ import { addVAEToGraph } from './addVAEToGraph';
|
|||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
CANVAS_COHERENCE_MASK_EDGE,
|
||||||
CANVAS_COHERENCE_NOISE,
|
CANVAS_COHERENCE_NOISE,
|
||||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||||
CANVAS_OUTPAINT_GRAPH,
|
CANVAS_OUTPAINT_GRAPH,
|
||||||
@ -34,7 +35,6 @@ import {
|
|||||||
ITERATE,
|
ITERATE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
MASK_BLUR,
|
|
||||||
MASK_COMBINE,
|
MASK_COMBINE,
|
||||||
MASK_FROM_ALPHA,
|
MASK_FROM_ALPHA,
|
||||||
MASK_RESIZE_DOWN,
|
MASK_RESIZE_DOWN,
|
||||||
@ -71,10 +71,11 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
shouldUseNoiseSettings,
|
shouldUseNoiseSettings,
|
||||||
shouldUseCpuNoise,
|
shouldUseCpuNoise,
|
||||||
maskBlur,
|
maskBlur,
|
||||||
maskBlurMethod,
|
canvasCoherenceMode,
|
||||||
canvasCoherenceSteps,
|
canvasCoherenceSteps,
|
||||||
canvasCoherenceStrength,
|
canvasCoherenceStrength,
|
||||||
tileSize,
|
infillTileSize,
|
||||||
|
infillPatchmatchDownscaleSize,
|
||||||
infillMethod,
|
infillMethod,
|
||||||
clipSkip,
|
clipSkip,
|
||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
@ -96,6 +97,12 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
|
boundingBoxScaleMethod
|
||||||
|
);
|
||||||
|
|
||||||
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
||||||
|
|
||||||
const use_cpu = shouldUseNoiseSettings
|
const use_cpu = shouldUseNoiseSettings
|
||||||
@ -141,18 +148,11 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
mask2: canvasMaskImage,
|
mask2: canvasMaskImage,
|
||||||
},
|
},
|
||||||
[MASK_BLUR]: {
|
|
||||||
type: 'img_blur',
|
|
||||||
id: MASK_BLUR,
|
|
||||||
is_intermediate: true,
|
|
||||||
radius: maskBlur,
|
|
||||||
blur_type: maskBlurMethod,
|
|
||||||
},
|
|
||||||
[INPAINT_IMAGE]: {
|
[INPAINT_IMAGE]: {
|
||||||
type: 'i2l',
|
type: 'i2l',
|
||||||
id: INPAINT_IMAGE,
|
id: INPAINT_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
@ -164,7 +164,7 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
type: 'create_denoise_mask',
|
type: 'create_denoise_mask',
|
||||||
id: INPAINT_CREATE_MASK,
|
id: INPAINT_CREATE_MASK,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[DENOISE_LATENTS]: {
|
[DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -202,7 +202,7 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[CANVAS_OUTPUT]: {
|
[CANVAS_OUTPUT]: {
|
||||||
type: 'color_correct',
|
type: 'color_correct',
|
||||||
@ -333,7 +333,7 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
// Create Inpaint Mask
|
// Create Inpaint Mask
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_BLUR,
|
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -443,6 +443,16 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_INFILL,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Decode the result from Inpaint
|
// Decode the result from Inpaint
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -463,6 +473,7 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
type: 'infill_patchmatch',
|
type: 'infill_patchmatch',
|
||||||
id: INPAINT_INFILL,
|
id: INPAINT_INFILL,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
|
downscale: infillPatchmatchDownscaleSize,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -474,17 +485,25 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (infillMethod === 'cv2') {
|
||||||
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
|
type: 'infill_cv2',
|
||||||
|
id: INPAINT_INFILL,
|
||||||
|
is_intermediate: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (infillMethod === 'tile') {
|
if (infillMethod === 'tile') {
|
||||||
graph.nodes[INPAINT_INFILL] = {
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
type: 'infill_tile',
|
type: 'infill_tile',
|
||||||
id: INPAINT_INFILL,
|
id: INPAINT_INFILL,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
tile_size: tileSize,
|
tile_size: infillTileSize,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||||
|
|
||||||
@ -546,16 +565,6 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: INPAINT_INFILL,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: INPAINT_CREATE_MASK,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// Take combined mask and resize and then blur
|
// Take combined mask and resize and then blur
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -567,16 +576,7 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: MASK_RESIZE_UP,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: MASK_BLUR,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// Resize Results Down
|
// Resize Results Down
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -658,32 +658,8 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
|
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
|
||||||
image: canvasInitImage,
|
image: canvasInitImage,
|
||||||
};
|
};
|
||||||
graph.nodes[MASK_BLUR] = {
|
|
||||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push(
|
graph.edges.push(
|
||||||
// Take combined mask and plug it to blur
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: MASK_COMBINE,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: MASK_BLUR,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: INPAINT_INFILL,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: INPAINT_CREATE_MASK,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// Color Correct The Inpainted Result
|
// Color Correct The Inpainted Result
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -707,7 +683,7 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_BLUR,
|
node_id: MASK_COMBINE,
|
||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -718,6 +694,115 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Coherence Mode
|
||||||
|
if (canvasCoherenceMode !== 'unmasked') {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
type: 'create_denoise_mask',
|
||||||
|
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
is_intermediate: true,
|
||||||
|
fp32,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Image Input For Mask Creation
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_INFILL,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create Mask If Coherence Mode Is Mask
|
||||||
|
if (canvasCoherenceMode === 'mask') {
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_COMBINE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (canvasCoherenceMode === 'edge') {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||||
|
type: 'mask_edge',
|
||||||
|
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
is_intermediate: true,
|
||||||
|
edge_blur: maskBlur,
|
||||||
|
edge_size: maskBlur * 2,
|
||||||
|
low_threshold: 100,
|
||||||
|
high_threshold: 200,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Scaled Dimensions For Mask Edge
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_COMBINE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plug Denoise Mask To Coherence Denoise Latents
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Seed
|
// Handle Seed
|
||||||
if (shouldRandomizeSeed) {
|
if (shouldRandomizeSeed) {
|
||||||
// Random int node to generate the starting seed
|
// Random int node to generate the starting seed
|
||||||
|
@ -67,6 +67,8 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
boundingBoxScaleMethod
|
boundingBoxScaleMethod
|
||||||
);
|
);
|
||||||
@ -133,7 +135,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
type: 'i2l',
|
type: 'i2l',
|
||||||
id: IMAGE_TO_LATENTS,
|
id: IMAGE_TO_LATENTS,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -258,7 +260,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
graph.nodes[CANVAS_OUTPUT] = {
|
graph.nodes[CANVAS_OUTPUT] = {
|
||||||
id: CANVAS_OUTPUT,
|
id: CANVAS_OUTPUT,
|
||||||
@ -305,7 +307,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: CANVAS_OUTPUT,
|
id: CANVAS_OUTPUT,
|
||||||
is_intermediate: !shouldAutoSave,
|
is_intermediate: !shouldAutoSave,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
|
|
||||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
||||||
|
@ -6,6 +6,7 @@ import {
|
|||||||
ImageBlurInvocation,
|
ImageBlurInvocation,
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
|
MaskEdgeInvocation,
|
||||||
NoiseInvocation,
|
NoiseInvocation,
|
||||||
RandomIntInvocation,
|
RandomIntInvocation,
|
||||||
RangeOfSizeInvocation,
|
RangeOfSizeInvocation,
|
||||||
@ -19,6 +20,8 @@ import { addVAEToGraph } from './addVAEToGraph';
|
|||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
CANVAS_COHERENCE_MASK_EDGE,
|
||||||
CANVAS_COHERENCE_NOISE,
|
CANVAS_COHERENCE_NOISE,
|
||||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
@ -68,6 +71,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
shouldUseCpuNoise,
|
shouldUseCpuNoise,
|
||||||
maskBlur,
|
maskBlur,
|
||||||
maskBlurMethod,
|
maskBlurMethod,
|
||||||
|
canvasCoherenceMode,
|
||||||
canvasCoherenceSteps,
|
canvasCoherenceSteps,
|
||||||
canvasCoherenceStrength,
|
canvasCoherenceStrength,
|
||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
@ -96,6 +100,12 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
|
boundingBoxScaleMethod
|
||||||
|
);
|
||||||
|
|
||||||
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
||||||
|
|
||||||
const use_cpu = shouldUseNoiseSettings
|
const use_cpu = shouldUseNoiseSettings
|
||||||
@ -137,7 +147,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
type: 'i2l',
|
type: 'i2l',
|
||||||
id: INPAINT_IMAGE,
|
id: INPAINT_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
@ -149,7 +159,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
type: 'create_denoise_mask',
|
type: 'create_denoise_mask',
|
||||||
id: INPAINT_CREATE_MASK,
|
id: INPAINT_CREATE_MASK,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -177,7 +187,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
},
|
},
|
||||||
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
|
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
id: SDXL_DENOISE_LATENTS,
|
id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
steps: canvasCoherenceSteps,
|
steps: canvasCoherenceSteps,
|
||||||
cfg_scale: cfg_scale,
|
cfg_scale: cfg_scale,
|
||||||
@ -189,7 +199,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[CANVAS_OUTPUT]: {
|
[CANVAS_OUTPUT]: {
|
||||||
type: 'color_correct',
|
type: 'color_correct',
|
||||||
@ -433,7 +443,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||||
|
|
||||||
@ -596,6 +606,116 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Coherence Mode
|
||||||
|
if (canvasCoherenceMode !== 'unmasked') {
|
||||||
|
// Create Mask If Coherence Mode Is Not Full
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
type: 'create_denoise_mask',
|
||||||
|
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
is_intermediate: true,
|
||||||
|
fp32,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Image Input For Mask Creation
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
...(graph.nodes[
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||||
|
] as CreateDenoiseMaskInvocation),
|
||||||
|
image: canvasInitImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Mask If Coherence Mode Is Mask
|
||||||
|
if (canvasCoherenceMode === 'mask') {
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
...(graph.nodes[
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||||
|
] as CreateDenoiseMaskInvocation),
|
||||||
|
mask: canvasMaskImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Mask Edge If Coherence Mode Is Edge
|
||||||
|
if (canvasCoherenceMode === 'edge') {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||||
|
type: 'mask_edge',
|
||||||
|
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
is_intermediate: true,
|
||||||
|
edge_blur: maskBlur,
|
||||||
|
edge_size: maskBlur * 2,
|
||||||
|
low_threshold: 100,
|
||||||
|
high_threshold: 200,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Scaled Dimensions For Mask Edge
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||||
|
...(graph.nodes[CANVAS_COHERENCE_MASK_EDGE] as MaskEdgeInvocation),
|
||||||
|
image: canvasMaskImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plug Denoise Mask To Coherence Denoise Latents
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Seed
|
// Handle Seed
|
||||||
if (shouldRandomizeSeed) {
|
if (shouldRandomizeSeed) {
|
||||||
// Random int node to generate the starting seed
|
// Random int node to generate the starting seed
|
||||||
|
@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageBlurInvocation,
|
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
InfillPatchMatchInvocation,
|
InfillPatchMatchInvocation,
|
||||||
@ -20,6 +19,8 @@ import { addVAEToGraph } from './addVAEToGraph';
|
|||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
CANVAS_COHERENCE_MASK_EDGE,
|
||||||
CANVAS_COHERENCE_NOISE,
|
CANVAS_COHERENCE_NOISE,
|
||||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
@ -31,7 +32,6 @@ import {
|
|||||||
INPAINT_INFILL_RESIZE_DOWN,
|
INPAINT_INFILL_RESIZE_DOWN,
|
||||||
ITERATE,
|
ITERATE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MASK_BLUR,
|
|
||||||
MASK_COMBINE,
|
MASK_COMBINE,
|
||||||
MASK_FROM_ALPHA,
|
MASK_FROM_ALPHA,
|
||||||
MASK_RESIZE_DOWN,
|
MASK_RESIZE_DOWN,
|
||||||
@ -72,10 +72,11 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
shouldUseNoiseSettings,
|
shouldUseNoiseSettings,
|
||||||
shouldUseCpuNoise,
|
shouldUseCpuNoise,
|
||||||
maskBlur,
|
maskBlur,
|
||||||
maskBlurMethod,
|
canvasCoherenceMode,
|
||||||
canvasCoherenceSteps,
|
canvasCoherenceSteps,
|
||||||
canvasCoherenceStrength,
|
canvasCoherenceStrength,
|
||||||
tileSize,
|
infillTileSize,
|
||||||
|
infillPatchmatchDownscaleSize,
|
||||||
infillMethod,
|
infillMethod,
|
||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
seamlessYAxis,
|
seamlessYAxis,
|
||||||
@ -103,6 +104,12 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
|
boundingBoxScaleMethod
|
||||||
|
);
|
||||||
|
|
||||||
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
||||||
|
|
||||||
const use_cpu = shouldUseNoiseSettings
|
const use_cpu = shouldUseNoiseSettings
|
||||||
@ -145,18 +152,11 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
mask2: canvasMaskImage,
|
mask2: canvasMaskImage,
|
||||||
},
|
},
|
||||||
[MASK_BLUR]: {
|
|
||||||
type: 'img_blur',
|
|
||||||
id: MASK_BLUR,
|
|
||||||
is_intermediate: true,
|
|
||||||
radius: maskBlur,
|
|
||||||
blur_type: maskBlurMethod,
|
|
||||||
},
|
|
||||||
[INPAINT_IMAGE]: {
|
[INPAINT_IMAGE]: {
|
||||||
type: 'i2l',
|
type: 'i2l',
|
||||||
id: INPAINT_IMAGE,
|
id: INPAINT_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
@ -168,7 +168,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
type: 'create_denoise_mask',
|
type: 'create_denoise_mask',
|
||||||
id: INPAINT_CREATE_MASK,
|
id: INPAINT_CREATE_MASK,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -208,7 +208,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[CANVAS_OUTPUT]: {
|
[CANVAS_OUTPUT]: {
|
||||||
type: 'color_correct',
|
type: 'color_correct',
|
||||||
@ -348,7 +348,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
// Create Inpaint Mask
|
// Create Inpaint Mask
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_BLUR,
|
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -410,7 +410,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: SDXL_MODEL_LOADER,
|
node_id: modelLoaderNodeId,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -458,6 +458,16 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_INFILL,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Decode inpainted latents to image
|
// Decode inpainted latents to image
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -473,12 +483,12 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Add Infill Nodes
|
// Add Infill Nodes
|
||||||
|
|
||||||
if (infillMethod === 'patchmatch') {
|
if (infillMethod === 'patchmatch') {
|
||||||
graph.nodes[INPAINT_INFILL] = {
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
type: 'infill_patchmatch',
|
type: 'infill_patchmatch',
|
||||||
id: INPAINT_INFILL,
|
id: INPAINT_INFILL,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
|
downscale: infillPatchmatchDownscaleSize,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,17 +500,25 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (infillMethod === 'cv2') {
|
||||||
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
|
type: 'infill_cv2',
|
||||||
|
id: INPAINT_INFILL,
|
||||||
|
is_intermediate: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (infillMethod === 'tile') {
|
if (infillMethod === 'tile') {
|
||||||
graph.nodes[INPAINT_INFILL] = {
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
type: 'infill_tile',
|
type: 'infill_tile',
|
||||||
id: INPAINT_INFILL,
|
id: INPAINT_INFILL,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
tile_size: tileSize,
|
tile_size: infillTileSize,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||||
|
|
||||||
@ -562,16 +580,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: INPAINT_INFILL,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: INPAINT_CREATE_MASK,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// Take combined mask and resize and then blur
|
// Take combined mask and resize and then blur
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -583,16 +592,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: MASK_RESIZE_UP,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: MASK_BLUR,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// Resize Results Down
|
// Resize Results Down
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -674,32 +674,8 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
|
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
|
||||||
image: canvasInitImage,
|
image: canvasInitImage,
|
||||||
};
|
};
|
||||||
graph.nodes[MASK_BLUR] = {
|
|
||||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push(
|
graph.edges.push(
|
||||||
// Take combined mask and plug it to blur
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: MASK_COMBINE,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: MASK_BLUR,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: INPAINT_INFILL,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: INPAINT_CREATE_MASK,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// Color Correct The Inpainted Result
|
// Color Correct The Inpainted Result
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -723,7 +699,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MASK_BLUR,
|
node_id: MASK_COMBINE,
|
||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -734,7 +710,116 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle seed
|
// Handle Coherence Mode
|
||||||
|
if (canvasCoherenceMode !== 'unmasked') {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||||
|
type: 'create_denoise_mask',
|
||||||
|
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
is_intermediate: true,
|
||||||
|
fp32,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Image Input For Mask Creation
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_INFILL,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create Mask If Coherence Mode Is Mask
|
||||||
|
if (canvasCoherenceMode === 'mask') {
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_COMBINE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (canvasCoherenceMode === 'edge') {
|
||||||
|
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||||
|
type: 'mask_edge',
|
||||||
|
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
is_intermediate: true,
|
||||||
|
edge_blur: maskBlur,
|
||||||
|
edge_size: maskBlur * 2,
|
||||||
|
low_threshold: 100,
|
||||||
|
high_threshold: 200,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle Scaled Dimensions For Mask Edge
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_COMBINE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plug Denoise Mask To Coherence Denoise Latents
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Seed
|
||||||
if (shouldRandomizeSeed) {
|
if (shouldRandomizeSeed) {
|
||||||
// Random int node to generate the starting seed
|
// Random int node to generate the starting seed
|
||||||
const randomIntNode: RandomIntInvocation = {
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
@ -61,6 +61,8 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
boundingBoxScaleMethod
|
boundingBoxScaleMethod
|
||||||
);
|
);
|
||||||
@ -252,7 +254,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
|
|
||||||
graph.nodes[CANVAS_OUTPUT] = {
|
graph.nodes[CANVAS_OUTPUT] = {
|
||||||
@ -290,7 +292,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
||||||
id: CANVAS_OUTPUT,
|
id: CANVAS_OUTPUT,
|
||||||
is_intermediate: !shouldAutoSave,
|
is_intermediate: !shouldAutoSave,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
|
@ -59,6 +59,8 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
shouldAutoSave,
|
shouldAutoSave,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
boundingBoxScaleMethod
|
boundingBoxScaleMethod
|
||||||
);
|
);
|
||||||
@ -238,7 +240,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
|
|
||||||
graph.nodes[CANVAS_OUTPUT] = {
|
graph.nodes[CANVAS_OUTPUT] = {
|
||||||
@ -276,7 +278,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
||||||
id: CANVAS_OUTPUT,
|
id: CANVAS_OUTPUT,
|
||||||
is_intermediate: !shouldAutoSave,
|
is_intermediate: !shouldAutoSave,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
};
|
};
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
|
@ -84,6 +84,8 @@ export const buildLinearImageToImageGraph = (
|
|||||||
throw new Error('No model found in state');
|
throw new Error('No model found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
||||||
|
|
||||||
const use_cpu = shouldUseNoiseSettings
|
const use_cpu = shouldUseNoiseSettings
|
||||||
@ -122,7 +124,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[DENOISE_LATENTS]: {
|
[DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -140,7 +142,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
// image: {
|
// image: {
|
||||||
// image_name: initialImage.image_name,
|
// image_name: initialImage.image_name,
|
||||||
// },
|
// },
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
|
@ -84,6 +84,8 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
throw new Error('No model found in state');
|
throw new Error('No model found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
// Model Loader ID
|
// Model Loader ID
|
||||||
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
||||||
|
|
||||||
@ -124,7 +126,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -144,7 +146,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
// image: {
|
// image: {
|
||||||
// image_name: initialImage.image_name,
|
// image_name: initialImage.image_name,
|
||||||
// },
|
// },
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
|
@ -62,6 +62,8 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
throw new Error('No model found in state');
|
throw new Error('No model found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
// Construct Style Prompt
|
// Construct Style Prompt
|
||||||
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
|
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
|
||||||
craftSDXLStylePrompt(state, shouldConcatSDXLStylePrompt);
|
craftSDXLStylePrompt(state, shouldConcatSDXLStylePrompt);
|
||||||
@ -118,7 +120,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
|
@ -57,6 +57,8 @@ export const buildLinearTextToImageGraph = (
|
|||||||
throw new Error('No model found in state');
|
throw new Error('No model found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
const isUsingOnnxModel = model.model_type === 'onnx';
|
const isUsingOnnxModel = model.model_type === 'onnx';
|
||||||
|
|
||||||
let modelLoaderNodeId = isUsingOnnxModel
|
let modelLoaderNodeId = isUsingOnnxModel
|
||||||
@ -139,7 +141,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
|
@ -27,11 +27,15 @@ export const INPAINT_INFILL = 'inpaint_infill';
|
|||||||
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
|
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
|
||||||
export const INPAINT_FINAL_IMAGE = 'inpaint_final_image';
|
export const INPAINT_FINAL_IMAGE = 'inpaint_final_image';
|
||||||
export const INPAINT_CREATE_MASK = 'inpaint_create_mask';
|
export const INPAINT_CREATE_MASK = 'inpaint_create_mask';
|
||||||
|
export const INPAINT_MASK = 'inpaint_mask';
|
||||||
export const CANVAS_COHERENCE_DENOISE_LATENTS =
|
export const CANVAS_COHERENCE_DENOISE_LATENTS =
|
||||||
'canvas_coherence_denoise_latents';
|
'canvas_coherence_denoise_latents';
|
||||||
export const CANVAS_COHERENCE_NOISE = 'canvas_coherence_noise';
|
export const CANVAS_COHERENCE_NOISE = 'canvas_coherence_noise';
|
||||||
export const CANVAS_COHERENCE_NOISE_INCREMENT =
|
export const CANVAS_COHERENCE_NOISE_INCREMENT =
|
||||||
'canvas_coherence_noise_increment';
|
'canvas_coherence_noise_increment';
|
||||||
|
export const CANVAS_COHERENCE_MASK_EDGE = 'canvas_coherence_mask_edge';
|
||||||
|
export const CANVAS_COHERENCE_INPAINT_CREATE_MASK =
|
||||||
|
'canvas_coherence_inpaint_create_mask';
|
||||||
export const MASK_FROM_ALPHA = 'tomask';
|
export const MASK_FROM_ALPHA = 'tomask';
|
||||||
export const MASK_EDGE = 'mask_edge';
|
export const MASK_EDGE = 'mask_edge';
|
||||||
export const MASK_BLUR = 'mask_blur';
|
export const MASK_BLUR = 'mask_blur';
|
||||||
|
@ -73,6 +73,7 @@ export const parseSchema = (
|
|||||||
const title = schema.title.replace('Invocation', '');
|
const title = schema.title.replace('Invocation', '');
|
||||||
const tags = schema.tags ?? [];
|
const tags = schema.tags ?? [];
|
||||||
const description = schema.description ?? '';
|
const description = schema.description ?? '';
|
||||||
|
const version = schema.version ?? '';
|
||||||
|
|
||||||
const inputs = reduce(
|
const inputs = reduce(
|
||||||
schema.properties,
|
schema.properties,
|
||||||
@ -225,11 +226,12 @@ export const parseSchema = (
|
|||||||
const invocation: InvocationTemplate = {
|
const invocation: InvocationTemplate = {
|
||||||
title,
|
title,
|
||||||
type,
|
type,
|
||||||
|
version,
|
||||||
tags,
|
tags,
|
||||||
description,
|
description,
|
||||||
|
outputType,
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
outputType,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
import { compareVersions } from 'compare-versions';
|
||||||
|
import { cloneDeep, keyBy } from 'lodash-es';
|
||||||
|
import {
|
||||||
|
InvocationTemplate,
|
||||||
|
Workflow,
|
||||||
|
WorkflowWarning,
|
||||||
|
isWorkflowInvocationNode,
|
||||||
|
} from '../types/types';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
|
||||||
|
export const validateWorkflow = (
|
||||||
|
workflow: Workflow,
|
||||||
|
nodeTemplates: Record<string, InvocationTemplate>
|
||||||
|
) => {
|
||||||
|
const clone = cloneDeep(workflow);
|
||||||
|
const { nodes, edges } = clone;
|
||||||
|
const errors: WorkflowWarning[] = [];
|
||||||
|
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||||
|
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||||
|
nodes.forEach((node) => {
|
||||||
|
if (!isWorkflowInvocationNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const nodeTemplate = nodeTemplates[node.data.type];
|
||||||
|
if (!nodeTemplate) {
|
||||||
|
errors.push({
|
||||||
|
message: `Node "${node.data.type}" skipped`,
|
||||||
|
issues: [`Node type "${node.data.type}" does not exist`],
|
||||||
|
data: node,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
nodeTemplate.version &&
|
||||||
|
node.data.version &&
|
||||||
|
compareVersions(nodeTemplate.version, node.data.version) !== 0
|
||||||
|
) {
|
||||||
|
errors.push({
|
||||||
|
message: `Node "${node.data.type}" has mismatched version`,
|
||||||
|
issues: [
|
||||||
|
`Node "${node.data.type}" v${node.data.version} may be incompatible with installed v${nodeTemplate.version}`,
|
||||||
|
],
|
||||||
|
data: { node, nodeTemplate: parseify(nodeTemplate) },
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
edges.forEach((edge, i) => {
|
||||||
|
const sourceNode = keyedNodes[edge.source];
|
||||||
|
const targetNode = keyedNodes[edge.target];
|
||||||
|
const issues: string[] = [];
|
||||||
|
if (!sourceNode) {
|
||||||
|
issues.push(`Output node ${edge.source} does not exist`);
|
||||||
|
} else if (
|
||||||
|
edge.type === 'default' &&
|
||||||
|
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||||
|
) {
|
||||||
|
issues.push(
|
||||||
|
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!targetNode) {
|
||||||
|
issues.push(`Input node ${edge.target} does not exist`);
|
||||||
|
} else if (
|
||||||
|
edge.type === 'default' &&
|
||||||
|
!(edge.targetHandle in targetNode.data.inputs)
|
||||||
|
) {
|
||||||
|
issues.push(
|
||||||
|
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||||
|
issues.push(
|
||||||
|
`Source node "${edge.source}" missing template "${sourceNode?.data.type}"`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||||
|
issues.push(
|
||||||
|
`Source node "${edge.target}" missing template "${targetNode?.data.type}"`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (issues.length) {
|
||||||
|
delete edges[i];
|
||||||
|
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
|
||||||
|
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
|
||||||
|
errors.push({
|
||||||
|
message: `Edge "${src} -> ${tgt}" skipped`,
|
||||||
|
issues,
|
||||||
|
data: edge,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return { workflow: clone, errors };
|
||||||
|
};
|
@ -0,0 +1,42 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { IAISelectDataType } from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { setCanvasCoherenceMode } from 'features/parameters/store/generationSlice';
|
||||||
|
import { CanvasCoherenceModeParam } from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const coherenceModeSelectData: IAISelectDataType[] = [
|
||||||
|
{ label: 'Unmasked', value: 'unmasked' },
|
||||||
|
{ label: 'Mask', value: 'mask' },
|
||||||
|
{ label: 'Mask Edge', value: 'edge' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const ParamCanvasCoherenceMode = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const canvasCoherenceMode = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.canvasCoherenceMode
|
||||||
|
);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleCoherenceModeChange = (v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(setCanvasCoherenceMode(v as CanvasCoherenceModeParam));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
label={t('parameters.coherenceMode')}
|
||||||
|
data={coherenceModeSelectData}
|
||||||
|
value={canvasCoherenceMode}
|
||||||
|
onChange={handleCoherenceModeChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamCanvasCoherenceMode);
|
@ -3,6 +3,7 @@ import IAICollapse from 'common/components/IAICollapse';
|
|||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import SubParametersWrapper from '../../SubParametersWrapper';
|
import SubParametersWrapper from '../../SubParametersWrapper';
|
||||||
|
import ParamCanvasCoherenceMode from './CoherencePass/ParamCanvasCoherenceMode';
|
||||||
import ParamCanvasCoherenceSteps from './CoherencePass/ParamCanvasCoherenceSteps';
|
import ParamCanvasCoherenceSteps from './CoherencePass/ParamCanvasCoherenceSteps';
|
||||||
import ParamCanvasCoherenceStrength from './CoherencePass/ParamCanvasCoherenceStrength';
|
import ParamCanvasCoherenceStrength from './CoherencePass/ParamCanvasCoherenceStrength';
|
||||||
import ParamMaskBlur from './MaskAdjustment/ParamMaskBlur';
|
import ParamMaskBlur from './MaskAdjustment/ParamMaskBlur';
|
||||||
@ -14,15 +15,16 @@ const ParamCompositingSettingsCollapse = () => {
|
|||||||
return (
|
return (
|
||||||
<IAICollapse label={t('parameters.compositingSettingsHeader')}>
|
<IAICollapse label={t('parameters.compositingSettingsHeader')}>
|
||||||
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
|
<Flex sx={{ flexDirection: 'column', gap: 2 }}>
|
||||||
|
<SubParametersWrapper label={t('parameters.coherencePassHeader')}>
|
||||||
|
<ParamCanvasCoherenceMode />
|
||||||
|
<ParamCanvasCoherenceSteps />
|
||||||
|
<ParamCanvasCoherenceStrength />
|
||||||
|
</SubParametersWrapper>
|
||||||
|
<Divider />
|
||||||
<SubParametersWrapper label={t('parameters.maskAdjustmentsHeader')}>
|
<SubParametersWrapper label={t('parameters.maskAdjustmentsHeader')}>
|
||||||
<ParamMaskBlur />
|
<ParamMaskBlur />
|
||||||
<ParamMaskBlurMethod />
|
<ParamMaskBlurMethod />
|
||||||
</SubParametersWrapper>
|
</SubParametersWrapper>
|
||||||
<Divider />
|
|
||||||
<SubParametersWrapper label={t('parameters.coherencePassHeader')}>
|
|
||||||
<ParamCanvasCoherenceSteps />
|
|
||||||
<ParamCanvasCoherenceStrength />
|
|
||||||
</SubParametersWrapper>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAICollapse>
|
</IAICollapse>
|
||||||
);
|
);
|
||||||
|
@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
import SubParametersWrapper from '../../SubParametersWrapper';
|
import SubParametersWrapper from '../../SubParametersWrapper';
|
||||||
import ParamInfillMethod from './ParamInfillMethod';
|
import ParamInfillMethod from './ParamInfillMethod';
|
||||||
import ParamInfillTilesize from './ParamInfillTilesize';
|
import ParamInfillOptions from './ParamInfillOptions';
|
||||||
import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing';
|
import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing';
|
||||||
import ParamScaledHeight from './ParamScaledHeight';
|
import ParamScaledHeight from './ParamScaledHeight';
|
||||||
import ParamScaledWidth from './ParamScaledWidth';
|
import ParamScaledWidth from './ParamScaledWidth';
|
||||||
@ -18,7 +18,7 @@ const ParamInfillCollapse = () => {
|
|||||||
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
|
||||||
<SubParametersWrapper>
|
<SubParametersWrapper>
|
||||||
<ParamInfillMethod />
|
<ParamInfillMethod />
|
||||||
<ParamInfillTilesize />
|
<ParamInfillOptions />
|
||||||
</SubParametersWrapper>
|
</SubParametersWrapper>
|
||||||
<Divider />
|
<Divider />
|
||||||
<SubParametersWrapper>
|
<SubParametersWrapper>
|
||||||
|
@ -27,9 +27,7 @@ const ParamInfillMethod = () => {
|
|||||||
|
|
||||||
const { data: appConfigData, isLoading } = useGetAppConfigQuery();
|
const { data: appConfigData, isLoading } = useGetAppConfigQuery();
|
||||||
|
|
||||||
const infill_methods = appConfigData?.infill_methods.filter(
|
const infill_methods = appConfigData?.infill_methods;
|
||||||
(method) => method !== 'lama'
|
|
||||||
);
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
|
||||||
|
import ParamInfillTilesize from './ParamInfillTilesize';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[generationSelector],
|
||||||
|
(parameters) => {
|
||||||
|
const { infillMethod } = parameters;
|
||||||
|
|
||||||
|
return {
|
||||||
|
infillMethod,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
export default function ParamInfillOptions() {
|
||||||
|
const { infillMethod } = useAppSelector(selector);
|
||||||
|
return (
|
||||||
|
<Flex>
|
||||||
|
{infillMethod === 'tile' && <ParamInfillTilesize />}
|
||||||
|
{infillMethod === 'patchmatch' && <ParamInfillPatchmatchDownscaleSize />}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import { setInfillPatchmatchDownscaleSize } from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[generationSelector],
|
||||||
|
(parameters) => {
|
||||||
|
const { infillPatchmatchDownscaleSize, infillMethod } = parameters;
|
||||||
|
|
||||||
|
return {
|
||||||
|
infillPatchmatchDownscaleSize,
|
||||||
|
infillMethod,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamInfillPatchmatchDownscaleSize = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { infillPatchmatchDownscaleSize, infillMethod } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(setInfillPatchmatchDownscaleSize(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(() => {
|
||||||
|
dispatch(setInfillPatchmatchDownscaleSize(2));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISlider
|
||||||
|
isDisabled={infillMethod !== 'patchmatch'}
|
||||||
|
label={t('parameters.patchmatchDownScaleSize')}
|
||||||
|
min={1}
|
||||||
|
max={10}
|
||||||
|
value={infillPatchmatchDownscaleSize}
|
||||||
|
onChange={handleChange}
|
||||||
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
withReset
|
||||||
|
handleReset={handleReset}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamInfillPatchmatchDownscaleSize);
|
@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { setTileSize } from 'features/parameters/store/generationSlice';
|
import { setInfillTileSize } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -11,10 +11,10 @@ import { useTranslation } from 'react-i18next';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[generationSelector],
|
[generationSelector],
|
||||||
(parameters) => {
|
(parameters) => {
|
||||||
const { tileSize, infillMethod } = parameters;
|
const { infillTileSize, infillMethod } = parameters;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
tileSize,
|
infillTileSize,
|
||||||
infillMethod,
|
infillMethod,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
@ -23,19 +23,19 @@ const selector = createSelector(
|
|||||||
|
|
||||||
const ParamInfillTileSize = () => {
|
const ParamInfillTileSize = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { tileSize, infillMethod } = useAppSelector(selector);
|
const { infillTileSize, infillMethod } = useAppSelector(selector);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
dispatch(setTileSize(v));
|
dispatch(setInfillTileSize(v));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleReset = useCallback(() => {
|
const handleReset = useCallback(() => {
|
||||||
dispatch(setTileSize(32));
|
dispatch(setInfillTileSize(32));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -45,7 +45,7 @@ const ParamInfillTileSize = () => {
|
|||||||
min={16}
|
min={16}
|
||||||
max={64}
|
max={64}
|
||||||
sliderNumberInputProps={{ max: 256 }}
|
sliderNumberInputProps={{ max: 256 }}
|
||||||
value={tileSize}
|
value={infillTileSize}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
|
@ -7,6 +7,7 @@ import { ImageDTO } from 'services/api/types';
|
|||||||
|
|
||||||
import { clipSkipMap } from '../types/constants';
|
import { clipSkipMap } from '../types/constants';
|
||||||
import {
|
import {
|
||||||
|
CanvasCoherenceModeParam,
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
MainModelParam,
|
MainModelParam,
|
||||||
@ -37,6 +38,7 @@ export interface GenerationState {
|
|||||||
scheduler: SchedulerParam;
|
scheduler: SchedulerParam;
|
||||||
maskBlur: number;
|
maskBlur: number;
|
||||||
maskBlurMethod: MaskBlurMethodParam;
|
maskBlurMethod: MaskBlurMethodParam;
|
||||||
|
canvasCoherenceMode: CanvasCoherenceModeParam;
|
||||||
canvasCoherenceSteps: number;
|
canvasCoherenceSteps: number;
|
||||||
canvasCoherenceStrength: StrengthParam;
|
canvasCoherenceStrength: StrengthParam;
|
||||||
seed: SeedParam;
|
seed: SeedParam;
|
||||||
@ -47,7 +49,8 @@ export interface GenerationState {
|
|||||||
shouldUseNoiseSettings: boolean;
|
shouldUseNoiseSettings: boolean;
|
||||||
steps: StepsParam;
|
steps: StepsParam;
|
||||||
threshold: number;
|
threshold: number;
|
||||||
tileSize: number;
|
infillTileSize: number;
|
||||||
|
infillPatchmatchDownscaleSize: number;
|
||||||
variationAmount: number;
|
variationAmount: number;
|
||||||
width: WidthParam;
|
width: WidthParam;
|
||||||
shouldUseSymmetry: boolean;
|
shouldUseSymmetry: boolean;
|
||||||
@ -77,6 +80,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
scheduler: 'euler',
|
scheduler: 'euler',
|
||||||
maskBlur: 16,
|
maskBlur: 16,
|
||||||
maskBlurMethod: 'box',
|
maskBlurMethod: 'box',
|
||||||
|
canvasCoherenceMode: 'edge',
|
||||||
canvasCoherenceSteps: 20,
|
canvasCoherenceSteps: 20,
|
||||||
canvasCoherenceStrength: 0.3,
|
canvasCoherenceStrength: 0.3,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
@ -87,7 +91,8 @@ export const initialGenerationState: GenerationState = {
|
|||||||
shouldUseNoiseSettings: false,
|
shouldUseNoiseSettings: false,
|
||||||
steps: 50,
|
steps: 50,
|
||||||
threshold: 0,
|
threshold: 0,
|
||||||
tileSize: 32,
|
infillTileSize: 32,
|
||||||
|
infillPatchmatchDownscaleSize: 1,
|
||||||
variationAmount: 0.1,
|
variationAmount: 0.1,
|
||||||
width: 512,
|
width: 512,
|
||||||
shouldUseSymmetry: false,
|
shouldUseSymmetry: false,
|
||||||
@ -206,18 +211,30 @@ export const generationSlice = createSlice({
|
|||||||
setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => {
|
setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => {
|
||||||
state.maskBlurMethod = action.payload;
|
state.maskBlurMethod = action.payload;
|
||||||
},
|
},
|
||||||
|
setCanvasCoherenceMode: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<CanvasCoherenceModeParam>
|
||||||
|
) => {
|
||||||
|
state.canvasCoherenceMode = action.payload;
|
||||||
|
},
|
||||||
setCanvasCoherenceSteps: (state, action: PayloadAction<number>) => {
|
setCanvasCoherenceSteps: (state, action: PayloadAction<number>) => {
|
||||||
state.canvasCoherenceSteps = action.payload;
|
state.canvasCoherenceSteps = action.payload;
|
||||||
},
|
},
|
||||||
setCanvasCoherenceStrength: (state, action: PayloadAction<number>) => {
|
setCanvasCoherenceStrength: (state, action: PayloadAction<number>) => {
|
||||||
state.canvasCoherenceStrength = action.payload;
|
state.canvasCoherenceStrength = action.payload;
|
||||||
},
|
},
|
||||||
setTileSize: (state, action: PayloadAction<number>) => {
|
|
||||||
state.tileSize = action.payload;
|
|
||||||
},
|
|
||||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||||
state.infillMethod = action.payload;
|
state.infillMethod = action.payload;
|
||||||
},
|
},
|
||||||
|
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillTileSize = action.payload;
|
||||||
|
},
|
||||||
|
setInfillPatchmatchDownscaleSize: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<number>
|
||||||
|
) => {
|
||||||
|
state.infillPatchmatchDownscaleSize = action.payload;
|
||||||
|
},
|
||||||
setShouldUseSymmetry: (state, action: PayloadAction<boolean>) => {
|
setShouldUseSymmetry: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldUseSymmetry = action.payload;
|
state.shouldUseSymmetry = action.payload;
|
||||||
},
|
},
|
||||||
@ -323,6 +340,7 @@ export const {
|
|||||||
setScheduler,
|
setScheduler,
|
||||||
setMaskBlur,
|
setMaskBlur,
|
||||||
setMaskBlurMethod,
|
setMaskBlurMethod,
|
||||||
|
setCanvasCoherenceMode,
|
||||||
setCanvasCoherenceSteps,
|
setCanvasCoherenceSteps,
|
||||||
setCanvasCoherenceStrength,
|
setCanvasCoherenceStrength,
|
||||||
setSeed,
|
setSeed,
|
||||||
@ -332,7 +350,8 @@ export const {
|
|||||||
setShouldRandomizeSeed,
|
setShouldRandomizeSeed,
|
||||||
setSteps,
|
setSteps,
|
||||||
setThreshold,
|
setThreshold,
|
||||||
setTileSize,
|
setInfillTileSize,
|
||||||
|
setInfillPatchmatchDownscaleSize,
|
||||||
setVariationAmount,
|
setVariationAmount,
|
||||||
setShouldUseSymmetry,
|
setShouldUseSymmetry,
|
||||||
setHorizontalSymmetrySteps,
|
setHorizontalSymmetrySteps,
|
||||||
|
@ -418,6 +418,22 @@ export const isValidMaskBlurMethod = (
|
|||||||
val: unknown
|
val: unknown
|
||||||
): val is MaskBlurMethodParam => zMaskBlurMethod.safeParse(val).success;
|
): val is MaskBlurMethodParam => zMaskBlurMethod.safeParse(val).success;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zod schema for a Canvas Coherence Mode method parameter
|
||||||
|
*/
|
||||||
|
export const zCanvasCoherenceMode = z.enum(['unmasked', 'mask', 'edge']);
|
||||||
|
/**
|
||||||
|
* Type alias for Canvas Coherence Mode parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type CanvasCoherenceModeParam = z.infer<typeof zCanvasCoherenceMode>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a mask blur method parameter
|
||||||
|
*/
|
||||||
|
export const isValidCoherenceModeParam = (
|
||||||
|
val: unknown
|
||||||
|
): val is CanvasCoherenceModeParam =>
|
||||||
|
zCanvasCoherenceMode.safeParse(val).success;
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
// * Zod schema for BaseModelType
|
// * Zod schema for BaseModelType
|
||||||
// */
|
// */
|
||||||
|
266
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
266
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -112,6 +112,7 @@ export type ImageScaleInvocation = s['ImageScaleInvocation'];
|
|||||||
export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation'];
|
export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation'];
|
||||||
export type InfillTileInvocation = s['InfillTileInvocation'];
|
export type InfillTileInvocation = s['InfillTileInvocation'];
|
||||||
export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation'];
|
export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation'];
|
||||||
|
export type MaskEdgeInvocation = s['MaskEdgeInvocation'];
|
||||||
export type RandomIntInvocation = s['RandomIntInvocation'];
|
export type RandomIntInvocation = s['RandomIntInvocation'];
|
||||||
export type CompelInvocation = s['CompelInvocation'];
|
export type CompelInvocation = s['CompelInvocation'];
|
||||||
export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
|
export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
|
||||||
|
@ -2970,6 +2970,11 @@ commondir@^1.0.1:
|
|||||||
resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b"
|
resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b"
|
||||||
integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==
|
integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==
|
||||||
|
|
||||||
|
compare-versions@^6.1.0:
|
||||||
|
version "6.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/compare-versions/-/compare-versions-6.1.0.tgz#3f2131e3ae93577df111dba133e6db876ffe127a"
|
||||||
|
integrity sha512-LNZQXhqUvqUTotpZ00qLSaify3b4VFD588aRr8MKFw4CMUr98ytzCW5wDH5qx/DEY5kCDXcbcRuCqL0szEf2tg==
|
||||||
|
|
||||||
compute-scroll-into-view@1.0.20:
|
compute-scroll-into-view@1.0.20:
|
||||||
version "1.0.20"
|
version "1.0.20"
|
||||||
resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.20.tgz#1768b5522d1172754f5d0c9b02de3af6be506a43"
|
resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.20.tgz#1768b5522d1172754f5d0c9b02de3af6be506a43"
|
||||||
|
@ -74,6 +74,7 @@ dependencies = [
|
|||||||
"rich~=13.3",
|
"rich~=13.3",
|
||||||
"safetensors==0.3.1",
|
"safetensors==0.3.1",
|
||||||
"scikit-image~=0.21.0",
|
"scikit-image~=0.21.0",
|
||||||
|
"semver~=3.0.1",
|
||||||
"send2trash",
|
"send2trash",
|
||||||
"test-tube~=0.7.5",
|
"test-tube~=0.7.5",
|
||||||
"torch~=2.0.1",
|
"torch~=2.0.1",
|
||||||
|
@ -1,3 +1,10 @@
|
|||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvalidVersionError,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
from .test_nodes import (
|
from .test_nodes import (
|
||||||
ImageToImageTestInvocation,
|
ImageToImageTestInvocation,
|
||||||
TextToImageTestInvocation,
|
TextToImageTestInvocation,
|
||||||
@ -20,7 +27,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
|
|||||||
|
|
||||||
from invokeai.app.invocations.image import ShowImageInvocation
|
from invokeai.app.invocations.image import ShowImageInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
from invokeai.app.invocations.primitives import IntegerInvocation
|
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||||
from invokeai.app.services.default_graphs import create_text_to_image
|
from invokeai.app.services.default_graphs import create_text_to_image
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -610,6 +617,79 @@ def test_graph_can_deserialize():
|
|||||||
assert g2.edges[0].destination.field == "image"
|
assert g2.edges[0].destination.field == "image"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_decorator():
|
||||||
|
invocation_type = "test_invocation"
|
||||||
|
title = "Test Invocation"
|
||||||
|
tags = ["first", "second", "third"]
|
||||||
|
category = "category"
|
||||||
|
version = "1.2.3"
|
||||||
|
|
||||||
|
@invocation(invocation_type, title=title, tags=tags, category=category, version=version)
|
||||||
|
class TestInvocation(BaseInvocation):
|
||||||
|
def invoke(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
schema = TestInvocation.schema()
|
||||||
|
|
||||||
|
assert schema.get("title") == title
|
||||||
|
assert schema.get("tags") == tags
|
||||||
|
assert schema.get("category") == category
|
||||||
|
assert schema.get("version") == version
|
||||||
|
assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_version_must_be_semver():
|
||||||
|
invocation_type = "test_invocation"
|
||||||
|
valid_version = "1.0.0"
|
||||||
|
invalid_version = "not_semver"
|
||||||
|
|
||||||
|
@invocation(invocation_type, version=valid_version)
|
||||||
|
class ValidVersionInvocation(BaseInvocation):
|
||||||
|
def invoke(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(InvalidVersionError):
|
||||||
|
|
||||||
|
@invocation(invocation_type, version=invalid_version)
|
||||||
|
class InvalidVersionInvocation(BaseInvocation):
|
||||||
|
def invoke(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_output_decorator():
|
||||||
|
output_type = "test_output"
|
||||||
|
|
||||||
|
@invocation_output(output_type)
|
||||||
|
class TestOutput(BaseInvocationOutput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floats_accept_ints():
|
||||||
|
g = Graph()
|
||||||
|
n1 = IntegerInvocation(id="1", value=1)
|
||||||
|
n2 = FloatInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ints_do_not_accept_floats():
|
||||||
|
g = Graph()
|
||||||
|
n1 = FloatInvocation(id="1", value=1.0)
|
||||||
|
n2 = IntegerInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
|
||||||
|
with pytest.raises(InvalidEdgeError):
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
|
Loading…
x
Reference in New Issue
Block a user