mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: node editor
squashed rebase on main after backendd refactor
This commit is contained in:
parent
d6c9bf5b38
commit
f49fc7fb55
@ -38,7 +38,7 @@ import mimetypes
|
|||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -134,6 +134,11 @@ def custom_openapi():
|
|||||||
# This could break in some cases, figure out a better way to do it
|
# This could break in some cases, figure out a better way to do it
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
output_type_titles[schema_key] = output_schema["title"]
|
||||||
|
|
||||||
|
# Add Node Editor UI helper schemas
|
||||||
|
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||||
|
for schema_key, output_schema in ui_config_schemas["definitions"].items():
|
||||||
|
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
invoker_name = invoker.__name__
|
invoker_name = invoker.__name__
|
||||||
|
@ -3,15 +3,353 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
AbstractSet,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_type_hints,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import BaseConfig, BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.fields import Undefined
|
||||||
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
|
class FieldDescriptions:
|
||||||
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
|
cfg_scale = "Classifier-Free Guidance scale"
|
||||||
|
scheduler = "Scheduler to use during inference"
|
||||||
|
positive_cond = "Positive conditioning tensor"
|
||||||
|
negative_cond = "Negative conditioning tensor"
|
||||||
|
noise = "Noise tensor"
|
||||||
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
vae = "VAE"
|
||||||
|
cond = "Conditioning tensor"
|
||||||
|
controlnet_model = "ControlNet model to load"
|
||||||
|
vae_model = "VAE model to load"
|
||||||
|
lora_model = "LoRA model to load"
|
||||||
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
|
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||||
|
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||||
|
raw_prompt = "Raw prompt text (no parsing)"
|
||||||
|
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||||
|
skipped_layers = "Number of layers to skip in text encoder"
|
||||||
|
seed = "Seed for random number generation"
|
||||||
|
steps = "Number of steps to run"
|
||||||
|
width = "Width of output (px)"
|
||||||
|
height = "Height of output (px)"
|
||||||
|
control = "ControlNet(s) to apply"
|
||||||
|
denoised_latents = "Denoised latents tensor"
|
||||||
|
latents = "Latents tensor"
|
||||||
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
|
core_metadata = "Optional core metadata to be written to image"
|
||||||
|
interp_mode = "Interpolation mode"
|
||||||
|
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||||
|
fp32 = "Whether or not to use full float32 precision"
|
||||||
|
precision = "Precision to use"
|
||||||
|
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||||
|
detect_res = "Pixel resolution for detection"
|
||||||
|
image_res = "Pixel resolution for output image"
|
||||||
|
safe_mode = "Whether or not to use safe mode"
|
||||||
|
scribble_mode = "Whether or not to use scribble mode"
|
||||||
|
scale_factor = "The factor by which to scale"
|
||||||
|
num_1 = "The first number"
|
||||||
|
num_2 = "The second number"
|
||||||
|
mask = "The mask to use for the operation"
|
||||||
|
|
||||||
|
|
||||||
|
class Input(str, Enum):
|
||||||
|
"""
|
||||||
|
The type of input a field accepts.
|
||||||
|
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||||
|
are instantiated.
|
||||||
|
- `Input.Connection`: The field must have its value provided by a connection.
|
||||||
|
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Connection = "connection"
|
||||||
|
Direct = "direct"
|
||||||
|
Any = "any"
|
||||||
|
|
||||||
|
|
||||||
|
class UITypeHint(str, Enum):
|
||||||
|
"""
|
||||||
|
Type hints for the UI.
|
||||||
|
If a field should be provided a data type that does not exactly match the python type of the field, \
|
||||||
|
use this to provide the type that should be used instead. See the node development docs for detail \
|
||||||
|
on adding a new field type, which involves client-side changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Integer = "integer"
|
||||||
|
Float = "float"
|
||||||
|
Boolean = "boolean"
|
||||||
|
String = "string"
|
||||||
|
Enum = "enum"
|
||||||
|
Array = "array"
|
||||||
|
ImageField = "ImageField"
|
||||||
|
LatentsField = "LatentsField"
|
||||||
|
ConditioningField = "ConditioningField"
|
||||||
|
ControlField = "ControlField"
|
||||||
|
MainModelField = "MainModelField"
|
||||||
|
SDXLMainModelField = "SDXLMainModelField"
|
||||||
|
SDXLRefinerModelField = "SDXLRefinerModelField"
|
||||||
|
ONNXModelField = "ONNXModelField"
|
||||||
|
VaeModelField = "VaeModelField"
|
||||||
|
LoRAModelField = "LoRAModelField"
|
||||||
|
ControlNetModelField = "ControlNetModelField"
|
||||||
|
UNetField = "UNetField"
|
||||||
|
VaeField = "VaeField"
|
||||||
|
ClipField = "ClipField"
|
||||||
|
ColorField = "ColorField"
|
||||||
|
ImageCollection = "ImageCollection"
|
||||||
|
IntegerCollection = "IntegerCollection"
|
||||||
|
FloatCollection = "FloatCollection"
|
||||||
|
StringCollection = "StringCollection"
|
||||||
|
BooleanCollection = "BooleanCollection"
|
||||||
|
Collection = "Collection"
|
||||||
|
CollectionItem = "CollectionItem"
|
||||||
|
Seed = "Seed"
|
||||||
|
FilePath = "FilePath"
|
||||||
|
|
||||||
|
|
||||||
|
class UIComponent(str, Enum):
|
||||||
|
"""
|
||||||
|
The type of UI component to use for a field, used to override the default components, which are \
|
||||||
|
inferred from the field type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
None_ = "none"
|
||||||
|
Textarea = "textarea"
|
||||||
|
Slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
|
class _InputField(BaseModel):
|
||||||
|
"""
|
||||||
|
*DO NOT USE*
|
||||||
|
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||||
|
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||||
|
purpose in the backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input: Input
|
||||||
|
ui_hidden: bool
|
||||||
|
ui_type_hint: Optional[UITypeHint]
|
||||||
|
ui_component: Optional[UIComponent]
|
||||||
|
|
||||||
|
|
||||||
|
class _OutputField(BaseModel):
|
||||||
|
"""
|
||||||
|
*DO NOT USE*
|
||||||
|
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||||
|
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||||
|
purpose in the backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ui_hidden: bool
|
||||||
|
ui_type_hint: Optional[UITypeHint]
|
||||||
|
|
||||||
|
|
||||||
|
def InputField(
|
||||||
|
*args: Any,
|
||||||
|
default: Any = Undefined,
|
||||||
|
default_factory: Optional[NoArgAnyCallable] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
const: Optional[bool] = None,
|
||||||
|
gt: Optional[float] = None,
|
||||||
|
ge: Optional[float] = None,
|
||||||
|
lt: Optional[float] = None,
|
||||||
|
le: Optional[float] = None,
|
||||||
|
multiple_of: Optional[float] = None,
|
||||||
|
allow_inf_nan: Optional[bool] = None,
|
||||||
|
max_digits: Optional[int] = None,
|
||||||
|
decimal_places: Optional[int] = None,
|
||||||
|
min_items: Optional[int] = None,
|
||||||
|
max_items: Optional[int] = None,
|
||||||
|
unique_items: Optional[bool] = None,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
allow_mutation: bool = True,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
discriminator: Optional[str] = None,
|
||||||
|
repr: bool = True,
|
||||||
|
input: Input = Input.Any,
|
||||||
|
ui_type_hint: Optional[UITypeHint] = None,
|
||||||
|
ui_component: Optional[UIComponent] = None,
|
||||||
|
ui_hidden: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Creates an input field for an invocation.
|
||||||
|
|
||||||
|
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||||
|
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||||
|
|
||||||
|
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||||
|
`Input.Direct` means a value must be provided on instantiation. \
|
||||||
|
`Input.Connection` means the value must be provided by a connection. \
|
||||||
|
`Input.Any` means either will do.
|
||||||
|
|
||||||
|
:param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \
|
||||||
|
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||||
|
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||||
|
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||||
|
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||||
|
`UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
|
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||||
|
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||||
|
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||||
|
For this case, you could provide `UIComponent.Textarea`.
|
||||||
|
|
||||||
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||||
|
"""
|
||||||
|
return Field(
|
||||||
|
*args,
|
||||||
|
default=default,
|
||||||
|
default_factory=default_factory,
|
||||||
|
alias=alias,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
exclude=exclude,
|
||||||
|
include=include,
|
||||||
|
const=const,
|
||||||
|
gt=gt,
|
||||||
|
ge=ge,
|
||||||
|
lt=lt,
|
||||||
|
le=le,
|
||||||
|
multiple_of=multiple_of,
|
||||||
|
allow_inf_nan=allow_inf_nan,
|
||||||
|
max_digits=max_digits,
|
||||||
|
decimal_places=decimal_places,
|
||||||
|
min_items=min_items,
|
||||||
|
max_items=max_items,
|
||||||
|
unique_items=unique_items,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
allow_mutation=allow_mutation,
|
||||||
|
regex=regex,
|
||||||
|
discriminator=discriminator,
|
||||||
|
repr=repr,
|
||||||
|
input=input,
|
||||||
|
ui_type_hint=ui_type_hint,
|
||||||
|
ui_component=ui_component,
|
||||||
|
ui_hidden=ui_hidden,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def OutputField(
|
||||||
|
*args: Any,
|
||||||
|
default: Any = Undefined,
|
||||||
|
default_factory: Optional[NoArgAnyCallable] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
const: Optional[bool] = None,
|
||||||
|
gt: Optional[float] = None,
|
||||||
|
ge: Optional[float] = None,
|
||||||
|
lt: Optional[float] = None,
|
||||||
|
le: Optional[float] = None,
|
||||||
|
multiple_of: Optional[float] = None,
|
||||||
|
allow_inf_nan: Optional[bool] = None,
|
||||||
|
max_digits: Optional[int] = None,
|
||||||
|
decimal_places: Optional[int] = None,
|
||||||
|
min_items: Optional[int] = None,
|
||||||
|
max_items: Optional[int] = None,
|
||||||
|
unique_items: Optional[bool] = None,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
allow_mutation: bool = True,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
discriminator: Optional[str] = None,
|
||||||
|
repr: bool = True,
|
||||||
|
ui_type_hint: Optional[UITypeHint] = None,
|
||||||
|
ui_hidden: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Creates an output field for an invocation output.
|
||||||
|
|
||||||
|
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||||
|
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||||
|
|
||||||
|
:param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \
|
||||||
|
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||||
|
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||||
|
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||||
|
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||||
|
`UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||||
|
"""
|
||||||
|
return Field(
|
||||||
|
*args,
|
||||||
|
default=default,
|
||||||
|
default_factory=default_factory,
|
||||||
|
alias=alias,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
exclude=exclude,
|
||||||
|
include=include,
|
||||||
|
const=const,
|
||||||
|
gt=gt,
|
||||||
|
ge=ge,
|
||||||
|
lt=lt,
|
||||||
|
le=le,
|
||||||
|
multiple_of=multiple_of,
|
||||||
|
allow_inf_nan=allow_inf_nan,
|
||||||
|
max_digits=max_digits,
|
||||||
|
decimal_places=decimal_places,
|
||||||
|
min_items=min_items,
|
||||||
|
max_items=max_items,
|
||||||
|
unique_items=unique_items,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
allow_mutation=allow_mutation,
|
||||||
|
regex=regex,
|
||||||
|
discriminator=discriminator,
|
||||||
|
repr=repr,
|
||||||
|
ui_type_hint=ui_type_hint,
|
||||||
|
ui_hidden=ui_hidden,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UIConfigBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Provides additional node configuration to the UI.
|
||||||
|
This is used internally by the @tags and @title decorator logic. You probably want to use those
|
||||||
|
decorators, though you may add this class to a node definition to specify the title and tags.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
|
||||||
|
title: Optional[str] = Field(default=None, description="The display name of the node")
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
services: InvocationServices
|
services: InvocationServices
|
||||||
graph_execution_state_id: str
|
graph_execution_state_id: str
|
||||||
@ -39,6 +377,20 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
return tuple(subclasses)
|
return tuple(subclasses)
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredConnectionException(Exception):
|
||||||
|
"""Raised when an field which requires a connection did not receive a value."""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str, field_name: str):
|
||||||
|
super().__init__(f"Node {node_id} missing connections for field {field_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class MissingInputException(Exception):
|
||||||
|
"""Raised when an field which requires some input, but did not receive a value."""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str, field_name: str):
|
||||||
|
super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
|
||||||
|
|
||||||
|
|
||||||
class BaseInvocation(ABC, BaseModel):
|
class BaseInvocation(ABC, BaseModel):
|
||||||
"""A node to process inputs and produce outputs.
|
"""A node to process inputs and produce outputs.
|
||||||
May use dependency injection in __init__ to receive providers.
|
May use dependency injection in __init__ to receive providers.
|
||||||
@ -76,70 +428,81 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
def get_output_type(cls):
|
def get_output_type(cls):
|
||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
@staticmethod
|
||||||
|
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
uiconfig = getattr(model_class, "UIConfig", None)
|
||||||
|
if uiconfig and hasattr(uiconfig, "title"):
|
||||||
|
schema["title"] = uiconfig.title
|
||||||
|
if uiconfig and hasattr(uiconfig, "tags"):
|
||||||
|
schema["tags"] = uiconfig.tags
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
"""Invoke with provided context and return outputs."""
|
"""Invoke with provided context and return outputs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# fmt: off
|
def __init__(self, **data):
|
||||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
# nodes may have required fields, that can accept input from connections
|
||||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
# on instantiation of the model, we need to exclude these from validation
|
||||||
# fmt: on
|
restore = dict()
|
||||||
|
try:
|
||||||
|
field_names = list(self.__fields__.keys())
|
||||||
|
for field_name in field_names:
|
||||||
|
# if the field is required and may get its value from a connection, exclude it from validation
|
||||||
|
field = self.__fields__[field_name]
|
||||||
|
_input = field.field_info.extra.get("input", None)
|
||||||
|
if _input in [Input.Connection, Input.Any] and field.required:
|
||||||
|
if field_name not in data:
|
||||||
|
restore[field_name] = self.__fields__.pop(field_name)
|
||||||
|
# instantiate the node, which will validate the data
|
||||||
|
super().__init__(**data)
|
||||||
|
finally:
|
||||||
|
# restore the removed fields
|
||||||
|
for field_name, field in restore.items():
|
||||||
|
self.__fields__[field_name] = field
|
||||||
|
|
||||||
|
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
|
for field_name, field in self.__fields__.items():
|
||||||
|
_input = field.field_info.extra.get("input", None)
|
||||||
|
if field.required and not hasattr(self, field_name):
|
||||||
|
if _input == Input.Connection:
|
||||||
|
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||||
|
elif _input == Input.Any:
|
||||||
|
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||||
|
return self.invoke(context)
|
||||||
|
|
||||||
|
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
|
||||||
|
is_intermediate: bool = InputField(
|
||||||
|
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
||||||
|
)
|
||||||
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: figure out a better way to provide these hints
|
T = TypeVar("T", bound=BaseInvocation)
|
||||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
|
||||||
class UIConfig(TypedDict, total=False):
|
|
||||||
type_hints: Dict[
|
|
||||||
str,
|
|
||||||
Literal[
|
|
||||||
"integer",
|
|
||||||
"float",
|
|
||||||
"boolean",
|
|
||||||
"string",
|
|
||||||
"enum",
|
|
||||||
"image",
|
|
||||||
"latents",
|
|
||||||
"model",
|
|
||||||
"control",
|
|
||||||
"image_collection",
|
|
||||||
"vae_model",
|
|
||||||
"lora_model",
|
|
||||||
],
|
|
||||||
]
|
|
||||||
tags: List[str]
|
|
||||||
title: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomisedSchemaExtra(TypedDict):
|
def title(title: str) -> Callable[[Type[T]], Type[T]]:
|
||||||
ui: UIConfig
|
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
|
||||||
|
|
||||||
|
def wrapper(cls: Type[T]) -> Type[T]:
|
||||||
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||||
|
cls.UIConfig.title = title
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class InvocationConfig(BaseConfig):
|
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
|
||||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
|
||||||
|
|
||||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
def wrapper(cls: Type[T]) -> Type[T]:
|
||||||
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||||
|
cls.UIConfig.tags = list(tags)
|
||||||
|
return cls
|
||||||
|
|
||||||
`tags`
|
return wrapper
|
||||||
- A list of strings, used to categorise invocations.
|
|
||||||
|
|
||||||
`type_hints`
|
|
||||||
- A dict of field types which override the types in the invocation definition.
|
|
||||||
- Each key should be the name of one of the invocation's fields.
|
|
||||||
- Each value should be one of the valid types:
|
|
||||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
|
||||||
|
|
||||||
```python
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["stable-diffusion", "image"],
|
|
||||||
"type_hints": {
|
|
||||||
"initial_image": "image",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
schema_extra: CustomisedSchemaExtra
|
|
||||||
|
@ -3,58 +3,78 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field, validator
|
from pydantic import validator
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField
|
from invokeai.app.models.image import ImageField
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UITypeHint,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IntCollectionOutput(BaseInvocationOutput):
|
class IntCollectionOutput(BaseInvocationOutput):
|
||||||
"""A collection of integers"""
|
"""A collection of integers"""
|
||||||
|
|
||||||
type: Literal["int_collection"] = "int_collection"
|
type: Literal["int_collection_output"] = "int_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[int] = Field(default=[], description="The int collection")
|
collection: list[int] = OutputField(
|
||||||
|
default=[], description="The int collection", ui_type_hint=UITypeHint.IntegerCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FloatCollectionOutput(BaseInvocationOutput):
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
"""A collection of floats"""
|
"""A collection of floats"""
|
||||||
|
|
||||||
type: Literal["float_collection"] = "float_collection"
|
type: Literal["float_collection_output"] = "float_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[float] = Field(default=[], description="The float collection")
|
collection: list[float] = OutputField(
|
||||||
|
default=[], description="The float collection", ui_type_hint=UITypeHint.FloatCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StringCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""A collection of strings"""
|
||||||
|
|
||||||
|
type: Literal["string_collection_output"] = "string_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[str] = OutputField(
|
||||||
|
default=[], description="The output strings", ui_type_hint=UITypeHint.StringCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
"""A collection of images"""
|
"""A collection of images"""
|
||||||
|
|
||||||
type: Literal["image_collection"] = "image_collection"
|
type: Literal["image_collection_output"] = "image_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
collection: list[ImageField] = OutputField(
|
||||||
|
default=[], description="The output images", ui_type_hint=UITypeHint.ImageCollection
|
||||||
class Config:
|
)
|
||||||
schema_extra = {"required": ["type", "collection"]}
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("Integer Range")
|
||||||
|
@tags("collection", "integer", "range")
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range of numbers from start to stop with step"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
|
||||||
type: Literal["range"] = "range"
|
type: Literal["range"] = "range"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
start: int = Field(default=0, description="The start of the range")
|
start: int = InputField(default=0, description="The start of the range")
|
||||||
stop: int = Field(default=10, description="The stop of the range")
|
stop: int = InputField(default=10, description="The stop of the range")
|
||||||
step: int = Field(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("stop")
|
@validator("stop")
|
||||||
def stop_gt_start(cls, v, values):
|
def stop_gt_start(cls, v, values):
|
||||||
@ -66,72 +86,56 @@ class RangeInvocation(BaseInvocation):
|
|||||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
|
|
||||||
|
|
||||||
|
@title("Integer Range of Size")
|
||||||
|
@tags("range", "integer", "size", "collection")
|
||||||
class RangeOfSizeInvocation(BaseInvocation):
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
"""Creates a range from start to start + size with step"""
|
"""Creates a range from start to start + size with step"""
|
||||||
|
|
||||||
type: Literal["range_of_size"] = "range_of_size"
|
type: Literal["range_of_size"] = "range_of_size"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
start: int = Field(default=0, description="The start of the range")
|
start: int = InputField(default=0, description="The start of the range")
|
||||||
size: int = Field(default=1, description="The number of values")
|
size: int = InputField(default=1, description="The number of values")
|
||||||
step: int = Field(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||||
|
|
||||||
|
|
||||||
|
@title("Random Range")
|
||||||
|
@tags("range", "integer", "random", "collection")
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
"""Creates a collection of random numbers"""
|
"""Creates a collection of random numbers"""
|
||||||
|
|
||||||
type: Literal["random_range"] = "random_range"
|
type: Literal["random_range"] = "random_range"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
low: int = InputField(default=0, description="The inclusive low value")
|
||||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = InputField(default=1, description="The number of values to generate")
|
||||||
seed: int = Field(
|
seed: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=SEED_MAX,
|
le=SEED_MAX,
|
||||||
description="The seed for the RNG (omit for random)",
|
description="The seed for the RNG (omit for random)",
|
||||||
default_factory=get_random_seed,
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
rng = np.random.default_rng(self.seed)
|
rng = np.random.default_rng(self.seed)
|
||||||
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Collection")
|
||||||
|
@tags("image", "collection")
|
||||||
class ImageCollectionInvocation(BaseInvocation):
|
class ImageCollectionInvocation(BaseInvocation):
|
||||||
"""Load a collection of images and provide it as output."""
|
"""Load a collection of images and provide it as output."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_collection"] = "image_collection"
|
type: Literal["image_collection"] = "image_collection"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
images: list[ImageField] = Field(
|
images: list[ImageField] = InputField(
|
||||||
default=[], description="The image collection to load"
|
default=[], description="The image collection to load", ui_type_hint=UITypeHint.ImageCollection
|
||||||
)
|
)
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||||
return ImageCollectionOutput(collection=self.images)
|
return ImageCollectionOutput(collection=self.images)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"type_hints": {
|
|
||||||
"title": "Image Collection",
|
|
||||||
"images": "image_collection",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
@ -1,29 +1,39 @@
|
|||||||
from typing import Literal, Optional, Union, List, Annotated
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from typing import List, Literal, Union
|
||||||
from .model import ClipField
|
|
||||||
|
|
||||||
from ...backend.util.devices import torch_dtype
|
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from ...backend.util.devices import torch_dtype
|
from pydantic import BaseModel, Field
|
||||||
from ...backend.model_management import ModelType
|
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||||
|
BasicConditioningInfo,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...backend.model_management import ModelPatcher, ModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
|
from ...backend.util.devices import torch_dtype
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
class ConditioningField(BaseModel):
|
||||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
conditioning_name: str = Field(description="The name of conditioning data")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["conditioning_name"]}
|
schema_extra = {"required": ["conditioning_name"]}
|
||||||
@ -47,23 +57,27 @@ class CompelOutput(BaseInvocationOutput):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["compel_output"] = "compel_output"
|
type: Literal["compel_output"] = "compel_output"
|
||||||
|
|
||||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@title("Compel Prompt")
|
||||||
|
@tags("prompt", "compel")
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["compel"] = "compel"
|
type: Literal["compel"] = "compel"
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = InputField(
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
# Schema customisation
|
ui_component=UIComponent.Textarea,
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
clip: ClipField = InputField(
|
||||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
title="CLIP",
|
||||||
}
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
@ -270,27 +284,23 @@ class SDXLPromptInvocationBase:
|
|||||||
return c, c_pooled, ec
|
return c, c_pooled, ec
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Compel Prompt")
|
||||||
|
@tags("sdxl", "compel", "prompt")
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||||
style: str = Field(default="", description="Style prompt")
|
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||||
original_width: int = Field(1024, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
original_height: int = Field(1024, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
crop_top: int = Field(0, description="")
|
crop_top: int = InputField(default=0, description="")
|
||||||
crop_left: int = Field(0, description="")
|
crop_left: int = InputField(default=0, description="")
|
||||||
target_width: int = Field(1024, description="")
|
target_width: int = InputField(default=1024, description="")
|
||||||
target_height: int = Field(1024, description="")
|
target_height: int = InputField(default=1024, description="")
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
@ -333,28 +343,22 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Refiner Compel Prompt")
|
||||||
|
@tags("sdxl", "compel", "prompt")
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
style: str = InputField(
|
||||||
original_width: int = Field(1024, description="")
|
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||||
original_height: int = Field(1024, description="")
|
) # TODO: ?
|
||||||
crop_top: int = Field(0, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
crop_left: int = Field(0, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
aesthetic_score: float = Field(6.0, description="")
|
crop_top: int = InputField(default=0, description="")
|
||||||
clip2: ClipField = Field(None, description="Clip to use")
|
crop_left: int = InputField(default=0, description="")
|
||||||
|
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||||
# Schema customisation
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Refiner Prompt (Compel)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
@ -391,21 +395,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
|||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
|
@title("CLIP Skip")
|
||||||
|
@tags("clipskip", "clip", "skip")
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
type: Literal["clip_skip"] = "clip_skip"
|
type: Literal["clip_skip"] = "clip_skip"
|
||||||
|
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||||
self.clip.skipped_layers += self.skipped_layers
|
self.clip.skipped_layers += self.skipped_layers
|
||||||
|
@ -28,77 +28,27 @@ from pydantic import BaseModel, Field, validator
|
|||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType
|
from ...backend.model_management import BaseModelType, ModelType
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
from ..models.image import ImageOutput, PILInvocationConfig
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
Input,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UITypeHint,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
from ..models.image import ImageOutput
|
||||||
|
|
||||||
CONTROLNET_DEFAULT_MODELS = [
|
|
||||||
###########################################
|
|
||||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
|
||||||
##############################################
|
|
||||||
"lllyasviel/sd-controlnet-canny",
|
|
||||||
"lllyasviel/sd-controlnet-depth",
|
|
||||||
"lllyasviel/sd-controlnet-hed",
|
|
||||||
"lllyasviel/sd-controlnet-seg",
|
|
||||||
"lllyasviel/sd-controlnet-openpose",
|
|
||||||
"lllyasviel/sd-controlnet-scribble",
|
|
||||||
"lllyasviel/sd-controlnet-normal",
|
|
||||||
"lllyasviel/sd-controlnet-mlsd",
|
|
||||||
#############################################
|
|
||||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
|
||||||
#############################################
|
|
||||||
"lllyasviel/control_v11p_sd15_canny",
|
|
||||||
"lllyasviel/control_v11p_sd15_openpose",
|
|
||||||
"lllyasviel/control_v11p_sd15_seg",
|
|
||||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
|
||||||
"lllyasviel/control_v11f1p_sd15_depth",
|
|
||||||
"lllyasviel/control_v11p_sd15_normalbae",
|
|
||||||
"lllyasviel/control_v11p_sd15_scribble",
|
|
||||||
"lllyasviel/control_v11p_sd15_mlsd",
|
|
||||||
"lllyasviel/control_v11p_sd15_softedge",
|
|
||||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
|
||||||
"lllyasviel/control_v11p_sd15_lineart",
|
|
||||||
"lllyasviel/control_v11p_sd15_inpaint",
|
|
||||||
# "lllyasviel/control_v11u_sd15_tile",
|
|
||||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
|
||||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
|
||||||
"lllyasviel/control_v11e_sd15_shuffle",
|
|
||||||
"lllyasviel/control_v11e_sd15_ip2p",
|
|
||||||
"lllyasviel/control_v11f1e_sd15_tile",
|
|
||||||
#################################################
|
|
||||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
|
||||||
##################################################
|
|
||||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-canny-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-depth-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-hed-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-color-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
|
||||||
##############################################
|
|
||||||
# ControlNetMediaPipeface, ControlNet v1.1
|
|
||||||
##############################################
|
|
||||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
|
||||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
|
||||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
|
||||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
|
||||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
|
||||||
]
|
|
||||||
|
|
||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
|
||||||
CONTROLNET_RESIZE_VALUES = Literal[
|
CONTROLNET_RESIZE_VALUES = Literal[
|
||||||
tuple(
|
"just_resize",
|
||||||
[
|
"crop_resize",
|
||||||
"just_resize",
|
"fill_resize",
|
||||||
"crop_resize",
|
"just_resize_simple",
|
||||||
"fill_resize",
|
|
||||||
"just_resize_simple",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -110,9 +60,8 @@ class ControlNetModelField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
@ -135,60 +84,39 @@ class ControlField(BaseModel):
|
|||||||
raise ValueError("Control weights must be within -1 to 2 range")
|
raise ValueError("Control weights must be within -1 to 2 range")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
|
||||||
"ui": {
|
|
||||||
"type_hints": {
|
|
||||||
"control_weight": "float",
|
|
||||||
"control_model": "controlnet_model",
|
|
||||||
# "control_weight": "number",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ControlOutput(BaseInvocationOutput):
|
class ControlOutput(BaseInvocationOutput):
|
||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["control_output"] = "control_output"
|
type: Literal["control_output"] = "control_output"
|
||||||
control: ControlField = Field(default=None, description="The control info")
|
|
||||||
# fmt: on
|
# Outputs
|
||||||
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||||
|
|
||||||
|
|
||||||
|
@title("ControlNet")
|
||||||
|
@tags("controlnet")
|
||||||
class ControlNetInvocation(BaseInvocation):
|
class ControlNetInvocation(BaseInvocation):
|
||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["controlnet"] = "controlnet"
|
type: Literal["controlnet"] = "controlnet"
|
||||||
# Inputs
|
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
|
||||||
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
|
||||||
description="control model used")
|
|
||||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
|
||||||
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
|
||||||
description="When the ControlNet is first applied (% of total steps)")
|
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
image: ImageField = InputField(description="The control image")
|
||||||
"ui": {
|
control_model: ControlNetModelField = InputField(
|
||||||
"title": "ControlNet",
|
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||||
"tags": ["controlnet", "latents"],
|
)
|
||||||
"type_hints": {
|
control_weight: Union[float, List[float]] = InputField(
|
||||||
"model": "model",
|
default=1.0, description="The weight given to the ControlNet", ui_type_hint=UITypeHint.Float
|
||||||
"control": "control",
|
)
|
||||||
# "cfg_scale": "float",
|
begin_step_percent: float = InputField(
|
||||||
"cfg_scale": "number",
|
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||||
"control_weight": "float",
|
)
|
||||||
},
|
end_step_percent: float = InputField(
|
||||||
},
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||||
}
|
)
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||||
return ControlOutput(
|
return ControlOutput(
|
||||||
@ -204,19 +132,13 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageProcessorInvocation(BaseInvocation):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_processor"] = "image_processor"
|
type: Literal["image_processor"] = "image_processor"
|
||||||
# Inputs
|
|
||||||
image: ImageField = Field(default=None, description="The image to process")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
image: ImageField = InputField(description="The image to process")
|
||||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# superclass just passes through image without processing
|
# superclass just passes through image without processing
|
||||||
@ -255,20 +177,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Canny Processor")
|
||||||
|
@tags("controlnet", "canny")
|
||||||
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||||
# Input
|
|
||||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
|
||||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Input
|
||||||
schema_extra = {
|
low_threshold: int = InputField(
|
||||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||||
}
|
)
|
||||||
|
high_threshold: int = InputField(
|
||||||
|
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||||
|
)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
canny_processor = CannyDetector()
|
canny_processor = CannyDetector()
|
||||||
@ -276,23 +198,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("HED (softedge) Processor")
|
||||||
|
@tags("controlnet", "hed", "softedge")
|
||||||
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# safe not supported in controlnet_aux v0.0.3
|
|
||||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
|
||||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
|
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -307,21 +225,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Lineart Processor")
|
||||||
|
@tags("controlnet", "lineart")
|
||||||
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -331,23 +245,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Lineart Anime Processor")
|
||||||
|
@tags("controlnet", "lineart", "anime")
|
||||||
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
"title": "Lineart Anime Processor",
|
|
||||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -359,21 +266,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Openpose Processor")
|
||||||
|
@tags("controlnet", "openpose", "pose")
|
||||||
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Openpose processing to image"""
|
"""Applies Openpose processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||||
# Inputs
|
|
||||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
||||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
}
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -386,22 +289,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Midas (Depth) Processor")
|
||||||
|
@tags("controlnet", "midas", "depth")
|
||||||
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||||
# Inputs
|
|
||||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
|
||||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
|
||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
|
||||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||||
}
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -415,20 +314,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Normal BAE Processor")
|
||||||
|
@tags("controlnet", "normal", "bae")
|
||||||
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -438,22 +333,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("MLSD Processor")
|
||||||
|
@tags("controlnet", "mlsd")
|
||||||
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
|
||||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -467,22 +358,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("PIDI Processor")
|
||||||
|
@tags("controlnet", "pidi")
|
||||||
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
|
||||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -496,26 +383,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Content Shuffle Processor")
|
||||||
|
@tags("controlnet", "contentshuffle")
|
||||||
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
|
||||||
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
|
||||||
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
"title": "Content Shuffle Processor",
|
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
},
|
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
@ -531,17 +411,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
|||||||
|
|
||||||
|
|
||||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Zoe (Depth) Processor")
|
||||||
|
@tags("controlnet", "zoe", "depth")
|
||||||
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -549,20 +424,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Mediapipe Face Processor")
|
||||||
|
@tags("controlnet", "mediapipe", "face")
|
||||||
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||||
# Inputs
|
|
||||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
|
||||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||||
@ -574,23 +445,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Leres (Depth) Processor")
|
||||||
|
@tags("controlnet", "leres", "depth")
|
||||||
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||||
# Inputs
|
|
||||||
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
|
|
||||||
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
|
|
||||||
boost: bool = Field(default=False, description="Whether to use boost mode")
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||||
}
|
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||||
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -605,21 +472,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Tile Resample Processor")
|
||||||
# fmt: off
|
@tags("controlnet", "tile")
|
||||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||||
# Inputs
|
"""Tile resampler processor"""
|
||||||
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
|
||||||
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
# Inputs
|
||||||
"title": "Tile Resample Processor",
|
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||||
def tile_resample(
|
def tile_resample(
|
||||||
@ -648,20 +510,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Segment Anything Processor")
|
||||||
|
@tags("controlnet", "segmentanything")
|
||||||
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Segment Anything Processor",
|
|
||||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
|
@ -5,40 +5,22 @@ from typing import Literal
|
|||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class CvInvocationConfig(BaseModel):
|
@title("OpenCV Inpaint")
|
||||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
@tags("opencv", "inpaint")
|
||||||
|
class CvInpaintInvocation(BaseInvocation):
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["cv", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
image: ImageField = InputField(description="The image to inpaint")
|
||||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -1,37 +1,30 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||||
|
|
||||||
|
|
||||||
|
@title("Load Image")
|
||||||
|
@tags("image")
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
"""Load an image and provide it as output."""
|
"""Load an image and provide it as output."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["load_image"] = "load_image"
|
type: Literal["load_image"] = "load_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(
|
image: ImageField = InputField(description="The image to load")
|
||||||
default=None, description="The image to load"
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -43,18 +36,16 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Show Image")
|
||||||
|
@tags("image")
|
||||||
class ShowImageInvocation(BaseInvocation):
|
class ShowImageInvocation(BaseInvocation):
|
||||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||||
|
|
||||||
|
# Metadata
|
||||||
type: Literal["show_image"] = "show_image"
|
type: Literal["show_image"] = "show_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
image: ImageField = InputField(description="The image to show")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -70,24 +61,20 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Crop Image")
|
||||||
|
@tags("image", "crop")
|
||||||
|
class ImageCropInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_crop"] = "img_crop"
|
type: Literal["img_crop"] = "img_crop"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to crop")
|
image: ImageField = InputField(description="The image to crop")
|
||||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
|
||||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
|
||||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
|
||||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -111,24 +98,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Paste Image")
|
||||||
|
@tags("image", "paste")
|
||||||
|
class ImagePasteInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_paste"] = "img_paste"
|
type: Literal["img_paste"] = "img_paste"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
base_image: Optional[ImageField] = Field(default=None, description="The base image")
|
base_image: ImageField = InputField(description="The base image")
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to paste")
|
image: ImageField = InputField(description="The image to paste")
|
||||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
mask: Optional[ImageField] = InputField(
|
||||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
default=None,
|
||||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
description="The mask to use when pasting",
|
||||||
# fmt: on
|
)
|
||||||
|
x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
|
||||||
class Config(InvocationConfig):
|
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||||
@ -164,21 +150,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Mask from Alpha")
|
||||||
|
@tags("image", "mask")
|
||||||
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["tomask"] = "tomask"
|
type: Literal["tomask"] = "tomask"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
|
image: ImageField = InputField(description="The image to create the mask from")
|
||||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -203,21 +185,17 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Multiply Images")
|
||||||
|
@tags("image", "multiply")
|
||||||
|
class ImageMultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_mul"] = "img_mul"
|
type: Literal["img_mul"] = "img_mul"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
|
image1: ImageField = InputField(description="The first image to multiply")
|
||||||
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
|
image2: ImageField = InputField(description="The second image to multiply")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||||
@ -244,21 +222,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Extract Image Channel")
|
||||||
|
@tags("image", "channel")
|
||||||
|
class ImageChannelInvocation(BaseInvocation):
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_chan"] = "img_chan"
|
type: Literal["img_chan"] = "img_chan"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
|
image: ImageField = InputField(description="The image to get the channel from")
|
||||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -284,21 +258,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Convert Image Mode")
|
||||||
|
@tags("image", "convert")
|
||||||
|
class ImageConvertInvocation(BaseInvocation):
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_conv"] = "img_conv"
|
type: Literal["img_conv"] = "img_conv"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to convert")
|
image: ImageField = InputField(description="The image to convert")
|
||||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -321,22 +291,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Blur Image")
|
||||||
|
@tags("image", "blur")
|
||||||
|
class ImageBlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_blur"] = "img_blur"
|
type: Literal["img_blur"] = "img_blur"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to blur")
|
image: ImageField = InputField(description="The image to blur")
|
||||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
|
||||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
# Metadata
|
||||||
# fmt: on
|
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -382,23 +349,19 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Resize Image")
|
||||||
|
@tags("image", "resize")
|
||||||
|
class ImageResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_resize"] = "img_resize"
|
type: Literal["img_resize"] = "img_resize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to resize")
|
image: ImageField = InputField(description="The image to resize")
|
||||||
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -426,22 +389,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Scale Image")
|
||||||
|
@tags("image", "scale")
|
||||||
|
class ImageScaleInvocation(BaseInvocation):
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_scale"] = "img_scale"
|
type: Literal["img_scale"] = "img_scale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
image: ImageField = InputField(description="The image to scale")
|
||||||
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
scale_factor: float = InputField(
|
||||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
default=2.0,
|
||||||
# fmt: on
|
gt=0,
|
||||||
|
description="The factor by which to scale the image",
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -471,22 +434,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Lerp Image")
|
||||||
|
@tags("image", "lerp")
|
||||||
|
class ImageLerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_lerp"] = "img_lerp"
|
type: Literal["img_lerp"] = "img_lerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -512,25 +471,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Inverse Lerp Image")
|
||||||
|
@tags("image", "ilerp")
|
||||||
|
class ImageInverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_ilerp"] = "img_ilerp"
|
type: Literal["img_ilerp"] = "img_ilerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Image Inverse Linear Interpolation",
|
|
||||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -556,21 +508,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Blur NSFW Image")
|
||||||
|
@tags("image", "nsfw")
|
||||||
|
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_nsfw"] = "img_nsfw"
|
type: Literal["img_nsfw"] = "img_nsfw"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
# fmt: on
|
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -607,22 +557,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return caution.resize((caution.width // 2, caution.height // 2))
|
return caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
|
||||||
|
|
||||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Add Invisible Watermark")
|
||||||
|
@tags("image", "watermark")
|
||||||
|
class ImageWatermarkInvocation(BaseInvocation):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_watermark"] = "img_watermark"
|
type: Literal["img_watermark"] = "img_watermark"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
text: str = Field(default='InvokeAI', description="Watermark text")
|
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
# fmt: on
|
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -644,19 +592,21 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Mask Edge")
|
||||||
|
@tags("image", "mask", "inpaint")
|
||||||
|
class MaskEdgeInvocation(BaseInvocation):
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mask_edge"] = "mask_edge"
|
type: Literal["mask_edge"] = "mask_edge"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to")
|
image: ImageField = InputField(description="The image to apply the mask to")
|
||||||
edge_size: int = Field(description="The size of the edge")
|
edge_size: int = InputField(description="The size of the edge")
|
||||||
edge_blur: int = Field(description="The amount of blur on the edge")
|
edge_blur: int = InputField(description="The amount of blur on the edge")
|
||||||
low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection")
|
low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
|
||||||
high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection")
|
high_threshold: int = InputField(
|
||||||
# fmt: on
|
description="Second threshold for the hysteresis procedure in Canny edge detection"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
mask = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -690,21 +640,16 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Combine Mask")
|
||||||
|
@tags("image", "mask", "multiply")
|
||||||
|
class MaskCombineInvocation(BaseInvocation):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mask_combine"] = "mask_combine"
|
type: Literal["mask_combine"] = "mask_combine"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
mask1: ImageField = Field(default=None, description="The first mask to combine")
|
mask1: ImageField = InputField(description="The first mask to combine")
|
||||||
mask2: ImageField = Field(default=None, description="The second image to combine")
|
mask2: ImageField = InputField(description="The second image to combine")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
||||||
@ -728,7 +673,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Color Correct")
|
||||||
|
@tags("image", "color")
|
||||||
|
class ColorCorrectInvocation(BaseInvocation):
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
using a mask to only color-correct certain regions of the target image.
|
using a mask to only color-correct certain regions of the target image.
|
||||||
@ -736,10 +683,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
type: Literal["color_correct"] = "color_correct"
|
type: Literal["color_correct"] = "color_correct"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
|
# Inputs
|
||||||
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
|
image: ImageField = InputField(description="The image to color-correct")
|
||||||
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
|
reference: ImageField = InputField(description="Reference image for color-correction")
|
||||||
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
|
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
||||||
|
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_init_mask = None
|
pil_init_mask = None
|
||||||
@ -833,16 +781,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Hue Adjustment")
|
||||||
|
@tags("image", "hue", "hsl")
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
|
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -877,16 +825,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Luminosity Adjustment")
|
||||||
|
@tags("image", "luminosity", "hsl")
|
||||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Luminosity (Value) of an image."""
|
"""Adjusts the Luminosity (Value) of an image."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
|
luminosity: float = InputField(
|
||||||
# fmt: on
|
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -925,16 +875,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Saturation Adjustment")
|
||||||
|
@tags("image", "saturation", "hsl")
|
||||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Saturation of an image."""
|
"""Adjusts the Saturation of an image."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.image import ImageOutput
|
from invokeai.app.invocations.image import ImageOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UITypeHint, title, tags
|
||||||
BaseInvocation,
|
|
||||||
InvocationConfig,
|
|
||||||
InvocationContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
def infill_methods() -> list[str]:
|
||||||
@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return si
|
return si
|
||||||
|
|
||||||
|
|
||||||
|
@title("Solid Color Infill")
|
||||||
|
@tags("image", "inpaint")
|
||||||
class InfillColorInvocation(BaseInvocation):
|
class InfillColorInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
type: Literal["infill_rgba"] = "infill_rgba"
|
type: Literal["infill_rgba"] = "infill_rgba"
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
|
||||||
color: ColorField = Field(
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
color: ColorField = InputField(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Tile Infill")
|
||||||
|
@tags("image", "inpaint")
|
||||||
class InfillTileInvocation(BaseInvocation):
|
class InfillTileInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
type: Literal["infill_tile"] = "infill_tile"
|
type: Literal["infill_tile"] = "infill_tile"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
# Input
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
seed: int = Field(
|
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||||
|
seed: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=SEED_MAX,
|
le=SEED_MAX,
|
||||||
description="The seed to use for tile generation (omit for random)",
|
description="The seed to use for tile generation (omit for random)",
|
||||||
default_factory=get_random_seed,
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("PatchMatch Infill")
|
||||||
|
@tags("image", "inpaint")
|
||||||
class InfillPatchMatchInvocation(BaseInvocation):
|
class InfillPatchMatchInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The image to infill")
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -13,7 +13,8 @@ from diffusers.models.attention_processor import (
|
|||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
)
|
)
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
@ -23,6 +24,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
from ...backend.model_management import BaseModelType, ModelPatcher
|
||||||
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
@ -32,9 +34,20 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UITypeHint,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
@ -46,8 +59,8 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
"""A latents field used for passing latents between invocations"""
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
latents_name: str = Field(description="The name of the latents")
|
||||||
seed: Optional[int] = Field(description="Seed used to generate this latents")
|
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["latents_name"]}
|
schema_extra = {"required": ["latents_name"]}
|
||||||
@ -56,14 +69,14 @@ class LatentsField(BaseModel):
|
|||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["latents_output"] = "latents_output"
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
latents: LatentsField = OutputField(
|
||||||
width: int = Field(description="The width of the latents in pixels")
|
description=FieldDescriptions.latents,
|
||||||
height: int = Field(description="The height of the latents in pixels")
|
)
|
||||||
# fmt: on
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
|
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
|
||||||
@ -111,30 +124,36 @@ def get_scheduler(
|
|||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@title("Denoise Latents")
|
||||||
|
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
type: Literal["denoise_latents"] = "denoise_latents"
|
type: Literal["denoise_latents"] = "denoise_latents"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
positive_conditioning: ConditioningField = InputField(
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
|
||||||
cfg_scale: Union[float, List[float]] = Field(
|
|
||||||
default=7.5,
|
|
||||||
ge=1,
|
|
||||||
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
|
||||||
)
|
)
|
||||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
negative_conditioning: ConditioningField = InputField(
|
||||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
)
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
mask: Optional[ImageField] = Field(
|
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type_hint=UITypeHint.Float
|
||||||
None,
|
)
|
||||||
description="Mask",
|
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||||
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
||||||
|
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||||
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
|
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||||
|
)
|
||||||
|
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||||
|
mask: Optional[ImageField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -149,20 +168,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Denoise Latents",
|
|
||||||
"tags": ["denoise", "latents"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
|
||||||
"control": "control",
|
|
||||||
"cfg_scale": "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self,
|
||||||
@ -474,29 +479,29 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
@title("Latents to Image")
|
||||||
|
@tags("latents", "image", "vae")
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
type: Literal["l2i"] = "l2i"
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: LatentsField = InputField(
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
description=FieldDescriptions.latents,
|
||||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
input=Input.Connection,
|
||||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
)
|
||||||
metadata: Optional[CoreMetadata] = Field(
|
vae: VaeField = InputField(
|
||||||
default=None, description="Optional core metadata to be written to the image"
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
|
metadata: CoreMetadata = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.core_metadata,
|
||||||
|
ui_hidden=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Latents To Image",
|
|
||||||
"tags": ["latents", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -574,24 +579,30 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
|
@title("Resize Latents")
|
||||||
|
@tags("latents", "resize")
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
type: Literal["lresize"] = "lresize"
|
type: Literal["lresize"] = "lresize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
latents: LatentsField = InputField(
|
||||||
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
description=FieldDescriptions.latents,
|
||||||
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
input=Input.Connection,
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: bool = Field(
|
|
||||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
|
||||||
)
|
)
|
||||||
|
width: int = InputField(
|
||||||
class Config(InvocationConfig):
|
ge=64,
|
||||||
schema_extra = {
|
multiple_of=8,
|
||||||
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
description=FieldDescriptions.width,
|
||||||
}
|
)
|
||||||
|
height: int = InputField(
|
||||||
|
ge=64,
|
||||||
|
multiple_of=8,
|
||||||
|
description=FieldDescriptions.width,
|
||||||
|
)
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -616,23 +627,21 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Scale Latents")
|
||||||
|
@tags("latents", "resize")
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
type: Literal["lscale"] = "lscale"
|
type: Literal["lscale"] = "lscale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
latents: LatentsField = InputField(
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
description=FieldDescriptions.latents,
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
input=Input.Connection,
|
||||||
antialias: bool = Field(
|
|
||||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
|
||||||
)
|
)
|
||||||
|
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
||||||
class Config(InvocationConfig):
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
schema_extra = {
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -658,22 +667,23 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image to Latents")
|
||||||
|
@tags("latents", "image", "vae")
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
type: Literal["i2l"] = "i2l"
|
type: Literal["i2l"] = "i2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(description="The image to encode")
|
image: ImageField = InputField(
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
description="The image to encode",
|
||||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
)
|
||||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
vae: VaeField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
# Schema customisation
|
input=Input.Connection,
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
@ -2,134 +2,104 @@
|
|||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
InvocationConfig,
|
OutputField,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MathInvocationConfig(BaseModel):
|
|
||||||
"""Helper class to provide all math invocations with additional config"""
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["math"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class IntOutput(BaseInvocationOutput):
|
class IntOutput(BaseInvocationOutput):
|
||||||
"""An integer output"""
|
"""An integer output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["int_output"] = "int_output"
|
type: Literal["int_output"] = "int_output"
|
||||||
a: int = Field(default=None, description="The output integer")
|
a: int = OutputField(default=None, description="The output integer")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class FloatOutput(BaseInvocationOutput):
|
class FloatOutput(BaseInvocationOutput):
|
||||||
"""A float output"""
|
"""A float output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["float_output"] = "float_output"
|
type: Literal["float_output"] = "float_output"
|
||||||
param: float = Field(default=None, description="The output float")
|
a: float = OutputField(default=None, description="The output float")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
@title("Add Integers")
|
||||||
|
@tags("math")
|
||||||
|
class AddInvocation(BaseInvocation):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["add"] = "add"
|
type: Literal["add"] = "add"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a + self.b)
|
return IntOutput(a=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
@title("Subtract Integers")
|
||||||
|
@tags("math")
|
||||||
|
class SubtractInvocation(BaseInvocation):
|
||||||
"""Subtracts two numbers"""
|
"""Subtracts two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sub"] = "sub"
|
type: Literal["sub"] = "sub"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a - self.b)
|
return IntOutput(a=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
@title("Multiply Integers")
|
||||||
|
@tags("math")
|
||||||
|
class MultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two numbers"""
|
"""Multiplies two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mul"] = "mul"
|
type: Literal["mul"] = "mul"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a * self.b)
|
return IntOutput(a=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
@title("Divide Integers")
|
||||||
|
@tags("math")
|
||||||
|
class DivideInvocation(BaseInvocation):
|
||||||
"""Divides two numbers"""
|
"""Divides two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["div"] = "div"
|
type: Literal["div"] = "div"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=int(self.a / self.b))
|
return IntOutput(a=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
|
@title("Random Integer")
|
||||||
|
@tags("math")
|
||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["rand_int"] = "rand_int"
|
type: Literal["rand_int"] = "rand_int"
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
|
||||||
high: int = Field(
|
|
||||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
low: int = InputField(default=0, description="The inclusive low value")
|
||||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||||
|
@ -1,18 +1,21 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ...version import __version__
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationConfig,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
from ...version import __version__
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModelExcludeNull):
|
class LoRAMetadataField(BaseModelExcludeNull):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
@ -43,37 +46,37 @@ class CoreMetadata(BaseModelExcludeNull):
|
|||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
vae: Union[VAEModelField, None] = Field(
|
vae: Optional[VAEModelField] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Latents-to-Latents
|
# Latents-to-Latents
|
||||||
strength: Union[float, None] = Field(
|
strength: Optional[float] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
||||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
refiner_cfg_scale: Optional[float] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||||
refiner_positive_aesthetic_store: Union[float, None] = Field(
|
refiner_positive_aesthetic_store: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_store: Union[float, None] = Field(
|
refiner_negative_aesthetic_store: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModelExcludeNull):
|
class ImageMetadata(BaseModelExcludeNull):
|
||||||
@ -94,66 +97,83 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
|||||||
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
||||||
|
|
||||||
|
|
||||||
|
@title("Metadata Accumulator")
|
||||||
|
@tags("metadata")
|
||||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||||
"""Outputs a Core Metadata Object"""
|
"""Outputs a Core Metadata Object"""
|
||||||
|
|
||||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||||
|
|
||||||
generation_mode: str = Field(
|
generation_mode: str = InputField(
|
||||||
description="The generation mode that output this image",
|
description="The generation mode that output this image",
|
||||||
)
|
)
|
||||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
positive_prompt: str = InputField(description="The positive prompt parameter")
|
||||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
negative_prompt: str = InputField(description="The negative prompt parameter")
|
||||||
width: int = Field(description="The width parameter")
|
width: int = InputField(description="The width parameter")
|
||||||
height: int = Field(description="The height parameter")
|
height: int = InputField(description="The height parameter")
|
||||||
seed: int = Field(description="The seed used for noise generation")
|
seed: int = InputField(description="The seed used for noise generation")
|
||||||
rand_device: str = Field(description="The device used for random number generation")
|
rand_device: str = InputField(description="The device used for random number generation")
|
||||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||||
steps: int = Field(description="The number of steps used for inference")
|
steps: int = InputField(description="The number of steps used for inference")
|
||||||
scheduler: str = Field(description="The scheduler used for inference")
|
scheduler: str = InputField(description="The scheduler used for inference")
|
||||||
clip_skip: int = Field(
|
clip_skip: int = InputField(
|
||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = InputField(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||||
strength: Union[float, None] = Field(
|
strength: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
init_image: Optional[str] = InputField(
|
||||||
vae: Union[VAEModelField, None] = Field(
|
default=None,
|
||||||
|
description="The name of the initial image",
|
||||||
|
)
|
||||||
|
vae: Optional[VAEModelField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
positive_style_prompt: Optional[str] = InputField(
|
||||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
default=None,
|
||||||
|
description="The positive style prompt parameter",
|
||||||
|
)
|
||||||
|
negative_style_prompt: Optional[str] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The negative style prompt parameter",
|
||||||
|
)
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
refiner_model: Optional[MainModelField] = InputField(
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
default=None,
|
||||||
|
description="The SDXL Refiner model used",
|
||||||
|
)
|
||||||
|
refiner_cfg_scale: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Optional[int] = InputField(
|
||||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
default=None,
|
||||||
refiner_positive_aesthetic_score: Union[float, None] = Field(
|
description="The number of steps used for the refiner",
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_score: Union[float, None] = Field(
|
refiner_scheduler: Optional[str] = InputField(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None,
|
||||||
|
description="The scheduler used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The aesthetic score used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The aesthetic score used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_start: Optional[float] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The start value used for refiner denoising",
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Metadata Accumulator",
|
|
||||||
"tags": ["image", "metadata", "generation"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
Input,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UITypeHint,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
@ -39,13 +50,11 @@ class VaeField(BaseModel):
|
|||||||
class ModelLoaderOutput(BaseInvocationOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["model_loader_output"] = "model_loader_output"
|
type: Literal["model_loader_output"] = "model_loader_output"
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
@ -63,24 +72,17 @@ class LoRAModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
|
@title("Main Model Loader")
|
||||||
|
@tags("model")
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["main_model_loader"] = "main_model_loader"
|
type: Literal["main_model_loader"] = "main_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Model Loader",
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
loras=[],
|
loras=[],
|
||||||
skipped_layers=0,
|
skipped_layers=0,
|
||||||
),
|
),
|
||||||
clip2=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Tokenizer2,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.TextEncoder2,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@title("LoRA Loader")
|
||||||
|
@tags("lora", "model")
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
type: Literal["lora_loader"] = "lora_loader"
|
type: Literal["lora_loader"] = "lora_loader"
|
||||||
|
|
||||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
# Inputs
|
||||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
unet: Optional[UNetField] = InputField(
|
||||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
clip: Optional[ClipField] = InputField(
|
||||||
schema_extra = {
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
||||||
"ui": {
|
)
|
||||||
"title": "Lora Loader",
|
|
||||||
"tags": ["lora", "loader"],
|
|
||||||
"type_hints": {"lora": "lora_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
if self.lora is None:
|
if self.lora is None:
|
||||||
@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""SDXL LoRA Loader Output"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL LoRA Loader")
|
||||||
|
@tags("sdxl", "lora", "model")
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||||
|
|
||||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
|
unet: Optional[UNetField] = Field(
|
||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
||||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
)
|
||||||
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
clip: Optional[ClipField] = Field(
|
||||||
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
clip2: Optional[ClipField] = Field(
|
||||||
"ui": {
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||||
"title": "SDXL Lora Loader",
|
)
|
||||||
"tags": ["lora", "loader"],
|
|
||||||
"type_hints": {"lora": "lora_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
if self.lora is None:
|
if self.lora is None:
|
||||||
@ -369,29 +350,23 @@ class VAEModelField(BaseModel):
|
|||||||
class VaeLoaderOutput(BaseInvocationOutput):
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||||
|
|
||||||
vae: VaeField = Field(default=None, description="Vae model")
|
# Outputs
|
||||||
# fmt: on
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
|
@title("VAE Loader")
|
||||||
|
@tags("vae", "model")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
type: Literal["vae_loader"] = "vae_loader"
|
type: Literal["vae_loader"] = "vae_loader"
|
||||||
|
|
||||||
vae_model: VAEModelField = Field(description="The VAE to load")
|
# Inputs
|
||||||
|
vae_model: VAEModelField = InputField(
|
||||||
# Schema customisation
|
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type_hint=UITypeHint.VaeModelField, title="VAE"
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "VAE Loader",
|
|
||||||
"tags": ["vae", "loader"],
|
|
||||||
"type_hints": {"vae_model": "vae_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
base_model = self.vae_model.base_model
|
base_model = self.vae_model.base_model
|
||||||
|
@ -1,19 +1,24 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field, validator
|
|
||||||
import torch
|
import torch
|
||||||
from invokeai.app.invocations.latent import LatentsField
|
from pydantic import validator
|
||||||
|
|
||||||
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationConfig,
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UITypeHint,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -61,14 +66,12 @@ Nodes
|
|||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
|
|
||||||
# fmt: off
|
type: Literal["noise_output"] = "noise_output"
|
||||||
type: Literal["noise_output"] = "noise_output"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
noise: LatentsField = Field(default=None, description="The output noise")
|
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||||
width: int = Field(description="The width of the noise in pixels")
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = Field(description="The height of the noise in pixels")
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||||
@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Noise")
|
||||||
|
@tags("latents", "noise")
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
seed: int = Field(
|
seed: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=SEED_MAX,
|
le=SEED_MAX,
|
||||||
description="The seed to use",
|
description=FieldDescriptions.seed,
|
||||||
default_factory=get_random_seed,
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
width: int = Field(
|
width: int = InputField(
|
||||||
default=512,
|
default=512,
|
||||||
multiple_of=8,
|
multiple_of=8,
|
||||||
gt=0,
|
gt=0,
|
||||||
description="The width of the resulting noise",
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
height: int = Field(
|
height: int = InputField(
|
||||||
default=512,
|
default=512,
|
||||||
multiple_of=8,
|
multiple_of=8,
|
||||||
gt=0,
|
gt=0,
|
||||||
description="The height of the resulting noise",
|
description=FieldDescriptions.height,
|
||||||
)
|
)
|
||||||
use_cpu: bool = Field(
|
use_cpu: bool = InputField(
|
||||||
default=True,
|
default=True,
|
||||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Noise",
|
|
||||||
"tags": ["latents", "noise"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("seed", pre=True)
|
@validator("seed", pre=True)
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
|
@ -1,37 +1,44 @@
|
|||||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import re
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from tqdm import tqdm
|
||||||
from ...backend.model_management import ONNXModelPatcher
|
|
||||||
from ...backend.util import choose_torch_device
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
|
||||||
from .compel import ConditioningField
|
|
||||||
from .controlnet_image_processors import ControlField
|
|
||||||
from .image import ImageOutput
|
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
from ...backend.model_management import ONNXModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from ...backend.util import choose_torch_device
|
||||||
from tqdm import tqdm
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .model import ClipField
|
from .baseinvocation import (
|
||||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
BaseInvocation,
|
||||||
from .compel import CompelOutput
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
Input,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
UITypeHint,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
from .compel import CompelOutput, ConditioningField
|
||||||
|
from .controlnet_image_processors import ControlField
|
||||||
|
from .image import ImageOutput
|
||||||
|
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||||
|
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
ORT_TO_NP_TYPE = {
|
ORT_TO_NP_TYPE = {
|
||||||
"tensor(bool)": np.bool_,
|
"tensor(bool)": np.bool_,
|
||||||
@ -51,11 +58,13 @@ ORT_TO_NP_TYPE = {
|
|||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||||
|
|
||||||
|
|
||||||
|
@title("ONNX Prompt (Raw)")
|
||||||
|
@tags("onnx", "prompt")
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
@ -134,25 +143,48 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
|
@title("ONNX Text to Latents")
|
||||||
|
@tags("latents", "inference", "txt2img", "onnx")
|
||||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
positive_conditioning: ConditioningField = InputField(
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
description=FieldDescriptions.positive_cond,
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
input=Input.Connection,
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
)
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
negative_conditioning: ConditioningField = InputField(
|
||||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
description=FieldDescriptions.negative_cond,
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
input=Input.Connection,
|
||||||
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
)
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
noise: LatentsField = InputField(
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
description=FieldDescriptions.noise,
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
input=Input.Connection,
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
)
|
||||||
# fmt: on
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
|
default=7.5,
|
||||||
|
ge=1,
|
||||||
|
description=FieldDescriptions.cfg_scale,
|
||||||
|
ui_type_hint=UITypeHint.Float,
|
||||||
|
)
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
|
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
||||||
|
)
|
||||||
|
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||||
|
unet: UNetField = InputField(
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.control,
|
||||||
|
ui_type_hint=UITypeHint.ControlField,
|
||||||
|
)
|
||||||
|
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
|
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
def ge_one(cls, v):
|
def ge_one(cls, v):
|
||||||
@ -166,20 +198,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
|
||||||
"control": "control",
|
|
||||||
# "cfg_scale": "float",
|
|
||||||
"cfg_scale": "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# based on
|
# based on
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
@ -300,26 +318,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
|
@title("ONNX Latents to Image")
|
||||||
|
@tags("latents", "image", "vae", "onnx")
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: LatentsField = InputField(
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
description=FieldDescriptions.denoised_latents,
|
||||||
metadata: Optional[CoreMetadata] = Field(
|
input=Input.Connection,
|
||||||
default=None, description="Optional core metadata to be written to the image"
|
|
||||||
)
|
)
|
||||||
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
vae: VaeField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
# Schema customisation
|
input=Input.Connection,
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
"ui": {
|
default=None,
|
||||||
"tags": ["latents", "image"],
|
description=FieldDescriptions.core_metadata,
|
||||||
},
|
ui_hidden=True,
|
||||||
}
|
)
|
||||||
|
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -373,89 +393,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loading submodels of selected model."""
|
|
||||||
|
|
||||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
|
||||||
model_name = "stable-diffusion-v1-5"
|
|
||||||
base_model = BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=BaseModelType.StableDiffusion1,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unkown model name: {model_name}!")
|
|
||||||
|
|
||||||
return ONNXModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
vae_decoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.VaeDecoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
vae_encoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.VaeEncoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModelField(BaseModel):
|
class OnnxModelField(BaseModel):
|
||||||
"""Onnx model field"""
|
"""Onnx model field"""
|
||||||
|
|
||||||
@ -464,22 +408,17 @@ class OnnxModelField(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
|
@title("ONNX Model Loader")
|
||||||
|
@tags("onnx", "model")
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||||
|
|
||||||
model: OnnxModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: OnnxModelField = InputField(
|
||||||
# Schema customisation
|
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type_hint=UITypeHint.ONNXModelField
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Onnx Model Loader",
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
|
@ -1,73 +1,63 @@
|
|||||||
import io
|
import io
|
||||||
from typing import Literal, Optional, Any
|
from typing import Literal, Optional
|
||||||
|
|
||||||
# from PIL.Image import Image
|
|
||||||
import PIL.Image
|
|
||||||
from matplotlib.ticker import MaxNLocator
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
from easing_functions import (
|
from easing_functions import (
|
||||||
LinearInOut,
|
|
||||||
QuadEaseInOut,
|
|
||||||
QuadEaseIn,
|
|
||||||
QuadEaseOut,
|
|
||||||
CubicEaseInOut,
|
|
||||||
CubicEaseIn,
|
|
||||||
CubicEaseOut,
|
|
||||||
QuarticEaseInOut,
|
|
||||||
QuarticEaseIn,
|
|
||||||
QuarticEaseOut,
|
|
||||||
QuinticEaseInOut,
|
|
||||||
QuinticEaseIn,
|
|
||||||
QuinticEaseOut,
|
|
||||||
SineEaseInOut,
|
|
||||||
SineEaseIn,
|
|
||||||
SineEaseOut,
|
|
||||||
CircularEaseIn,
|
|
||||||
CircularEaseInOut,
|
|
||||||
CircularEaseOut,
|
|
||||||
ExponentialEaseInOut,
|
|
||||||
ExponentialEaseIn,
|
|
||||||
ExponentialEaseOut,
|
|
||||||
ElasticEaseIn,
|
|
||||||
ElasticEaseInOut,
|
|
||||||
ElasticEaseOut,
|
|
||||||
BackEaseIn,
|
BackEaseIn,
|
||||||
BackEaseInOut,
|
BackEaseInOut,
|
||||||
BackEaseOut,
|
BackEaseOut,
|
||||||
BounceEaseIn,
|
BounceEaseIn,
|
||||||
BounceEaseInOut,
|
BounceEaseInOut,
|
||||||
BounceEaseOut,
|
BounceEaseOut,
|
||||||
|
CircularEaseIn,
|
||||||
|
CircularEaseInOut,
|
||||||
|
CircularEaseOut,
|
||||||
|
CubicEaseIn,
|
||||||
|
CubicEaseInOut,
|
||||||
|
CubicEaseOut,
|
||||||
|
ElasticEaseIn,
|
||||||
|
ElasticEaseInOut,
|
||||||
|
ElasticEaseOut,
|
||||||
|
ExponentialEaseIn,
|
||||||
|
ExponentialEaseInOut,
|
||||||
|
ExponentialEaseOut,
|
||||||
|
LinearInOut,
|
||||||
|
QuadEaseIn,
|
||||||
|
QuadEaseInOut,
|
||||||
|
QuadEaseOut,
|
||||||
|
QuarticEaseIn,
|
||||||
|
QuarticEaseInOut,
|
||||||
|
QuarticEaseOut,
|
||||||
|
QuinticEaseIn,
|
||||||
|
QuinticEaseInOut,
|
||||||
|
QuinticEaseOut,
|
||||||
|
SineEaseIn,
|
||||||
|
SineEaseInOut,
|
||||||
|
SineEaseOut,
|
||||||
)
|
)
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from matplotlib.ticker import MaxNLocator
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
InvocationContext,
|
|
||||||
InvocationConfig,
|
|
||||||
)
|
|
||||||
from ...backend.util.logging import InvokeAILogger
|
from ...backend.util.logging import InvokeAILogger
|
||||||
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||||
from .collections import FloatCollectionOutput
|
from .collections import FloatCollectionOutput
|
||||||
|
|
||||||
|
|
||||||
|
@title("Float Range")
|
||||||
|
@tags("math", "range")
|
||||||
class FloatLinearRangeInvocation(BaseInvocation):
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range"""
|
||||||
|
|
||||||
type: Literal["float_range"] = "float_range"
|
type: Literal["float_range"] = "float_range"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
start: float = Field(default=5, description="The first value of the range")
|
start: float = InputField(default=5, description="The first value of the range")
|
||||||
stop: float = Field(default=10, description="The last value of the range")
|
stop: float = InputField(default=10, description="The last value of the range")
|
||||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
@ -108,37 +98,32 @@ EASING_FUNCTIONS_MAP = {
|
|||||||
"BounceInOut": BounceEaseInOut,
|
"BounceInOut": BounceEaseInOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
|
@title("Step Param Easing")
|
||||||
|
@tags("step", "easing")
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
type: Literal["step_param_easing"] = "step_param_easing"
|
type: Literal["step_param_easing"] = "step_param_easing"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||||
start_value: float = Field(default=0.0, description="easing starting value")
|
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||||
end_value: float = Field(default=1.0, description="easing ending value")
|
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
|
||||||
# if None, then start_value is used prior to easing start
|
# if None, then start_value is used prior to easing start
|
||||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||||
# if None, then end value is used prior to easing end
|
# if None, then end value is used prior to easing end
|
||||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
log_diagnostics = False
|
log_diagnostics = False
|
||||||
|
@ -2,82 +2,80 @@
|
|||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.prompt import PromptOutput
|
from invokeai.app.invocations.prompt import PromptOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
from .math import FloatOutput, IntOutput
|
from .math import FloatOutput, IntOutput
|
||||||
|
|
||||||
# Pass-through parameter nodes - used by subgraphs
|
# Pass-through parameter nodes - used by subgraphs
|
||||||
|
|
||||||
|
|
||||||
|
@title("Integer Parameter")
|
||||||
|
@tags("integer")
|
||||||
class ParamIntInvocation(BaseInvocation):
|
class ParamIntInvocation(BaseInvocation):
|
||||||
"""An integer parameter"""
|
"""An integer parameter"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["param_int"] = "param_int"
|
type: Literal["param_int"] = "param_int"
|
||||||
a: int = Field(default=0, description="The integer value")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description="The integer value")
|
||||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a)
|
return IntOutput(a=self.a)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Float Parameter")
|
||||||
|
@tags("float")
|
||||||
class ParamFloatInvocation(BaseInvocation):
|
class ParamFloatInvocation(BaseInvocation):
|
||||||
"""A float parameter"""
|
"""A float parameter"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["param_float"] = "param_float"
|
type: Literal["param_float"] = "param_float"
|
||||||
param: float = Field(default=0.0, description="The float value")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
param: float = InputField(default=0.0, description="The float value")
|
||||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
return FloatOutput(param=self.param)
|
return FloatOutput(a=self.param)
|
||||||
|
|
||||||
|
|
||||||
class StringOutput(BaseInvocationOutput):
|
class StringOutput(BaseInvocationOutput):
|
||||||
"""A string output"""
|
"""A string output"""
|
||||||
|
|
||||||
type: Literal["string_output"] = "string_output"
|
type: Literal["string_output"] = "string_output"
|
||||||
text: str = Field(default=None, description="The output string")
|
text: str = OutputField(description="The output string")
|
||||||
|
|
||||||
|
|
||||||
|
@title("String Parameter")
|
||||||
|
@tags("string")
|
||||||
class ParamStringInvocation(BaseInvocation):
|
class ParamStringInvocation(BaseInvocation):
|
||||||
"""A string parameter"""
|
"""A string parameter"""
|
||||||
|
|
||||||
type: Literal["param_string"] = "param_string"
|
type: Literal["param_string"] = "param_string"
|
||||||
text: str = Field(default="", description="The string value")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
text: str = InputField(default="", description="The string value")
|
||||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||||
return StringOutput(text=self.text)
|
return StringOutput(text=self.text)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Prompt Parameter")
|
||||||
|
@tags("prompt")
|
||||||
class ParamPromptInvocation(BaseInvocation):
|
class ParamPromptInvocation(BaseInvocation):
|
||||||
"""A prompt input parameter"""
|
"""A prompt input parameter"""
|
||||||
|
|
||||||
type: Literal["param_prompt"] = "param_prompt"
|
type: Literal["param_prompt"] = "param_prompt"
|
||||||
prompt: str = Field(default="", description="The prompt value")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
prompt: str = InputField(default="", description="The prompt value")
|
||||||
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptOutput:
|
def invoke(self, context: InvocationContext) -> PromptOutput:
|
||||||
return PromptOutput(prompt=self.prompt)
|
return PromptOutput(prompt=self.prompt)
|
||||||
|
@ -2,56 +2,52 @@ from os.path import exists
|
|||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field, validator
|
from pydantic import validator
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
UITypeHint,
|
||||||
|
title,
|
||||||
|
tags,
|
||||||
|
)
|
||||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||||
|
|
||||||
|
|
||||||
class PromptOutput(BaseInvocationOutput):
|
class PromptOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a prompt"""
|
"""Base class for invocations that output a prompt"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["prompt"] = "prompt"
|
type: Literal["prompt"] = "prompt"
|
||||||
|
|
||||||
prompt: str = Field(default=None, description="The output prompt")
|
prompt: str = OutputField(description="The output prompt")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"prompt",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PromptCollectionOutput(BaseInvocationOutput):
|
class PromptCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a collection of prompts"""
|
"""Base class for invocations that output a collection of prompts"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
||||||
|
|
||||||
prompt_collection: list[str] = Field(description="The output prompt collection")
|
prompt_collection: list[str] = OutputField(
|
||||||
count: int = Field(description="The size of the prompt collection")
|
description="The output prompt collection", ui_type_hint=UITypeHint.StringCollection
|
||||||
# fmt: on
|
)
|
||||||
|
count: int = OutputField(description="The size of the prompt collection")
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("Dynamic Prompt")
|
||||||
|
@tags("prompt", "collection")
|
||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
|
||||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
|
||||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||||
}
|
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||||
if self.combinatorial:
|
if self.combinatorial:
|
||||||
@ -64,24 +60,23 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
|
||||||
|
|
||||||
|
@title("Prompts from File")
|
||||||
|
@tags("prompt", "file")
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
# fmt: off
|
type: Literal["prompt_from_file"] = "prompt_from_file"
|
||||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
file_path: str = Field(description="Path to prompt text file")
|
file_path: str = InputField(description="Path to prompt text file", ui_type_hint=UITypeHint.FilePath)
|
||||||
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
|
pre_prompt: Optional[str] = InputField(
|
||||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
)
|
||||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
post_prompt: Optional[str] = InputField(
|
||||||
# fmt: on
|
description="String to append to each prompt", ui_component=UIComponent.Textarea
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||||
schema_extra = {
|
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("file_path")
|
@validator("file_path")
|
||||||
def file_path_exists(cls, v):
|
def file_path_exists(cls, v):
|
||||||
|
@ -1,55 +1,55 @@
|
|||||||
import torch
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UITypeHint,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
|
||||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL base model loader output"""
|
"""SDXL base model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL refiner model loader output"""
|
"""SDXL refiner model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
# fmt: on
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Main Model Loader")
|
||||||
|
@tags("model", "sdxl")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: MainModelField = InputField(
|
||||||
|
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type_hint=UITypeHint.SDXLMainModelField
|
||||||
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Model Loader",
|
|
||||||
"tags": ["model", "loader", "sdxl"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Refiner Model Loader")
|
||||||
|
@tags("model", "sdxl", "refiner")
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: MainModelField = InputField(
|
||||||
|
description=FieldDescriptions.sdxl_refiner_model,
|
||||||
|
input=Input.Direct,
|
||||||
|
ui_type_hint=UITypeHint.SDXLRefinerModelField,
|
||||||
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Refiner Model Loader",
|
|
||||||
"tags": ["model", "loader", "sdxl_refiner"],
|
|
||||||
"type_hints": {"model": "refiner_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
|
@ -6,12 +6,11 @@ import cv2 as cv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import Field
|
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
# TODO: Populate this from disk?
|
# TODO: Populate this from disk?
|
||||||
@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@title("Upscale (RealESRGAN)")
|
||||||
|
@tags("esrgan", "upscale")
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
type: Literal["esrgan"] = "esrgan"
|
type: Literal["esrgan"] = "esrgan"
|
||||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
|
||||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
image: ImageField = InputField(description="The input image")
|
||||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -5,14 +5,13 @@ from pydantic import BaseModel, Field
|
|||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from ..invocations.baseinvocation import (
|
from ..invocations.baseinvocation import (
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""An image field used for passing image objects between invocations"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
image_name: str = Field(description="The name of the image")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["image_name"]}
|
schema_extra = {"required": ["image_name"]}
|
||||||
@ -36,17 +35,6 @@ class ProgressImage(BaseModel):
|
|||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
class PILInvocationConfig(BaseModel):
|
|
||||||
"""Helper class to provide all PIL invocations with additional config"""
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["PIL", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
|
|
||||||
|
@ -3,16 +3,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
get_type_hints,
|
|
||||||
)
|
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
@ -22,7 +13,11 @@ from ..invocations import *
|
|||||||
from ..invocations.baseinvocation import (
|
from ..invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UITypeHint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
type: Literal["iterate_output"] = "iterate_output"
|
type: Literal["iterate_output"] = "iterate_output"
|
||||||
|
|
||||||
item: Any = Field(description="The item being iterated over")
|
item: Any = OutputField(
|
||||||
|
description="The item being iterated over", title="Collection Item", ui_type_hint=UITypeHint.CollectionItem
|
||||||
class Config:
|
)
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"item",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["iterate"] = "iterate"
|
type: Literal["iterate"] = "iterate"
|
||||||
|
|
||||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
collection: list[Any] = InputField(
|
||||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
description="The list of items to iterate over", default_factory=list, ui_type_hint=UITypeHint.Collection
|
||||||
|
)
|
||||||
|
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||||
"""Produces the outputs as values"""
|
"""Produces the outputs as values"""
|
||||||
@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation):
|
|||||||
class CollectInvocationOutput(BaseInvocationOutput):
|
class CollectInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal["collect_output"] = "collect_output"
|
type: Literal["collect_output"] = "collect_output"
|
||||||
|
|
||||||
collection: list[Any] = Field(description="The collection of input items")
|
collection: list[Any] = OutputField(
|
||||||
|
description="The collection of input items", title="Collection", ui_type_hint=UITypeHint.Collection
|
||||||
class Config:
|
)
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"collection",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["collect"] = "collect"
|
type: Literal["collect"] = "collect"
|
||||||
|
|
||||||
item: Any = Field(
|
item: Any = InputField(
|
||||||
description="The item to collect (all inputs must be of the same type)",
|
description="The item to collect (all inputs must be of the same type)",
|
||||||
default=None,
|
ui_type_hint=UITypeHint.CollectionItem,
|
||||||
|
title="Collection Item",
|
||||||
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
collection: list[Any] = Field(
|
collection: list[Any] = InputField(
|
||||||
description="The collection, will be provided on execution",
|
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
|
||||||
default_factory=list,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||||
|
@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
outputs = invocation.invoke(
|
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||||
|
# this accomodates nodes which require a value, but get it only from a
|
||||||
|
# connection
|
||||||
|
outputs = invocation.invoke_internal(
|
||||||
InvocationContext(
|
InvocationContext(
|
||||||
services=self.__invoker.services,
|
services=self.__invoker.services,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
@ -49,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
return parse_raw_as(item_type, item)
|
parsed = parse_raw_as(item_type, item)
|
||||||
|
return parsed
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
|
@ -61,6 +61,7 @@
|
|||||||
"@dagrejs/graphlib": "^2.1.13",
|
"@dagrejs/graphlib": "^2.1.13",
|
||||||
"@dnd-kit/core": "^6.0.8",
|
"@dnd-kit/core": "^6.0.8",
|
||||||
"@dnd-kit/modifiers": "^6.0.1",
|
"@dnd-kit/modifiers": "^6.0.1",
|
||||||
|
"@dnd-kit/utilities": "^3.2.1",
|
||||||
"@emotion/react": "^11.11.1",
|
"@emotion/react": "^11.11.1",
|
||||||
"@emotion/styled": "^11.11.0",
|
"@emotion/styled": "^11.11.0",
|
||||||
"@floating-ui/react-dom": "^2.0.1",
|
"@floating-ui/react-dom": "^2.0.1",
|
||||||
|
34
invokeai/frontend/web/scripts/colors.js
Normal file
34
invokeai/frontend/web/scripts/colors.js
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
export const COLORS = {
|
||||||
|
reset: '\x1b[0m',
|
||||||
|
bright: '\x1b[1m',
|
||||||
|
dim: '\x1b[2m',
|
||||||
|
underscore: '\x1b[4m',
|
||||||
|
blink: '\x1b[5m',
|
||||||
|
reverse: '\x1b[7m',
|
||||||
|
hidden: '\x1b[8m',
|
||||||
|
|
||||||
|
fg: {
|
||||||
|
black: '\x1b[30m',
|
||||||
|
red: '\x1b[31m',
|
||||||
|
green: '\x1b[32m',
|
||||||
|
yellow: '\x1b[33m',
|
||||||
|
blue: '\x1b[34m',
|
||||||
|
magenta: '\x1b[35m',
|
||||||
|
cyan: '\x1b[36m',
|
||||||
|
white: '\x1b[37m',
|
||||||
|
gray: '\x1b[90m',
|
||||||
|
crimson: '\x1b[38m',
|
||||||
|
},
|
||||||
|
bg: {
|
||||||
|
black: '\x1b[40m',
|
||||||
|
red: '\x1b[41m',
|
||||||
|
green: '\x1b[42m',
|
||||||
|
yellow: '\x1b[43m',
|
||||||
|
blue: '\x1b[44m',
|
||||||
|
magenta: '\x1b[45m',
|
||||||
|
cyan: '\x1b[46m',
|
||||||
|
white: '\x1b[47m',
|
||||||
|
gray: '\x1b[100m',
|
||||||
|
crimson: '\x1b[48m',
|
||||||
|
},
|
||||||
|
};
|
@ -1,23 +1,83 @@
|
|||||||
import fs from 'node:fs';
|
import fs from 'node:fs';
|
||||||
import openapiTS from 'openapi-typescript';
|
import openapiTS from 'openapi-typescript';
|
||||||
|
import { COLORS } from './colors.js';
|
||||||
|
|
||||||
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
||||||
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
|
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
|
||||||
|
|
||||||
async function main() {
|
async function main() {
|
||||||
process.stdout.write(
|
process.stdout.write(
|
||||||
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
|
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n`
|
||||||
);
|
);
|
||||||
const types = await openapiTS(OPENAPI_URL, {
|
const types = await openapiTS(OPENAPI_URL, {
|
||||||
exportType: true,
|
exportType: true,
|
||||||
transform: (schemaObject) => {
|
transform: (schemaObject, metadata) => {
|
||||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Because invocations may have required fields that accept connection input, the generated
|
||||||
|
* types may be incorrect.
|
||||||
|
*
|
||||||
|
* For example, the ImageResizeInvocation has a required `image` field, but because it accepts
|
||||||
|
* connection input, it should be optional on instantiation of the field.
|
||||||
|
*
|
||||||
|
* To handle this, the schema exposes an `input` property that can be used to determine if the
|
||||||
|
* field accepts connection input. If it does, we can make the field optional.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Check if we are generating types for an invocation
|
||||||
|
const isInvocationPath = metadata.path.match(
|
||||||
|
/^#\/components\/schemas\/\w*Invocation$/
|
||||||
|
);
|
||||||
|
|
||||||
|
const hasInvocationProperties =
|
||||||
|
schemaObject.properties &&
|
||||||
|
['id', 'is_intermediate', 'type'].every(
|
||||||
|
(prop) => prop in schemaObject.properties
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isInvocationPath && hasInvocationProperties) {
|
||||||
|
// We only want to make fields optional if they are required
|
||||||
|
if (!Array.isArray(schemaObject?.required)) {
|
||||||
|
schemaObject.required = ['id', 'type'];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
schemaObject.required.forEach((prop) => {
|
||||||
|
const acceptsConnection = ['any', 'connection'].includes(
|
||||||
|
schemaObject.properties?.[prop]?.['input']
|
||||||
|
);
|
||||||
|
|
||||||
|
if (acceptsConnection) {
|
||||||
|
// remove this prop from the required array
|
||||||
|
const invocationName = metadata.path.split('/').pop();
|
||||||
|
console.log(
|
||||||
|
`Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}`
|
||||||
|
);
|
||||||
|
schemaObject.required = schemaObject.required.filter(
|
||||||
|
(r) => r !== prop
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
schemaObject.required = [
|
||||||
|
...new Set(schemaObject.required.concat(['id', 'type'])),
|
||||||
|
];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if (
|
||||||
|
// 'input' in schemaObject &&
|
||||||
|
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
|
||||||
|
// ) {
|
||||||
|
// schemaObject.required = false;
|
||||||
|
// }
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
fs.writeFileSync(OUTPUT_FILE, types);
|
fs.writeFileSync(OUTPUT_FILE, types);
|
||||||
process.stdout.write(` OK!\r\n`);
|
process.stdout.write(`\nOK!\r\n`);
|
||||||
}
|
}
|
||||||
|
|
||||||
main();
|
main();
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
import {
|
||||||
|
ctrlKeyPressed,
|
||||||
|
metaKeyPressed,
|
||||||
|
shiftKeyPressed,
|
||||||
|
} from 'features/ui/store/hotkeysSlice';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
setActiveTab,
|
setActiveTab,
|
||||||
@ -16,11 +20,11 @@ import React, { memo } from 'react';
|
|||||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
|
||||||
const globalHotkeysSelector = createSelector(
|
const globalHotkeysSelector = createSelector(
|
||||||
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
|
[stateSelector],
|
||||||
(hotkeys, ui) => {
|
({ hotkeys, ui }) => {
|
||||||
const { shift } = hotkeys;
|
const { shift, ctrl, meta } = hotkeys;
|
||||||
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
||||||
return { shift, shouldPinGallery, shouldPinParametersPanel };
|
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
memoizeOptions: {
|
memoizeOptions: {
|
||||||
@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector(
|
|||||||
*/
|
*/
|
||||||
const GlobalHotkeys: React.FC = () => {
|
const GlobalHotkeys: React.FC = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
|
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
|
||||||
globalHotkeysSelector
|
useAppSelector(globalHotkeysSelector);
|
||||||
);
|
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => {
|
|||||||
} else {
|
} else {
|
||||||
shift && dispatch(shiftKeyPressed(false));
|
shift && dispatch(shiftKeyPressed(false));
|
||||||
}
|
}
|
||||||
|
if (isHotkeyPressed('ctrl')) {
|
||||||
|
!ctrl && dispatch(ctrlKeyPressed(true));
|
||||||
|
} else {
|
||||||
|
ctrl && dispatch(ctrlKeyPressed(false));
|
||||||
|
}
|
||||||
|
if (isHotkeyPressed('meta')) {
|
||||||
|
!meta && dispatch(metaKeyPressed(true));
|
||||||
|
} else {
|
||||||
|
meta && dispatch(metaKeyPressed(false));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{ keyup: true, keydown: true },
|
{ keyup: true, keydown: true },
|
||||||
[shift]
|
[shift, ctrl, meta]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys('o', () => {
|
useHotkeys('o', () => {
|
||||||
|
@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client';
|
|||||||
import { socketMiddleware } from 'services/events/middleware';
|
import { socketMiddleware } from 'services/events/middleware';
|
||||||
import Loading from '../../common/components/Loading/Loading';
|
import Loading from '../../common/components/Loading/Loading';
|
||||||
import '../../i18n';
|
import '../../i18n';
|
||||||
import ImageDndContext from './ImageDnd/ImageDndContext';
|
import AppDndContext from '../../features/dnd/components/AppDndContext';
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||||
@ -80,9 +80,9 @@ const InvokeAIUI = ({
|
|||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<ImageDndContext>
|
<AppDndContext>
|
||||||
<App config={config} headerComponent={headerComponent} />
|
<App config={config} headerComponent={headerComponent} />
|
||||||
</ImageDndContext>
|
</AppDndContext>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
</Provider>
|
</Provider>
|
||||||
|
@ -19,7 +19,8 @@ type LoggerNamespace =
|
|||||||
| 'nodes'
|
| 'nodes'
|
||||||
| 'system'
|
| 'system'
|
||||||
| 'socketio'
|
| 'socketio'
|
||||||
| 'session';
|
| 'session'
|
||||||
|
| 'dnd';
|
||||||
|
|
||||||
export const logger = (namespace: LoggerNamespace) =>
|
export const logger = (namespace: LoggerNamespace) =>
|
||||||
$logger.get().child({ namespace });
|
$logger.get().child({ namespace });
|
||||||
|
@ -15,7 +15,7 @@ export const actionsDenylist = [
|
|||||||
'socket/socketGeneratorProgress',
|
'socket/socketGeneratorProgress',
|
||||||
'socket/appSocketGeneratorProgress',
|
'socket/appSocketGeneratorProgress',
|
||||||
// every time user presses shift
|
// every time user presses shift
|
||||||
'hotkeys/shiftKeyPressed',
|
// 'hotkeys/shiftKeyPressed',
|
||||||
// this happens after every state change
|
// this happens after every state change
|
||||||
'@@REMEMBER_PERSISTED',
|
'@@REMEMBER_PERSISTED',
|
||||||
];
|
];
|
||||||
|
@ -1,16 +1,20 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import {
|
|
||||||
TypesafeDraggableData,
|
|
||||||
TypesafeDroppableData,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import {
|
||||||
|
TypesafeDraggableData,
|
||||||
|
TypesafeDroppableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import {
|
||||||
|
fieldImageValueChanged,
|
||||||
|
workflowExposedFieldAdded,
|
||||||
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '../';
|
import { startAppListening } from '../';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
|
||||||
export const dndDropped = createAction<{
|
export const dndDropped = createAction<{
|
||||||
overData: TypesafeDroppableData;
|
overData: TypesafeDroppableData;
|
||||||
@ -21,7 +25,7 @@ export const addImageDroppedListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: dndDropped,
|
actionCreator: dndDropped,
|
||||||
effect: async (action, { dispatch }) => {
|
effect: async (action, { dispatch }) => {
|
||||||
const log = logger('images');
|
const log = logger('dnd');
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
|
|
||||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||||
@ -31,10 +35,28 @@ export const addImageDroppedListener = () => {
|
|||||||
{ activeData, overData },
|
{ activeData, overData },
|
||||||
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
||||||
);
|
);
|
||||||
|
} else if (activeData.payloadType === 'NODE_FIELD') {
|
||||||
|
log.debug(
|
||||||
|
{ activeData: parseify(activeData), overData: parseify(overData) },
|
||||||
|
'Node field dropped'
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
overData.actionType === 'ADD_FIELD_TO_LINEAR' &&
|
||||||
|
activeData.payloadType === 'NODE_FIELD'
|
||||||
|
) {
|
||||||
|
const { nodeId, field } = activeData.payload;
|
||||||
|
dispatch(
|
||||||
|
workflowExposedFieldAdded({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image dropped on current image
|
* Image dropped on current image
|
||||||
*/
|
*/
|
||||||
@ -99,7 +121,7 @@ export const addImageDroppedListener = () => {
|
|||||||
) {
|
) {
|
||||||
const { fieldName, nodeId } = overData.context;
|
const { fieldName, nodeId } = overData.context;
|
||||||
dispatch(
|
dispatch(
|
||||||
fieldValueChanged({
|
fieldImageValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName,
|
fieldName,
|
||||||
value: activeData.payload.imageDTO,
|
value: activeData.payload.imageDTO,
|
||||||
|
@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react';
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { omit } from 'lodash-es';
|
import { omit } from 'lodash-es';
|
||||||
@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
|
|
||||||
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
||||||
const { nodeId, fieldName } = postUploadAction;
|
const { nodeId, fieldName } = postUploadAction;
|
||||||
dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO }));
|
dispatch(
|
||||||
|
fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })
|
||||||
|
);
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast({
|
addToast({
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
...DEFAULT_UPLOADED_TOAST,
|
||||||
|
@ -15,12 +15,21 @@ import {
|
|||||||
setShouldUseSDXLRefiner,
|
setShouldUseSDXLRefiner,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
|
import {
|
||||||
|
mainModelsAdapter,
|
||||||
|
modelsApi,
|
||||||
|
vaeModelsAdapter,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
import { TypeGuardFor } from 'services/api/types';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (state, action) =>
|
predicate: (
|
||||||
|
action
|
||||||
|
): action is TypeGuardFor<
|
||||||
|
typeof modelsApi.endpoints.getMainModels.matchFulfilled
|
||||||
|
> =>
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const currentModel = getState().generation.model;
|
const currentModel = getState().generation.model;
|
||||||
|
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||||
|
|
||||||
const isCurrentModelAvailable = some(
|
if (models.length === 0) {
|
||||||
action.payload.entities,
|
|
||||||
(m) =>
|
|
||||||
m?.model_name === currentModel?.model_name &&
|
|
||||||
m?.base_model === currentModel?.base_model &&
|
|
||||||
m?.model_type === currentModel?.model_type
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isCurrentModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstModelId = action.payload.ids[0];
|
|
||||||
const firstModel = action.payload.entities[firstModelId];
|
|
||||||
|
|
||||||
if (!firstModel) {
|
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(modelChanged(null));
|
dispatch(modelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = zMainOrOnnxModel.safeParse(firstModel);
|
const isCurrentModelAvailable = currentModel
|
||||||
|
? models.some(
|
||||||
|
(m) =>
|
||||||
|
m.model_name === currentModel.model_name &&
|
||||||
|
m.base_model === currentModel.base_model &&
|
||||||
|
m.model_type === currentModel.model_type
|
||||||
|
)
|
||||||
|
: false;
|
||||||
|
|
||||||
|
if (isCurrentModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zMainOrOnnxModel.safeParse(models[0]);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
log.error(
|
log.error(
|
||||||
@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (state, action) =>
|
predicate: (
|
||||||
|
action
|
||||||
|
): action is TypeGuardFor<
|
||||||
|
typeof modelsApi.endpoints.getMainModels.matchFulfilled
|
||||||
|
> =>
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const currentModel = getState().sdxl.refinerModel;
|
const currentModel = getState().sdxl.refinerModel;
|
||||||
|
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||||
|
|
||||||
const isCurrentModelAvailable = some(
|
if (models.length === 0) {
|
||||||
action.payload.entities,
|
|
||||||
(m) =>
|
|
||||||
m?.model_name === currentModel?.model_name &&
|
|
||||||
m?.base_model === currentModel?.base_model &&
|
|
||||||
m?.model_type === currentModel?.model_type
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isCurrentModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstModelId = action.payload.ids[0];
|
|
||||||
const firstModel = action.payload.entities[firstModelId];
|
|
||||||
|
|
||||||
if (!firstModel) {
|
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
dispatch(setShouldUseSDXLRefiner(false));
|
dispatch(setShouldUseSDXLRefiner(false));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = zSDXLRefinerModel.safeParse(firstModel);
|
const isCurrentModelAvailable = currentModel
|
||||||
|
? models.some(
|
||||||
|
(m) =>
|
||||||
|
m.model_name === currentModel.model_name &&
|
||||||
|
m.base_model === currentModel.base_model &&
|
||||||
|
m.model_type === currentModel.model_type
|
||||||
|
)
|
||||||
|
: false;
|
||||||
|
|
||||||
|
if (isCurrentModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zSDXLRefinerModel.safeParse(models[0]);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
log.error(
|
log.error(
|
||||||
|
@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => {
|
|||||||
const log = logger('system');
|
const log = logger('system');
|
||||||
const schemaJSON = action.payload;
|
const schemaJSON = action.payload;
|
||||||
|
|
||||||
log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema');
|
log.debug({ schemaJSON }, 'Received OpenAPI schema');
|
||||||
|
|
||||||
const nodeTemplates = parseSchema(schemaJSON);
|
const nodeTemplates = parseSchema(schemaJSON);
|
||||||
|
|
||||||
@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => {
|
|||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: receivedOpenAPISchema.rejected,
|
actionCreator: receivedOpenAPISchema.rejected,
|
||||||
effect: () => {
|
effect: (action) => {
|
||||||
const log = logger('system');
|
const log = logger('system');
|
||||||
log.error('Problem dereferencing OpenAPI Schema');
|
log.error(
|
||||||
|
{ error: parseify(action.error) },
|
||||||
|
'Problem retrieving OpenAPI Schema'
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -19,7 +19,7 @@ import {
|
|||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
const nodeDenylist = ['dataURL_image'];
|
const nodeDenylist = ['load_image'];
|
||||||
|
|
||||||
export const addInvocationCompleteEventListener = () => {
|
export const addInvocationCompleteEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
|
@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => {
|
|||||||
const log = logger('session');
|
const log = logger('session');
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildNodesGraph(state);
|
const graph = buildNodesGraph(state.nodes);
|
||||||
dispatch(nodesGraphBuilt(graph));
|
dispatch(nodesGraphBuilt(graph));
|
||||||
log.debug({ graph: parseify(graph) }, 'Nodes graph built');
|
log.debug({ graph: parseify(graph) }, 'Nodes graph built');
|
||||||
|
|
||||||
|
@ -1,86 +1,7 @@
|
|||||||
import {
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
// CONTROLNET_MODELS,
|
|
||||||
CONTROLNET_PROCESSORS,
|
|
||||||
} from 'features/controlNet/store/constants';
|
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
// These are old types from the model management UI
|
|
||||||
|
|
||||||
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
|
|
||||||
|
|
||||||
// export type Model = {
|
|
||||||
// status: ModelStatus;
|
|
||||||
// description: string;
|
|
||||||
// weights: string;
|
|
||||||
// config?: string;
|
|
||||||
// vae?: string;
|
|
||||||
// width?: number;
|
|
||||||
// height?: number;
|
|
||||||
// default?: boolean;
|
|
||||||
// format?: string;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type DiffusersModel = {
|
|
||||||
// status: ModelStatus;
|
|
||||||
// description: string;
|
|
||||||
// repo_id?: string;
|
|
||||||
// path?: string;
|
|
||||||
// vae?: {
|
|
||||||
// repo_id?: string;
|
|
||||||
// path?: string;
|
|
||||||
// };
|
|
||||||
// format?: string;
|
|
||||||
// default?: boolean;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type ModelList = Record<string, Model & DiffusersModel>;
|
|
||||||
|
|
||||||
// export type FoundModel = {
|
|
||||||
// name: string;
|
|
||||||
// location: string;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeModelConfigProps = {
|
|
||||||
// name: string | undefined;
|
|
||||||
// description: string | undefined;
|
|
||||||
// config: string | undefined;
|
|
||||||
// weights: string | undefined;
|
|
||||||
// vae: string | undefined;
|
|
||||||
// width: number | undefined;
|
|
||||||
// height: number | undefined;
|
|
||||||
// default: boolean | undefined;
|
|
||||||
// format: string | undefined;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeDiffusersModelConfigProps = {
|
|
||||||
// name: string | undefined;
|
|
||||||
// description: string | undefined;
|
|
||||||
// repo_id: string | undefined;
|
|
||||||
// path: string | undefined;
|
|
||||||
// default: boolean | undefined;
|
|
||||||
// format: string | undefined;
|
|
||||||
// vae: {
|
|
||||||
// repo_id: string | undefined;
|
|
||||||
// path: string | undefined;
|
|
||||||
// };
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeModelConversionProps = {
|
|
||||||
// model_name: string;
|
|
||||||
// save_location: string;
|
|
||||||
// custom_location: string | null;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeModelMergingProps = {
|
|
||||||
// models_to_merge: string[];
|
|
||||||
// alpha: number;
|
|
||||||
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
|
||||||
// force: boolean;
|
|
||||||
// merged_model_name: string;
|
|
||||||
// model_merge_save_path: string | null;
|
|
||||||
// };
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A disable-able application feature
|
* A disable-able application feature
|
||||||
*/
|
*/
|
||||||
|
@ -6,10 +6,6 @@ import {
|
|||||||
useColorMode,
|
useColorMode,
|
||||||
useColorModeValue,
|
useColorModeValue,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import {
|
|
||||||
TypesafeDraggableData,
|
|
||||||
TypesafeDroppableData,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import {
|
import {
|
||||||
IAILoadingImageFallback,
|
IAILoadingImageFallback,
|
||||||
@ -17,6 +13,10 @@ import {
|
|||||||
} from 'common/components/IAIImageFallback';
|
} from 'common/components/IAIImageFallback';
|
||||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||||
|
import {
|
||||||
|
TypesafeDraggableData,
|
||||||
|
TypesafeDroppableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||||
import {
|
import {
|
||||||
MouseEvent,
|
MouseEvent,
|
||||||
@ -157,11 +157,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
<IAILoadingImageFallback image={imageDTO} />
|
<IAILoadingImageFallback image={imageDTO} />
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
width={imageDTO.width}
|
|
||||||
height={imageDTO.height}
|
|
||||||
onError={onError}
|
onError={onError}
|
||||||
draggable={false}
|
draggable={false}
|
||||||
sx={{
|
sx={{
|
||||||
|
w: imageDTO.width,
|
||||||
objectFit: 'contain',
|
objectFit: 'contain',
|
||||||
maxW: 'full',
|
maxW: 'full',
|
||||||
maxH: 'full',
|
maxH: 'full',
|
||||||
@ -213,13 +212,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{!isDropDisabled && (
|
|
||||||
<IAIDroppable
|
|
||||||
data={droppableData}
|
|
||||||
disabled={isDropDisabled}
|
|
||||||
dropLabel={dropLabel}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{onClickReset && withResetIcon && imageDTO && (
|
{onClickReset && withResetIcon && imageDTO && (
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
onClick={onClickReset}
|
onClick={onClickReset}
|
||||||
@ -244,6 +236,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{!isDropDisabled && (
|
||||||
|
<IAIDroppable
|
||||||
|
data={droppableData}
|
||||||
|
disabled={isDropDisabled}
|
||||||
|
dropLabel={dropLabel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</ImageContextMenu>
|
</ImageContextMenu>
|
||||||
|
@ -1,22 +1,19 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box, BoxProps } from '@chakra-ui/react';
|
||||||
import {
|
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||||
TypesafeDraggableData,
|
import { TypesafeDraggableData } from 'features/dnd/types';
|
||||||
useDraggable,
|
import { memo, useRef } from 'react';
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { MouseEvent, memo, useRef } from 'react';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
type IAIDraggableProps = {
|
type IAIDraggableProps = BoxProps & {
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
data?: TypesafeDraggableData;
|
data?: TypesafeDraggableData;
|
||||||
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIDraggable = (props: IAIDraggableProps) => {
|
const IAIDraggable = (props: IAIDraggableProps) => {
|
||||||
const { data, disabled, onClick } = props;
|
const { data, disabled, ...rest } = props;
|
||||||
const dndId = useRef(uuidv4());
|
const dndId = useRef(uuidv4());
|
||||||
|
|
||||||
const { attributes, listeners, setNodeRef } = useDraggable({
|
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
|
||||||
id: dndId.current,
|
id: dndId.current,
|
||||||
disabled,
|
disabled,
|
||||||
data,
|
data,
|
||||||
@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
onClick={onClick}
|
|
||||||
ref={setNodeRef}
|
ref={setNodeRef}
|
||||||
position="absolute"
|
position="absolute"
|
||||||
w="full"
|
w="full"
|
||||||
@ -33,6 +29,7 @@ const IAIDraggable = (props: IAIDraggableProps) => {
|
|||||||
insetInlineStart={0}
|
insetInlineStart={0}
|
||||||
{...attributes}
|
{...attributes}
|
||||||
{...listeners}
|
{...listeners}
|
||||||
|
{...rest}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import {
|
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||||
TypesafeDroppableData,
|
import { TypesafeDroppableData } from 'features/dnd/types';
|
||||||
isValidDrop,
|
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||||
useDroppable,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { AnimatePresence } from 'framer-motion';
|
import { AnimatePresence } from 'framer-motion';
|
||||||
import { ReactNode, memo, useRef } from 'react';
|
import { ReactNode, memo, useRef } from 'react';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
|||||||
const { dropLabel, data, disabled } = props;
|
const { dropLabel, data, disabled } = props;
|
||||||
const dndId = useRef(uuidv4());
|
const dndId = useRef(uuidv4());
|
||||||
|
|
||||||
const { isOver, setNodeRef, active } = useDroppable({
|
const { isOver, setNodeRef, active } = useDroppableTypesafe({
|
||||||
id: dndId.current,
|
id: dndId.current,
|
||||||
disabled,
|
disabled,
|
||||||
data,
|
data,
|
||||||
|
@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => {
|
|||||||
|
|
||||||
type IAINoImageFallbackProps = {
|
type IAINoImageFallbackProps = {
|
||||||
label?: string;
|
label?: string;
|
||||||
icon?: As;
|
icon?: As | null;
|
||||||
boxSize?: StyleProps['boxSize'];
|
boxSize?: StyleProps['boxSize'];
|
||||||
sx?: ChakraProps['sx'];
|
sx?: ChakraProps['sx'];
|
||||||
};
|
};
|
||||||
@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
|||||||
...props.sx,
|
...props.sx,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Icon as={icon} boxSize={boxSize} opacity={0.7} />
|
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
|
||||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import {
|
import {
|
||||||
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
FormControlProps,
|
FormControlProps,
|
||||||
|
FormHelperText,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
FormLabelProps,
|
FormLabelProps,
|
||||||
Switch,
|
Switch,
|
||||||
SwitchProps,
|
SwitchProps,
|
||||||
|
Text,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps {
|
|||||||
formControlProps?: FormControlProps;
|
formControlProps?: FormControlProps;
|
||||||
formLabelProps?: FormLabelProps;
|
formLabelProps?: FormLabelProps;
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
|
helperText?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => {
|
|||||||
formControlProps,
|
formControlProps,
|
||||||
formLabelProps,
|
formLabelProps,
|
||||||
tooltip,
|
tooltip,
|
||||||
|
helperText,
|
||||||
...rest
|
...rest
|
||||||
} = props;
|
} = props;
|
||||||
return (
|
return (
|
||||||
@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => {
|
|||||||
<FormControl
|
<FormControl
|
||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
width={width}
|
width={width}
|
||||||
display="flex"
|
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
{...formControlProps}
|
{...formControlProps}
|
||||||
>
|
>
|
||||||
{label && (
|
<Flex sx={{ flexDir: 'column', w: 'full' }}>
|
||||||
<FormLabel
|
<Flex sx={{ alignItems: 'center', w: 'full' }}>
|
||||||
my={1}
|
{label && (
|
||||||
flexGrow={1}
|
<FormLabel
|
||||||
sx={{
|
my={1}
|
||||||
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
flexGrow={1}
|
||||||
...formLabelProps?.sx,
|
sx={{
|
||||||
pe: 4,
|
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||||
}}
|
...formLabelProps?.sx,
|
||||||
{...formLabelProps}
|
pe: 4,
|
||||||
>
|
}}
|
||||||
{label}
|
{...formLabelProps}
|
||||||
</FormLabel>
|
>
|
||||||
)}
|
{label}
|
||||||
<Switch {...rest} />
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
<Switch {...rest} />
|
||||||
|
</Flex>
|
||||||
|
{helperText && (
|
||||||
|
<FormHelperText>
|
||||||
|
<Text variant="subtext">{helperText}</Text>
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
);
|
);
|
||||||
|
@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => {
|
|||||||
accent850,
|
accent850,
|
||||||
accent900,
|
accent900,
|
||||||
accent950,
|
accent950,
|
||||||
|
baseAlpha50,
|
||||||
|
baseAlpha100,
|
||||||
|
baseAlpha150,
|
||||||
|
baseAlpha200,
|
||||||
|
baseAlpha250,
|
||||||
|
baseAlpha300,
|
||||||
|
baseAlpha350,
|
||||||
|
baseAlpha400,
|
||||||
|
baseAlpha450,
|
||||||
|
baseAlpha500,
|
||||||
|
baseAlpha550,
|
||||||
|
baseAlpha600,
|
||||||
|
baseAlpha650,
|
||||||
|
baseAlpha700,
|
||||||
|
baseAlpha750,
|
||||||
|
baseAlpha800,
|
||||||
|
baseAlpha850,
|
||||||
|
baseAlpha900,
|
||||||
|
baseAlpha950,
|
||||||
|
accentAlpha50,
|
||||||
|
accentAlpha100,
|
||||||
|
accentAlpha150,
|
||||||
|
accentAlpha200,
|
||||||
|
accentAlpha250,
|
||||||
|
accentAlpha300,
|
||||||
|
accentAlpha350,
|
||||||
|
accentAlpha400,
|
||||||
|
accentAlpha450,
|
||||||
|
accentAlpha500,
|
||||||
|
accentAlpha550,
|
||||||
|
accentAlpha600,
|
||||||
|
accentAlpha650,
|
||||||
|
accentAlpha700,
|
||||||
|
accentAlpha750,
|
||||||
|
accentAlpha800,
|
||||||
|
accentAlpha850,
|
||||||
|
accentAlpha900,
|
||||||
|
accentAlpha950,
|
||||||
] = useToken('colors', [
|
] = useToken('colors', [
|
||||||
'base.50',
|
'base.50',
|
||||||
'base.100',
|
'base.100',
|
||||||
@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => {
|
|||||||
'accent.850',
|
'accent.850',
|
||||||
'accent.900',
|
'accent.900',
|
||||||
'accent.950',
|
'accent.950',
|
||||||
|
'baseAlpha.50',
|
||||||
|
'baseAlpha.100',
|
||||||
|
'baseAlpha.150',
|
||||||
|
'baseAlpha.200',
|
||||||
|
'baseAlpha.250',
|
||||||
|
'baseAlpha.300',
|
||||||
|
'baseAlpha.350',
|
||||||
|
'baseAlpha.400',
|
||||||
|
'baseAlpha.450',
|
||||||
|
'baseAlpha.500',
|
||||||
|
'baseAlpha.550',
|
||||||
|
'baseAlpha.600',
|
||||||
|
'baseAlpha.650',
|
||||||
|
'baseAlpha.700',
|
||||||
|
'baseAlpha.750',
|
||||||
|
'baseAlpha.800',
|
||||||
|
'baseAlpha.850',
|
||||||
|
'baseAlpha.900',
|
||||||
|
'baseAlpha.950',
|
||||||
|
'accentAlpha.50',
|
||||||
|
'accentAlpha.100',
|
||||||
|
'accentAlpha.150',
|
||||||
|
'accentAlpha.200',
|
||||||
|
'accentAlpha.250',
|
||||||
|
'accentAlpha.300',
|
||||||
|
'accentAlpha.350',
|
||||||
|
'accentAlpha.400',
|
||||||
|
'accentAlpha.450',
|
||||||
|
'accentAlpha.500',
|
||||||
|
'accentAlpha.550',
|
||||||
|
'accentAlpha.600',
|
||||||
|
'accentAlpha.650',
|
||||||
|
'accentAlpha.700',
|
||||||
|
'accentAlpha.750',
|
||||||
|
'accentAlpha.800',
|
||||||
|
'accentAlpha.850',
|
||||||
|
'accentAlpha.900',
|
||||||
|
'accentAlpha.950',
|
||||||
]);
|
]);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => {
|
|||||||
accent850,
|
accent850,
|
||||||
accent900,
|
accent900,
|
||||||
accent950,
|
accent950,
|
||||||
|
baseAlpha50,
|
||||||
|
baseAlpha100,
|
||||||
|
baseAlpha150,
|
||||||
|
baseAlpha200,
|
||||||
|
baseAlpha250,
|
||||||
|
baseAlpha300,
|
||||||
|
baseAlpha350,
|
||||||
|
baseAlpha400,
|
||||||
|
baseAlpha450,
|
||||||
|
baseAlpha500,
|
||||||
|
baseAlpha550,
|
||||||
|
baseAlpha600,
|
||||||
|
baseAlpha650,
|
||||||
|
baseAlpha700,
|
||||||
|
baseAlpha750,
|
||||||
|
baseAlpha800,
|
||||||
|
baseAlpha850,
|
||||||
|
baseAlpha900,
|
||||||
|
baseAlpha950,
|
||||||
|
accentAlpha50,
|
||||||
|
accentAlpha100,
|
||||||
|
accentAlpha150,
|
||||||
|
accentAlpha200,
|
||||||
|
accentAlpha250,
|
||||||
|
accentAlpha300,
|
||||||
|
accentAlpha350,
|
||||||
|
accentAlpha400,
|
||||||
|
accentAlpha450,
|
||||||
|
accentAlpha500,
|
||||||
|
accentAlpha550,
|
||||||
|
accentAlpha600,
|
||||||
|
accentAlpha650,
|
||||||
|
accentAlpha700,
|
||||||
|
accentAlpha750,
|
||||||
|
accentAlpha800,
|
||||||
|
accentAlpha850,
|
||||||
|
accentAlpha900,
|
||||||
|
accentAlpha950,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
/**
|
/**
|
||||||
* Serialize an object to JSON and back to a new object
|
* Serialize an object to JSON and back to a new object
|
||||||
*/
|
*/
|
||||||
export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj));
|
export const parseify = (obj: unknown) => {
|
||||||
|
try {
|
||||||
|
return JSON.parse(JSON.stringify(obj));
|
||||||
|
} catch {
|
||||||
|
return 'Error parsing object';
|
||||||
|
}
|
||||||
|
};
|
||||||
|
@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/dist/query';
|
|||||||
import {
|
import {
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'features/dnd/types';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
|
|||||||
/**
|
/**
|
||||||
* Any ControlNet Processor node, with its parameters flagged as required
|
* Any ControlNet Processor node, with its parameters flagged as required
|
||||||
*/
|
*/
|
||||||
export type RequiredControlNetProcessorNode =
|
export type RequiredControlNetProcessorNode = O.Required<
|
||||||
| RequiredCannyImageProcessorInvocation
|
| RequiredCannyImageProcessorInvocation
|
||||||
| RequiredContentShuffleImageProcessorInvocation
|
| RequiredContentShuffleImageProcessorInvocation
|
||||||
| RequiredHedImageProcessorInvocation
|
| RequiredHedImageProcessorInvocation
|
||||||
@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode =
|
|||||||
| RequiredNormalbaeImageProcessorInvocation
|
| RequiredNormalbaeImageProcessorInvocation
|
||||||
| RequiredOpenposeImageProcessorInvocation
|
| RequiredOpenposeImageProcessorInvocation
|
||||||
| RequiredPidiImageProcessorInvocation
|
| RequiredPidiImageProcessorInvocation
|
||||||
| RequiredZoeDepthImageProcessorInvocation;
|
| RequiredZoeDepthImageProcessorInvocation,
|
||||||
|
'id'
|
||||||
|
>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Type guard for CannyImageProcessorInvocation
|
* Type guard for CannyImageProcessorInvocation
|
||||||
|
@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { some } from 'lodash-es';
|
import { some } from 'lodash-es';
|
||||||
import { ImageUsage } from './types';
|
import { ImageUsage } from './types';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/types';
|
||||||
|
|
||||||
export const getImageUsage = (state: RootState, image_name: string) => {
|
export const getImageUsage = (state: RootState, image_name: string) => {
|
||||||
const { generation, canvas, nodes, controlNet } = state;
|
const { generation, canvas, nodes, controlNet } = state;
|
||||||
@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
|
|||||||
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const isNodesImage = nodes.nodes.some((node) => {
|
const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => {
|
||||||
return some(
|
return some(
|
||||||
node.data.inputs,
|
node.data.inputs,
|
||||||
(input) =>
|
(input) =>
|
||||||
input.type === 'image' && input.value?.image_name === image_name
|
input.type === 'ImageField' && input.value?.image_name === image_name
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -6,23 +6,18 @@ import {
|
|||||||
useSensor,
|
useSensor,
|
||||||
useSensors,
|
useSensors,
|
||||||
} from '@dnd-kit/core';
|
} from '@dnd-kit/core';
|
||||||
import { snapCenterToCursor } from '@dnd-kit/modifiers';
|
import { logger } from 'app/logging/logger';
|
||||||
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
|
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
||||||
|
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
|
||||||
|
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
|
||||||
|
import { DndContextTypesafe } from './DndContextTypesafe';
|
||||||
import DragPreview from './DragPreview';
|
import DragPreview from './DragPreview';
|
||||||
import {
|
|
||||||
DndContext,
|
|
||||||
DragEndEvent,
|
|
||||||
DragStartEvent,
|
|
||||||
TypesafeDraggableData,
|
|
||||||
} from './typesafeDnd';
|
|
||||||
import { logger } from 'app/logging/logger';
|
|
||||||
|
|
||||||
type ImageDndContextProps = PropsWithChildren;
|
const AppDndContext = (props: PropsWithChildren) => {
|
||||||
|
|
||||||
const ImageDndContext = (props: ImageDndContextProps) => {
|
|
||||||
const [activeDragData, setActiveDragData] =
|
const [activeDragData, setActiveDragData] =
|
||||||
useState<TypesafeDraggableData | null>(null);
|
useState<TypesafeDraggableData | null>(null);
|
||||||
const log = logger('images');
|
const log = logger('images');
|
||||||
@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
|
|
||||||
const handleDragStart = useCallback(
|
const handleDragStart = useCallback(
|
||||||
(event: DragStartEvent) => {
|
(event: DragStartEvent) => {
|
||||||
log.trace({ dragData: event.active.data.current }, 'Drag started');
|
log.trace(
|
||||||
|
{ dragData: parseify(event.active.data.current) },
|
||||||
|
'Drag started'
|
||||||
|
);
|
||||||
const activeData = event.active.data.current;
|
const activeData = event.active.data.current;
|
||||||
if (!activeData) {
|
if (!activeData) {
|
||||||
return;
|
return;
|
||||||
@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
|
|
||||||
const handleDragEnd = useCallback(
|
const handleDragEnd = useCallback(
|
||||||
(event: DragEndEvent) => {
|
(event: DragEndEvent) => {
|
||||||
log.trace({ dragData: event.active.data.current }, 'Drag ended');
|
log.trace(
|
||||||
|
{ dragData: parseify(event.active.data.current) },
|
||||||
|
'Drag ended'
|
||||||
|
);
|
||||||
const overData = event.over?.data.current;
|
const overData = event.over?.data.current;
|
||||||
if (!activeDragData || !overData) {
|
if (!activeDragData || !overData) {
|
||||||
return;
|
return;
|
||||||
@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
|
|
||||||
const sensors = useSensors(mouseSensor, touchSensor);
|
const sensors = useSensors(mouseSensor, touchSensor);
|
||||||
|
|
||||||
|
const scaledModifier = useScaledModifer();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<DndContext
|
<DndContextTypesafe
|
||||||
onDragStart={handleDragStart}
|
onDragStart={handleDragStart}
|
||||||
onDragEnd={handleDragEnd}
|
onDragEnd={handleDragEnd}
|
||||||
sensors={sensors}
|
sensors={sensors}
|
||||||
collisionDetection={pointerWithin}
|
collisionDetection={pointerWithin}
|
||||||
|
autoScroll={false}
|
||||||
>
|
>
|
||||||
{props.children}
|
{props.children}
|
||||||
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
|
<DragOverlay
|
||||||
|
dropAnimation={null}
|
||||||
|
modifiers={[scaledModifier]}
|
||||||
|
style={{
|
||||||
|
width: 'min-content',
|
||||||
|
height: 'min-content',
|
||||||
|
cursor: 'none',
|
||||||
|
userSelect: 'none',
|
||||||
|
// expand overlay to prevent cursor from going outside it and displaying
|
||||||
|
padding: '10rem',
|
||||||
|
}}
|
||||||
|
>
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{activeDragData && (
|
{activeDragData && (
|
||||||
<motion.div
|
<motion.div
|
||||||
@ -98,8 +113,8 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
)}
|
)}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
</DragOverlay>
|
</DragOverlay>
|
||||||
</DndContext>
|
</DndContextTypesafe>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ImageDndContext);
|
export default memo(AppDndContext);
|
@ -0,0 +1,6 @@
|
|||||||
|
import { DndContext } from '@dnd-kit/core';
|
||||||
|
import { DndContextTypesafeProps } from '../types';
|
||||||
|
|
||||||
|
export function DndContextTypesafe(props: DndContextTypesafeProps) {
|
||||||
|
return <DndContext {...props} />;
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
|
import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { TypesafeDraggableData } from './typesafeDnd';
|
import { TypesafeDraggableData } from '../types';
|
||||||
|
|
||||||
type OverlayDragImageProps = {
|
type OverlayDragImageProps = {
|
||||||
dragData: TypesafeDraggableData | null;
|
dragData: TypesafeDraggableData | null;
|
||||||
@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (props.dragData.payloadType === 'NODE_FIELD') {
|
||||||
|
const { field, fieldTemplate } = props.dragData.payload;
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
p: 2,
|
||||||
|
px: 3,
|
||||||
|
opacity: 0.7,
|
||||||
|
bg: 'base.300',
|
||||||
|
borderRadius: 'base',
|
||||||
|
boxShadow: 'dark-lg',
|
||||||
|
whiteSpace: 'nowrap',
|
||||||
|
fontSize: 'sm',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text>{field.label || fieldTemplate.title}</Text>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (props.dragData.payloadType === 'IMAGE_DTO') {
|
if (props.dragData.payloadType === 'IMAGE_DTO') {
|
||||||
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
|
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
width: '100%',
|
width: 'full',
|
||||||
height: '100%',
|
height: 'full',
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
userSelect: 'none',
|
|
||||||
cursor: 'none',
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Image
|
<Image
|
||||||
@ -62,8 +81,6 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
cursor: 'none',
|
|
||||||
userSelect: 'none',
|
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
@ -0,0 +1,15 @@
|
|||||||
|
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||||
|
import {
|
||||||
|
UseDraggableTypesafeArguments,
|
||||||
|
UseDraggableTypesafeReturnValue,
|
||||||
|
UseDroppableTypesafeArguments,
|
||||||
|
UseDroppableTypesafeReturnValue,
|
||||||
|
} from '../types';
|
||||||
|
|
||||||
|
export function useDroppableTypesafe(props: UseDroppableTypesafeArguments) {
|
||||||
|
return useDroppable(props) as UseDroppableTypesafeReturnValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useDraggableTypesafe(props: UseDraggableTypesafeArguments) {
|
||||||
|
return useDraggable(props) as UseDraggableTypesafeReturnValue;
|
||||||
|
}
|
@ -0,0 +1,50 @@
|
|||||||
|
import type { Modifier } from '@dnd-kit/core';
|
||||||
|
import { getEventCoordinates } from '@dnd-kit/utilities';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
const selectZoom = createSelector(
|
||||||
|
[stateSelector, activeTabNameSelector],
|
||||||
|
({ nodes }, activeTabName) => (activeTabName === 'nodes' ? nodes.zoom : 1)
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
|
||||||
|
*/
|
||||||
|
export const useScaledModifer = () => {
|
||||||
|
const zoom = useAppSelector(selectZoom);
|
||||||
|
const modifier: Modifier = useCallback(
|
||||||
|
({ activatorEvent, draggingNodeRect, transform }) => {
|
||||||
|
if (draggingNodeRect && activatorEvent) {
|
||||||
|
const activatorCoordinates = getEventCoordinates(activatorEvent);
|
||||||
|
|
||||||
|
if (!activatorCoordinates) {
|
||||||
|
return transform;
|
||||||
|
}
|
||||||
|
|
||||||
|
const offsetX = activatorCoordinates.x - draggingNodeRect.left;
|
||||||
|
const offsetY = activatorCoordinates.y - draggingNodeRect.top;
|
||||||
|
|
||||||
|
const x = transform.x + offsetX - draggingNodeRect.width / 2;
|
||||||
|
const y = transform.y + offsetY - draggingNodeRect.height / 2;
|
||||||
|
const scaleX = transform.scaleX * zoom;
|
||||||
|
const scaleY = transform.scaleY * zoom;
|
||||||
|
|
||||||
|
return {
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
scaleX,
|
||||||
|
scaleY,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return transform;
|
||||||
|
},
|
||||||
|
[zoom]
|
||||||
|
);
|
||||||
|
|
||||||
|
return modifier;
|
||||||
|
};
|
@ -3,7 +3,6 @@ import {
|
|||||||
Active,
|
Active,
|
||||||
Collision,
|
Collision,
|
||||||
DndContextProps,
|
DndContextProps,
|
||||||
DndContext as OriginalDndContext,
|
|
||||||
Over,
|
Over,
|
||||||
Translate,
|
Translate,
|
||||||
UseDraggableArguments,
|
UseDraggableArguments,
|
||||||
@ -11,6 +10,10 @@ import {
|
|||||||
useDraggable as useOriginalDraggable,
|
useDraggable as useOriginalDraggable,
|
||||||
useDroppable as useOriginalDroppable,
|
useDroppable as useOriginalDroppable,
|
||||||
} from '@dnd-kit/core';
|
} from '@dnd-kit/core';
|
||||||
|
import {
|
||||||
|
InputFieldTemplate,
|
||||||
|
InputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
type BaseDropData = {
|
type BaseDropData = {
|
||||||
@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & {
|
|||||||
actionType: 'REMOVE_FROM_BOARD';
|
actionType: 'REMOVE_FROM_BOARD';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type AddFieldToLinearViewDropData = BaseDropData & {
|
||||||
|
actionType: 'ADD_FIELD_TO_LINEAR';
|
||||||
|
};
|
||||||
|
|
||||||
export type TypesafeDroppableData =
|
export type TypesafeDroppableData =
|
||||||
| CurrentImageDropData
|
| CurrentImageDropData
|
||||||
| InitialImageDropData
|
| InitialImageDropData
|
||||||
@ -71,12 +78,22 @@ export type TypesafeDroppableData =
|
|||||||
| AddToBatchDropData
|
| AddToBatchDropData
|
||||||
| NodesMultiImageDropData
|
| NodesMultiImageDropData
|
||||||
| AddToBoardDropData
|
| AddToBoardDropData
|
||||||
| RemoveFromBoardDropData;
|
| RemoveFromBoardDropData
|
||||||
|
| AddFieldToLinearViewDropData;
|
||||||
|
|
||||||
type BaseDragData = {
|
type BaseDragData = {
|
||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type NodeFieldDraggableData = BaseDragData & {
|
||||||
|
payloadType: 'NODE_FIELD';
|
||||||
|
payload: {
|
||||||
|
nodeId: string;
|
||||||
|
field: InputFieldValue;
|
||||||
|
fieldTemplate: InputFieldTemplate;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export type ImageDraggableData = BaseDragData & {
|
export type ImageDraggableData = BaseDragData & {
|
||||||
payloadType: 'IMAGE_DTO';
|
payloadType: 'IMAGE_DTO';
|
||||||
payload: { imageDTO: ImageDTO };
|
payload: { imageDTO: ImageDTO };
|
||||||
@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & {
|
|||||||
payload: { imageDTOs: ImageDTO[] };
|
payload: { imageDTOs: ImageDTO[] };
|
||||||
};
|
};
|
||||||
|
|
||||||
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
|
export type TypesafeDraggableData =
|
||||||
|
| NodeFieldDraggableData
|
||||||
|
| ImageDraggableData
|
||||||
|
| ImageDTOsDraggableData;
|
||||||
|
|
||||||
interface UseDroppableTypesafeArguments
|
export interface UseDroppableTypesafeArguments
|
||||||
extends Omit<UseDroppableArguments, 'data'> {
|
extends Omit<UseDroppableArguments, 'data'> {
|
||||||
data?: TypesafeDroppableData;
|
data?: TypesafeDroppableData;
|
||||||
}
|
}
|
||||||
|
|
||||||
type UseDroppableTypesafeReturnValue = Omit<
|
export type UseDroppableTypesafeReturnValue = Omit<
|
||||||
ReturnType<typeof useOriginalDroppable>,
|
ReturnType<typeof useOriginalDroppable>,
|
||||||
'active' | 'over'
|
'active' | 'over'
|
||||||
> & {
|
> & {
|
||||||
@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit<
|
|||||||
over: TypesafeOver | null;
|
over: TypesafeOver | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useDroppable(props: UseDroppableTypesafeArguments) {
|
export interface UseDraggableTypesafeArguments
|
||||||
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface UseDraggableTypesafeArguments
|
|
||||||
extends Omit<UseDraggableArguments, 'data'> {
|
extends Omit<UseDraggableArguments, 'data'> {
|
||||||
data?: TypesafeDraggableData;
|
data?: TypesafeDraggableData;
|
||||||
}
|
}
|
||||||
|
|
||||||
type UseDraggableTypesafeReturnValue = Omit<
|
export type UseDraggableTypesafeReturnValue = Omit<
|
||||||
ReturnType<typeof useOriginalDraggable>,
|
ReturnType<typeof useOriginalDraggable>,
|
||||||
'active' | 'over'
|
'active' | 'over'
|
||||||
> & {
|
> & {
|
||||||
@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit<
|
|||||||
over: TypesafeOver | null;
|
over: TypesafeOver | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useDraggable(props: UseDraggableTypesafeArguments) {
|
export interface TypesafeActive extends Omit<Active, 'data'> {
|
||||||
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface TypesafeActive extends Omit<Active, 'data'> {
|
|
||||||
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface TypesafeOver extends Omit<Over, 'data'> {
|
export interface TypesafeOver extends Omit<Over, 'data'> {
|
||||||
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
|
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const isValidDrop = (
|
|
||||||
overData: TypesafeDroppableData | undefined,
|
|
||||||
active: TypesafeActive | null
|
|
||||||
) => {
|
|
||||||
if (!overData || !active?.data.current) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { actionType } = overData;
|
|
||||||
const { payloadType } = active.data.current;
|
|
||||||
|
|
||||||
if (overData.id === active.data.current.id) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (actionType) {
|
|
||||||
case 'SET_CURRENT_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_INITIAL_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_CONTROLNET_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_CANVAS_INITIAL_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_NODES_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_MULTI_NODES_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
case 'ADD_TO_BATCH':
|
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
case 'ADD_TO_BOARD': {
|
|
||||||
// If the board is the same, don't allow the drop
|
|
||||||
|
|
||||||
// Check the payload types
|
|
||||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
if (!isPayloadValid) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the image's board is the board we are dragging onto
|
|
||||||
if (payloadType === 'IMAGE_DTO') {
|
|
||||||
const { imageDTO } = active.data.current.payload;
|
|
||||||
const currentBoard = imageDTO.board_id ?? 'none';
|
|
||||||
const destinationBoard = overData.context.boardId;
|
|
||||||
|
|
||||||
return currentBoard !== destinationBoard;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_DTOS') {
|
|
||||||
// TODO (multi-select)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
case 'REMOVE_FROM_BOARD': {
|
|
||||||
// If the board is the same, don't allow the drop
|
|
||||||
|
|
||||||
// Check the payload types
|
|
||||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
if (!isPayloadValid) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the image's board is the board we are dragging onto
|
|
||||||
if (payloadType === 'IMAGE_DTO') {
|
|
||||||
const { imageDTO } = active.data.current.payload;
|
|
||||||
const currentBoard = imageDTO.board_id;
|
|
||||||
|
|
||||||
return currentBoard !== 'none';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_DTOS') {
|
|
||||||
// TODO (multi-select)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
interface DragEvent {
|
interface DragEvent {
|
||||||
activatorEvent: Event;
|
activatorEvent: Event;
|
||||||
active: TypesafeActive;
|
active: TypesafeActive;
|
||||||
@ -240,6 +168,3 @@ export interface DndContextTypesafeProps
|
|||||||
onDragEnd?(event: DragEndEvent): void;
|
onDragEnd?(event: DragEndEvent): void;
|
||||||
onDragCancel?(event: DragCancelEvent): void;
|
onDragCancel?(event: DragCancelEvent): void;
|
||||||
}
|
}
|
||||||
export function DndContext(props: DndContextTypesafeProps) {
|
|
||||||
return <OriginalDndContext {...props} />;
|
|
||||||
}
|
|
87
invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
Normal file
87
invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import { TypesafeActive, TypesafeDroppableData } from '../types';
|
||||||
|
|
||||||
|
export const isValidDrop = (
|
||||||
|
overData: TypesafeDroppableData | undefined,
|
||||||
|
active: TypesafeActive | null
|
||||||
|
) => {
|
||||||
|
if (!overData || !active?.data.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { actionType } = overData;
|
||||||
|
const { payloadType } = active.data.current;
|
||||||
|
|
||||||
|
if (overData.id === active.data.current.id) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (actionType) {
|
||||||
|
case 'ADD_FIELD_TO_LINEAR':
|
||||||
|
return payloadType === 'NODE_FIELD';
|
||||||
|
case 'SET_CURRENT_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_INITIAL_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_CONTROLNET_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_CANVAS_INITIAL_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_NODES_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_MULTI_NODES_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
case 'ADD_TO_BATCH':
|
||||||
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
case 'ADD_TO_BOARD': {
|
||||||
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
|
// Check the payload types
|
||||||
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
if (!isPayloadValid) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the image's board is the board we are dragging onto
|
||||||
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
|
const { imageDTO } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTO.board_id ?? 'none';
|
||||||
|
const destinationBoard = overData.context.boardId;
|
||||||
|
|
||||||
|
return currentBoard !== destinationBoard;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
|
// TODO (multi-select)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case 'REMOVE_FROM_BOARD': {
|
||||||
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
|
// Check the payload types
|
||||||
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
if (!isPayloadValid) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the image's board is the board we are dragging onto
|
||||||
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
|
const { imageDTO } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTO.board_id;
|
||||||
|
|
||||||
|
return currentBoard !== 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
|
// TODO (multi-select)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
@ -11,7 +11,6 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
|||||||
import { BoardDTO } from 'services/api/types';
|
import { BoardDTO } from 'services/api/types';
|
||||||
import AutoAddIcon from '../AutoAddIcon';
|
import AutoAddIcon from '../AutoAddIcon';
|
||||||
import BoardContextMenu from '../BoardContextMenu';
|
import BoardContextMenu from '../BoardContextMenu';
|
||||||
|
import { AddToBoardDropData } from 'features/dnd/types';
|
||||||
|
|
||||||
interface GalleryBoardProps {
|
interface GalleryBoardProps {
|
||||||
board: BoardDTO;
|
board: BoardDTO;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { As, Badge, Flex } from '@chakra-ui/react';
|
import { As, Badge, Flex } from '@chakra-ui/react';
|
||||||
import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import IAIDroppable from 'common/components/IAIDroppable';
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { TypesafeDroppableData } from 'features/dnd/types';
|
||||||
import { BoardId } from 'features/gallery/store/types';
|
import { BoardId } from 'features/gallery/store/types';
|
||||||
import { ReactNode } from 'react';
|
import { ReactNode } from 'react';
|
||||||
import BoardContextMenu from '../BoardContextMenu';
|
import BoardContextMenu from '../BoardContextMenu';
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import InvokeAILogoImage from 'assets/images/logo.png';
|
import InvokeAILogoImage from 'assets/images/logo.png';
|
||||||
import IAIDroppable from 'common/components/IAIDroppable';
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||||
|
import { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||||
import {
|
import {
|
||||||
boardIdSelected,
|
|
||||||
autoAddBoardIdChanged,
|
autoAddBoardIdChanged,
|
||||||
|
boardIdSelected,
|
||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import {
|
|
||||||
TypesafeDraggableData,
|
|
||||||
TypesafeDroppableData,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import {
|
||||||
|
TypesafeDraggableData,
|
||||||
|
TypesafeDroppableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
|
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
|
||||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
|
@ -52,11 +52,13 @@ const ImageGalleryContent = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack
|
<VStack
|
||||||
|
layerStyle="first"
|
||||||
sx={{
|
sx={{
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
h: 'full',
|
h: 'full',
|
||||||
w: 'full',
|
w: 'full',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
|
p: 2,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Box sx={{ w: 'full' }}>
|
<Box sx={{ w: 'full' }}>
|
||||||
|
@ -1,9 +1,4 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import {
|
|
||||||
ImageDTOsDraggableData,
|
|
||||||
ImageDraggableData,
|
|
||||||
TypesafeDraggableData,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
||||||
@ -12,6 +7,11 @@ import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
|||||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||||
import { FaTrash } from 'react-icons/fa';
|
import { FaTrash } from 'react-icons/fa';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
|
import {
|
||||||
|
ImageDTOsDraggableData,
|
||||||
|
ImageDraggableData,
|
||||||
|
TypesafeDraggableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
|
|
||||||
interface HoverableImageProps {
|
interface HoverableImageProps {
|
||||||
imageName: string;
|
imageName: string;
|
||||||
|
@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
|
|||||||
options: {
|
options: {
|
||||||
scrollbars: {
|
scrollbars: {
|
||||||
visibility: 'auto',
|
visibility: 'auto',
|
||||||
autoHide: 'leave',
|
autoHide: 'scroll',
|
||||||
autoHideDelay: 1300,
|
autoHideDelay: 1300,
|
||||||
theme: 'os-theme-dark',
|
theme: 'os-theme-dark',
|
||||||
},
|
},
|
||||||
|
@ -1,26 +1,40 @@
|
|||||||
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
|
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
import { useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { FaCopy } from 'react-icons/fa';
|
import { FaCopy, FaSave } from 'react-icons/fa';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
copyTooltip: string;
|
label: string;
|
||||||
jsonObject: object;
|
jsonObject: object;
|
||||||
|
fileName?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ImageMetadataJSON = (props: Props) => {
|
const ImageMetadataJSON = (props: Props) => {
|
||||||
const { copyTooltip, jsonObject } = props;
|
const { label, jsonObject, fileName } = props;
|
||||||
const jsonString = useMemo(
|
const jsonString = useMemo(
|
||||||
() => JSON.stringify(jsonObject, null, 2),
|
() => JSON.stringify(jsonObject, null, 2),
|
||||||
[jsonObject]
|
[jsonObject]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleCopy = useCallback(() => {
|
||||||
|
navigator.clipboard.writeText(jsonString);
|
||||||
|
}, [jsonString]);
|
||||||
|
|
||||||
|
const handleSave = useCallback(() => {
|
||||||
|
const blob = new Blob([jsonString]);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = URL.createObjectURL(blob);
|
||||||
|
a.download = `${fileName || label}.json`;
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
a.remove();
|
||||||
|
}, [jsonString, label, fileName]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
|
layerStyle="second"
|
||||||
sx={{
|
sx={{
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
bg: 'whiteAlpha.500',
|
|
||||||
_dark: { bg: 'blackAlpha.500' },
|
|
||||||
flexGrow: 1,
|
flexGrow: 1,
|
||||||
w: 'full',
|
w: 'full',
|
||||||
h: 'full',
|
h: 'full',
|
||||||
@ -36,6 +50,7 @@ const ImageMetadataJSON = (props: Props) => {
|
|||||||
bottom: 0,
|
bottom: 0,
|
||||||
overflow: 'auto',
|
overflow: 'auto',
|
||||||
p: 4,
|
p: 4,
|
||||||
|
fontSize: 'sm',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<OverlayScrollbarsComponent
|
<OverlayScrollbarsComponent
|
||||||
@ -44,7 +59,7 @@ const ImageMetadataJSON = (props: Props) => {
|
|||||||
options={{
|
options={{
|
||||||
scrollbars: {
|
scrollbars: {
|
||||||
visibility: 'auto',
|
visibility: 'auto',
|
||||||
autoHide: 'move',
|
autoHide: 'scroll',
|
||||||
autoHideDelay: 1300,
|
autoHideDelay: 1300,
|
||||||
theme: 'os-theme-dark',
|
theme: 'os-theme-dark',
|
||||||
},
|
},
|
||||||
@ -54,12 +69,22 @@ const ImageMetadataJSON = (props: Props) => {
|
|||||||
</OverlayScrollbarsComponent>
|
</OverlayScrollbarsComponent>
|
||||||
</Box>
|
</Box>
|
||||||
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
|
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
|
||||||
<Tooltip label={copyTooltip}>
|
<Tooltip label={`Save ${label} JSON`}>
|
||||||
<IconButton
|
<IconButton
|
||||||
aria-label={copyTooltip}
|
aria-label={`Save ${label} JSON`}
|
||||||
|
icon={<FaSave />}
|
||||||
|
variant="ghost"
|
||||||
|
opacity={0.7}
|
||||||
|
onClick={handleSave}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
<Tooltip label={`Copy ${label} JSON`}>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`Copy ${label} JSON`}
|
||||||
icon={<FaCopy />}
|
icon={<FaCopy />}
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
onClick={() => navigator.clipboard.writeText(jsonString)}
|
opacity={0.7}
|
||||||
|
onClick={handleCopy}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -10,7 +10,8 @@ import {
|
|||||||
Text,
|
Text,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { memo, useMemo } from 'react';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { memo } from 'react';
|
||||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { useDebounce } from 'use-debounce';
|
import { useDebounce } from 'use-debounce';
|
||||||
@ -41,48 +42,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
const metadata = currentData?.metadata;
|
const metadata = currentData?.metadata;
|
||||||
const graph = currentData?.graph;
|
const graph = currentData?.graph;
|
||||||
|
|
||||||
const tabData = useMemo(() => {
|
|
||||||
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
|
|
||||||
|
|
||||||
if (metadata) {
|
|
||||||
_tabData.push({
|
|
||||||
label: 'Core Metadata',
|
|
||||||
data: metadata,
|
|
||||||
copyTooltip: 'Copy Core Metadata JSON',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (image) {
|
|
||||||
_tabData.push({
|
|
||||||
label: 'Image Details',
|
|
||||||
data: image,
|
|
||||||
copyTooltip: 'Copy Image Details JSON',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (graph) {
|
|
||||||
_tabData.push({
|
|
||||||
label: 'Graph',
|
|
||||||
data: graph,
|
|
||||||
copyTooltip: 'Copy Graph JSON',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return _tabData;
|
|
||||||
}, [metadata, graph, image]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
|
layerStyle="first"
|
||||||
sx={{
|
sx={{
|
||||||
padding: 4,
|
padding: 4,
|
||||||
gap: 1,
|
gap: 1,
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
width: 'full',
|
width: 'full',
|
||||||
height: 'full',
|
height: 'full',
|
||||||
backdropFilter: 'blur(20px)',
|
|
||||||
bg: 'baseAlpha.200',
|
|
||||||
_dark: {
|
|
||||||
bg: 'blackAlpha.600',
|
|
||||||
},
|
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
overflow: 'hidden',
|
overflow: 'hidden',
|
||||||
@ -103,32 +71,33 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
||||||
>
|
>
|
||||||
<TabList>
|
<TabList>
|
||||||
{tabData.map((tab) => (
|
<Tab>Core Metadata</Tab>
|
||||||
<Tab
|
<Tab>Image Details</Tab>
|
||||||
key={tab.label}
|
<Tab>Graph</Tab>
|
||||||
sx={{
|
|
||||||
borderTopRadius: 'base',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
|
|
||||||
{tab.label}
|
|
||||||
</Text>
|
|
||||||
</Tab>
|
|
||||||
))}
|
|
||||||
</TabList>
|
</TabList>
|
||||||
|
|
||||||
<TabPanels sx={{ w: 'full', h: 'full' }}>
|
<TabPanels>
|
||||||
{tabData.map((tab) => (
|
<TabPanel>
|
||||||
<TabPanel
|
{metadata ? (
|
||||||
key={tab.label}
|
<ImageMetadataJSON jsonObject={metadata} label="Core Metadata" />
|
||||||
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
|
) : (
|
||||||
>
|
<IAINoContentFallback label="No core metadata found" />
|
||||||
<ImageMetadataJSON
|
)}
|
||||||
jsonObject={tab.data}
|
</TabPanel>
|
||||||
copyTooltip={tab.copyTooltip}
|
<TabPanel>
|
||||||
/>
|
{image ? (
|
||||||
</TabPanel>
|
<ImageMetadataJSON jsonObject={image} label="Image Details" />
|
||||||
))}
|
) : (
|
||||||
|
<IAINoContentFallback label="No image details found" />
|
||||||
|
)}
|
||||||
|
</TabPanel>
|
||||||
|
<TabPanel>
|
||||||
|
{graph ? (
|
||||||
|
<ImageMetadataJSON jsonObject={graph} label="Graph" />
|
||||||
|
) : (
|
||||||
|
<IAINoContentFallback label="No graph found" />
|
||||||
|
)}
|
||||||
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -9,30 +9,40 @@ import { map } from 'lodash-es';
|
|||||||
import { forwardRef, useCallback } from 'react';
|
import { forwardRef, useCallback } from 'react';
|
||||||
import 'reactflow/dist/style.css';
|
import 'reactflow/dist/style.css';
|
||||||
import { AnyInvocationType } from 'services/events/types';
|
import { AnyInvocationType } from 'services/events/types';
|
||||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
import { useBuildNodeData } from '../hooks/useBuildNodeData';
|
||||||
import { nodeAdded } from '../store/nodesSlice';
|
import { nodeAdded } from '../store/nodesSlice';
|
||||||
|
|
||||||
type NodeTemplate = {
|
type NodeTemplate = {
|
||||||
label: string;
|
label: string;
|
||||||
value: string;
|
value: string;
|
||||||
description: string;
|
description: string;
|
||||||
|
tags: string[];
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ nodes }) => {
|
({ nodes }) => {
|
||||||
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
|
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
|
||||||
return {
|
return {
|
||||||
label: template.title,
|
label: template.title,
|
||||||
value: template.type,
|
value: template.type,
|
||||||
description: template.description,
|
description: template.description,
|
||||||
|
tags: template.tags,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
label: 'Progress Image',
|
label: 'Progress Image',
|
||||||
value: 'progress_image',
|
value: 'current_image',
|
||||||
description: 'Displays the progress image in the Node Editor',
|
description: 'Displays the current image in the Node Editor',
|
||||||
|
tags: ['progress'],
|
||||||
|
});
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
label: 'Notes',
|
||||||
|
value: 'notes',
|
||||||
|
description: 'Add notes about your workflow',
|
||||||
|
tags: ['notes'],
|
||||||
});
|
});
|
||||||
|
|
||||||
return { data };
|
return { data };
|
||||||
@ -44,7 +54,7 @@ const AddNodeMenu = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data } = useAppSelector(selector);
|
const { data } = useAppSelector(selector);
|
||||||
|
|
||||||
const buildInvocation = useBuildInvocation();
|
const buildInvocation = useBuildNodeData();
|
||||||
|
|
||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
|
|
||||||
@ -89,11 +99,12 @@ const AddNodeMenu = () => {
|
|||||||
filter={(value, item: NodeTemplate) =>
|
filter={(value, item: NodeTemplate) =>
|
||||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
item.description.toLowerCase().includes(value.toLowerCase().trim())
|
item.description.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
|
item.tags.includes(value.toLowerCase().trim())
|
||||||
}
|
}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
sx={{
|
sx={{
|
||||||
width: '18rem',
|
width: '24rem',
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -0,0 +1,61 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
|
||||||
|
import { FIELDS, colorTokenToCssVar } from '../types/constants';
|
||||||
|
|
||||||
|
const selector = createSelector(stateSelector, ({ nodes }) => {
|
||||||
|
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
|
||||||
|
nodes;
|
||||||
|
|
||||||
|
const stroke =
|
||||||
|
currentConnectionFieldType && shouldColorEdges
|
||||||
|
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
|
||||||
|
: colorTokenToCssVar('base.500');
|
||||||
|
|
||||||
|
let className = 'react-flow__custom_connection-path';
|
||||||
|
|
||||||
|
if (shouldAnimateEdges) {
|
||||||
|
className = className.concat(' animated');
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
stroke,
|
||||||
|
className,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
export const CustomConnectionLine = ({
|
||||||
|
fromX,
|
||||||
|
fromY,
|
||||||
|
fromPosition,
|
||||||
|
toX,
|
||||||
|
toY,
|
||||||
|
toPosition,
|
||||||
|
}: ConnectionLineComponentProps) => {
|
||||||
|
const { stroke, className } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const pathParams = {
|
||||||
|
sourceX: fromX,
|
||||||
|
sourceY: fromY,
|
||||||
|
sourcePosition: fromPosition,
|
||||||
|
targetX: toX,
|
||||||
|
targetY: toY,
|
||||||
|
targetPosition: toPosition,
|
||||||
|
};
|
||||||
|
|
||||||
|
const [dAttr] = getBezierPath(pathParams);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<g>
|
||||||
|
<path
|
||||||
|
fill="none"
|
||||||
|
stroke={stroke}
|
||||||
|
strokeWidth={2}
|
||||||
|
className={className}
|
||||||
|
d={dAttr}
|
||||||
|
style={{ opacity: 0.8 }}
|
||||||
|
/>
|
||||||
|
</g>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,183 @@
|
|||||||
|
import { Badge, Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import {
|
||||||
|
BaseEdge,
|
||||||
|
EdgeLabelRenderer,
|
||||||
|
EdgeProps,
|
||||||
|
getBezierPath,
|
||||||
|
} from 'reactflow';
|
||||||
|
import { FIELDS, colorTokenToCssVar } from '../types/constants';
|
||||||
|
import { isInvocationNode } from '../types/types';
|
||||||
|
|
||||||
|
const makeEdgeSelector = (
|
||||||
|
source: string,
|
||||||
|
sourceHandleId: string | null | undefined,
|
||||||
|
target: string,
|
||||||
|
targetHandleId: string | null | undefined,
|
||||||
|
selected?: boolean
|
||||||
|
) =>
|
||||||
|
createSelector(stateSelector, ({ nodes }) => {
|
||||||
|
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||||
|
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||||
|
|
||||||
|
const isInvocationToInvocationEdge =
|
||||||
|
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||||
|
|
||||||
|
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
|
||||||
|
const sourceType = isInvocationToInvocationEdge
|
||||||
|
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
const stroke =
|
||||||
|
sourceType && nodes.shouldColorEdges
|
||||||
|
? colorTokenToCssVar(FIELDS[sourceType].color)
|
||||||
|
: colorTokenToCssVar('base.500');
|
||||||
|
|
||||||
|
return {
|
||||||
|
isSelected,
|
||||||
|
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||||
|
stroke,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const CollapsedEdge = ({
|
||||||
|
sourceX,
|
||||||
|
sourceY,
|
||||||
|
targetX,
|
||||||
|
targetY,
|
||||||
|
sourcePosition,
|
||||||
|
targetPosition,
|
||||||
|
markerEnd,
|
||||||
|
data,
|
||||||
|
selected,
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
sourceHandleId,
|
||||||
|
targetHandleId,
|
||||||
|
}: EdgeProps<{ count: number }>) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
makeEdgeSelector(
|
||||||
|
source,
|
||||||
|
sourceHandleId,
|
||||||
|
target,
|
||||||
|
targetHandleId,
|
||||||
|
selected
|
||||||
|
),
|
||||||
|
[selected, source, sourceHandleId, target, targetHandleId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { isSelected, shouldAnimate } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const [edgePath, labelX, labelY] = getBezierPath({
|
||||||
|
sourceX,
|
||||||
|
sourceY,
|
||||||
|
sourcePosition,
|
||||||
|
targetX,
|
||||||
|
targetY,
|
||||||
|
targetPosition,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { base500 } = useChakraThemeTokens();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<BaseEdge
|
||||||
|
path={edgePath}
|
||||||
|
markerEnd={markerEnd}
|
||||||
|
style={{
|
||||||
|
strokeWidth: isSelected ? 3 : 2,
|
||||||
|
stroke: base500,
|
||||||
|
opacity: isSelected ? 0.8 : 0.5,
|
||||||
|
animation: shouldAnimate
|
||||||
|
? 'dashdraw 0.5s linear infinite'
|
||||||
|
: undefined,
|
||||||
|
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{data?.count && data.count > 1 && (
|
||||||
|
<EdgeLabelRenderer>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
|
||||||
|
}}
|
||||||
|
className="nodrag nopan"
|
||||||
|
>
|
||||||
|
<Badge
|
||||||
|
variant="solid"
|
||||||
|
sx={{
|
||||||
|
bg: 'base.500',
|
||||||
|
opacity: isSelected ? 0.8 : 0.5,
|
||||||
|
boxShadow: 'base',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{data.count}
|
||||||
|
</Badge>
|
||||||
|
</Flex>
|
||||||
|
</EdgeLabelRenderer>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const DefaultEdge = ({
|
||||||
|
sourceX,
|
||||||
|
sourceY,
|
||||||
|
targetX,
|
||||||
|
targetY,
|
||||||
|
sourcePosition,
|
||||||
|
targetPosition,
|
||||||
|
markerEnd,
|
||||||
|
selected,
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
sourceHandleId,
|
||||||
|
targetHandleId,
|
||||||
|
}: EdgeProps) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
makeEdgeSelector(
|
||||||
|
source,
|
||||||
|
sourceHandleId,
|
||||||
|
target,
|
||||||
|
targetHandleId,
|
||||||
|
selected
|
||||||
|
),
|
||||||
|
[source, sourceHandleId, target, targetHandleId, selected]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const [edgePath] = getBezierPath({
|
||||||
|
sourceX,
|
||||||
|
sourceY,
|
||||||
|
sourcePosition,
|
||||||
|
targetX,
|
||||||
|
targetY,
|
||||||
|
targetPosition,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<BaseEdge
|
||||||
|
path={edgePath}
|
||||||
|
markerEnd={markerEnd}
|
||||||
|
style={{
|
||||||
|
strokeWidth: isSelected ? 3 : 2,
|
||||||
|
stroke,
|
||||||
|
opacity: isSelected ? 0.8 : 0.5,
|
||||||
|
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
|
||||||
|
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const edgeTypes = {
|
||||||
|
collapsed: CollapsedEdge,
|
||||||
|
default: DefaultEdge,
|
||||||
|
};
|
@ -0,0 +1,9 @@
|
|||||||
|
import InvocationNode from './nodes/InvocationNode';
|
||||||
|
import CurrentImageNode from './nodes/CurrentImageNode';
|
||||||
|
import NotesNode from './nodes/NotesNode';
|
||||||
|
|
||||||
|
export const nodeTypes = {
|
||||||
|
invocation: InvocationNode,
|
||||||
|
current_image: CurrentImageNode,
|
||||||
|
notes: NotesNode,
|
||||||
|
};
|
@ -1,64 +0,0 @@
|
|||||||
import { Tooltip } from '@chakra-ui/react';
|
|
||||||
import { CSSProperties, memo } from 'react';
|
|
||||||
import { Handle, Position, Connection, HandleType } from 'reactflow';
|
|
||||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
|
|
||||||
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
|
||||||
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
|
|
||||||
|
|
||||||
const handleBaseStyles: CSSProperties = {
|
|
||||||
position: 'absolute',
|
|
||||||
width: '1rem',
|
|
||||||
height: '1rem',
|
|
||||||
borderWidth: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
const inputHandleStyles: CSSProperties = {
|
|
||||||
left: '-1rem',
|
|
||||||
};
|
|
||||||
|
|
||||||
const outputHandleStyles: CSSProperties = {
|
|
||||||
right: '-0.5rem',
|
|
||||||
};
|
|
||||||
|
|
||||||
// const requiredConnectionStyles: CSSProperties = {
|
|
||||||
// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
|
|
||||||
// };
|
|
||||||
|
|
||||||
type FieldHandleProps = {
|
|
||||||
nodeId: string;
|
|
||||||
field: InputFieldTemplate | OutputFieldTemplate;
|
|
||||||
isValidConnection: (connection: Connection) => boolean;
|
|
||||||
handleType: HandleType;
|
|
||||||
styles?: CSSProperties;
|
|
||||||
};
|
|
||||||
|
|
||||||
const FieldHandle = (props: FieldHandleProps) => {
|
|
||||||
const { field, isValidConnection, handleType, styles } = props;
|
|
||||||
const { name, type } = field;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Tooltip
|
|
||||||
label={type}
|
|
||||||
placement={handleType === 'target' ? 'start' : 'end'}
|
|
||||||
hasArrow
|
|
||||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
|
||||||
>
|
|
||||||
<Handle
|
|
||||||
type={handleType}
|
|
||||||
id={name}
|
|
||||||
isValidConnection={isValidConnection}
|
|
||||||
position={handleType === 'target' ? Position.Left : Position.Right}
|
|
||||||
style={{
|
|
||||||
backgroundColor: FIELDS[type].colorCssVar,
|
|
||||||
...styles,
|
|
||||||
...handleBaseStyles,
|
|
||||||
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
|
||||||
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
|
|
||||||
// ...connectionEventStyles,
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(FieldHandle);
|
|
@ -1,8 +1,8 @@
|
|||||||
import 'reactflow/dist/style.css';
|
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
|
||||||
import { Tooltip, Badge, Flex } from '@chakra-ui/react';
|
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { FIELDS } from '../types/constants';
|
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
import 'reactflow/dist/style.css';
|
||||||
|
import { FIELDS } from '../types/constants';
|
||||||
|
|
||||||
const FieldTypeLegend = () => {
|
const FieldTypeLegend = () => {
|
||||||
return (
|
return (
|
||||||
@ -10,8 +10,14 @@ const FieldTypeLegend = () => {
|
|||||||
{map(FIELDS, ({ title, description, color }, key) => (
|
{map(FIELDS, ({ title, description, color }, key) => (
|
||||||
<Tooltip key={key} label={description}>
|
<Tooltip key={key} label={description}>
|
||||||
<Badge
|
<Badge
|
||||||
colorScheme={color}
|
sx={{
|
||||||
sx={{ userSelect: 'none' }}
|
userSelect: 'none',
|
||||||
|
color:
|
||||||
|
parseInt(color.split('.')[1] ?? '0', 10) < 500
|
||||||
|
? 'base.800'
|
||||||
|
: 'base.50',
|
||||||
|
bg: color,
|
||||||
|
}}
|
||||||
textAlign="center"
|
textAlign="center"
|
||||||
>
|
>
|
||||||
{title}
|
{title}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import {
|
import {
|
||||||
@ -7,35 +6,49 @@ import {
|
|||||||
OnConnectEnd,
|
OnConnectEnd,
|
||||||
OnConnectStart,
|
OnConnectStart,
|
||||||
OnEdgesChange,
|
OnEdgesChange,
|
||||||
|
OnEdgesDelete,
|
||||||
OnInit,
|
OnInit,
|
||||||
|
OnMove,
|
||||||
OnNodesChange,
|
OnNodesChange,
|
||||||
|
OnNodesDelete,
|
||||||
|
OnSelectionChangeFunc,
|
||||||
|
ProOptions,
|
||||||
ReactFlow,
|
ReactFlow,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
|
import { useIsValidConnection } from '../hooks/useIsValidConnection';
|
||||||
import {
|
import {
|
||||||
connectionEnded,
|
connectionEnded,
|
||||||
connectionMade,
|
connectionMade,
|
||||||
connectionStarted,
|
connectionStarted,
|
||||||
edgesChanged,
|
edgesChanged,
|
||||||
|
edgesDeleted,
|
||||||
nodesChanged,
|
nodesChanged,
|
||||||
setEditorInstance,
|
nodesDeleted,
|
||||||
|
selectedEdgesChanged,
|
||||||
|
selectedNodesChanged,
|
||||||
|
zoomChanged,
|
||||||
} from '../store/nodesSlice';
|
} from '../store/nodesSlice';
|
||||||
import { InvocationComponent } from './InvocationComponent';
|
import { CustomConnectionLine } from './CustomConnectionLine';
|
||||||
import ProgressImageNode from './ProgressImageNode';
|
import { edgeTypes } from './CustomEdges';
|
||||||
import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
|
import { nodeTypes } from './CustomNodes';
|
||||||
import MinimapPanel from './panels/MinimapPanel';
|
import BottomLeftPanel from './editorPanels/BottomLeftPanel';
|
||||||
import TopCenterPanel from './panels/TopCenterPanel';
|
import MinimapPanel from './editorPanels/MinimapPanel';
|
||||||
import TopLeftPanel from './panels/TopLeftPanel';
|
import TopCenterPanel from './editorPanels/TopCenterPanel';
|
||||||
import TopRightPanel from './panels/TopRightPanel';
|
import TopLeftPanel from './editorPanels/TopLeftPanel';
|
||||||
|
import TopRightPanel from './editorPanels/TopRightPanel';
|
||||||
|
|
||||||
const nodeTypes = {
|
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app
|
||||||
invocation: InvocationComponent,
|
const proOptions: ProOptions = { hideAttribution: true };
|
||||||
progress_image: ProgressImageNode,
|
|
||||||
};
|
|
||||||
|
|
||||||
export const Flow = () => {
|
export const Flow = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
const nodes = useAppSelector((state) => state.nodes.nodes);
|
||||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
const edges = useAppSelector((state) => state.nodes.edges);
|
||||||
|
const shouldSnapToGrid = useAppSelector(
|
||||||
|
(state) => state.nodes.shouldSnapToGrid
|
||||||
|
);
|
||||||
|
|
||||||
|
const isValidConnection = useIsValidConnection();
|
||||||
|
|
||||||
const onNodesChange: OnNodesChange = useCallback(
|
const onNodesChange: OnNodesChange = useCallback(
|
||||||
(changes) => {
|
(changes) => {
|
||||||
@ -69,10 +82,36 @@ export const Flow = () => {
|
|||||||
dispatch(connectionEnded());
|
dispatch(connectionEnded());
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const onInit: OnInit = useCallback(
|
const onInit: OnInit = useCallback((v) => {
|
||||||
(v) => {
|
v.fitView();
|
||||||
dispatch(setEditorInstance(v));
|
}, []);
|
||||||
if (v) v.fitView();
|
|
||||||
|
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||||
|
(edges) => {
|
||||||
|
dispatch(edgesDeleted(edges));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onNodesDelete: OnNodesDelete = useCallback(
|
||||||
|
(nodes) => {
|
||||||
|
dispatch(nodesDeleted(nodes));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSelectionChange: OnSelectionChangeFunc = useCallback(
|
||||||
|
({ nodes, edges }) => {
|
||||||
|
dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
|
||||||
|
dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleMove: OnMove = useCallback(
|
||||||
|
(e, viewport) => {
|
||||||
|
const { zoom } = viewport;
|
||||||
|
dispatch(zoomChanged(zoom));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
@ -80,24 +119,33 @@ export const Flow = () => {
|
|||||||
return (
|
return (
|
||||||
<ReactFlow
|
<ReactFlow
|
||||||
nodeTypes={nodeTypes}
|
nodeTypes={nodeTypes}
|
||||||
|
edgeTypes={edgeTypes}
|
||||||
nodes={nodes}
|
nodes={nodes}
|
||||||
edges={edges}
|
edges={edges}
|
||||||
onNodesChange={onNodesChange}
|
onNodesChange={onNodesChange}
|
||||||
onEdgesChange={onEdgesChange}
|
onEdgesChange={onEdgesChange}
|
||||||
|
onEdgesDelete={onEdgesDelete}
|
||||||
|
onNodesDelete={onNodesDelete}
|
||||||
onConnectStart={onConnectStart}
|
onConnectStart={onConnectStart}
|
||||||
onConnect={onConnect}
|
onConnect={onConnect}
|
||||||
onConnectEnd={onConnectEnd}
|
onConnectEnd={onConnectEnd}
|
||||||
|
onMove={handleMove}
|
||||||
|
connectionLineComponent={CustomConnectionLine}
|
||||||
|
onSelectionChange={handleSelectionChange}
|
||||||
onInit={onInit}
|
onInit={onInit}
|
||||||
defaultEdgeOptions={{
|
isValidConnection={isValidConnection}
|
||||||
style: { strokeWidth: 2 },
|
minZoom={0.2}
|
||||||
}}
|
snapToGrid={shouldSnapToGrid}
|
||||||
|
snapGrid={[25, 25]}
|
||||||
|
connectionRadius={30}
|
||||||
|
proOptions={proOptions}
|
||||||
>
|
>
|
||||||
<TopLeftPanel />
|
<TopLeftPanel />
|
||||||
<TopCenterPanel />
|
<TopCenterPanel />
|
||||||
<TopRightPanel />
|
<TopRightPanel />
|
||||||
<BottomLeftPanel />
|
<BottomLeftPanel />
|
||||||
<Background />
|
|
||||||
<MinimapPanel />
|
<MinimapPanel />
|
||||||
|
<Background />
|
||||||
</ReactFlow>
|
</ReactFlow>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
|
|
||||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { FaInfoCircle } from 'react-icons/fa';
|
|
||||||
|
|
||||||
interface IAINodeHeaderProps {
|
|
||||||
nodeId?: string;
|
|
||||||
title?: string;
|
|
||||||
description?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
|
||||||
const { nodeId, title, description } = props;
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
className={DRAG_HANDLE_CLASSNAME}
|
|
||||||
sx={{
|
|
||||||
borderTopRadius: 'md',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'space-between',
|
|
||||||
px: 2,
|
|
||||||
py: 1,
|
|
||||||
bg: 'base.100',
|
|
||||||
_dark: { bg: 'base.900' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Tooltip label={nodeId}>
|
|
||||||
<Heading
|
|
||||||
size="xs"
|
|
||||||
sx={{
|
|
||||||
fontWeight: 600,
|
|
||||||
color: 'base.900',
|
|
||||||
_dark: { color: 'base.200' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{title}
|
|
||||||
</Heading>
|
|
||||||
</Tooltip>
|
|
||||||
<Tooltip label={description} placement="top" hasArrow shouldWrapChildren>
|
|
||||||
<Icon
|
|
||||||
sx={{
|
|
||||||
h: 'min-content',
|
|
||||||
color: 'base.700',
|
|
||||||
_dark: {
|
|
||||||
color: 'base.300',
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
as={FaInfoCircle}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(IAINodeHeader);
|
|
@ -1,149 +0,0 @@
|
|||||||
import {
|
|
||||||
Box,
|
|
||||||
Divider,
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
HStack,
|
|
||||||
Tooltip,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
|
||||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
|
||||||
import {
|
|
||||||
InputFieldTemplate,
|
|
||||||
InputFieldValue,
|
|
||||||
InvocationTemplate,
|
|
||||||
} from 'features/nodes/types/types';
|
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { ReactNode, memo, useCallback } from 'react';
|
|
||||||
import FieldHandle from '../FieldHandle';
|
|
||||||
import InputFieldComponent from '../InputFieldComponent';
|
|
||||||
|
|
||||||
interface IAINodeInputProps {
|
|
||||||
nodeId: string;
|
|
||||||
|
|
||||||
input: InputFieldValue;
|
|
||||||
template?: InputFieldTemplate | undefined;
|
|
||||||
connected: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
function IAINodeInput(props: IAINodeInputProps) {
|
|
||||||
const { nodeId, input, template, connected } = props;
|
|
||||||
const isValidConnection = useIsValidConnection();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Box
|
|
||||||
className="nopan"
|
|
||||||
position="relative"
|
|
||||||
borderColor={
|
|
||||||
!template
|
|
||||||
? 'error.400'
|
|
||||||
: !connected &&
|
|
||||||
['always', 'connectionOnly'].includes(
|
|
||||||
String(template?.inputRequirement)
|
|
||||||
) &&
|
|
||||||
input.value === undefined
|
|
||||||
? 'warning.400'
|
|
||||||
: undefined
|
|
||||||
}
|
|
||||||
>
|
|
||||||
<FormControl isDisabled={!template ? true : connected} pl={2}>
|
|
||||||
{!template ? (
|
|
||||||
<HStack justifyContent="space-between" alignItems="center">
|
|
||||||
<FormLabel>Unknown input: {input.name}</FormLabel>
|
|
||||||
</HStack>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<HStack justifyContent="space-between" alignItems="center">
|
|
||||||
<HStack>
|
|
||||||
<Tooltip
|
|
||||||
label={template?.description}
|
|
||||||
placement="top"
|
|
||||||
hasArrow
|
|
||||||
shouldWrapChildren
|
|
||||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
|
||||||
>
|
|
||||||
<FormLabel>{template?.title}</FormLabel>
|
|
||||||
</Tooltip>
|
|
||||||
</HStack>
|
|
||||||
<InputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={input}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
</HStack>
|
|
||||||
|
|
||||||
{!['never', 'directOnly'].includes(
|
|
||||||
template?.inputRequirement ?? ''
|
|
||||||
) && (
|
|
||||||
<FieldHandle
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={template}
|
|
||||||
isValidConnection={isValidConnection}
|
|
||||||
handleType="target"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</FormControl>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
interface IAINodeInputsProps {
|
|
||||||
nodeId: string;
|
|
||||||
template: InvocationTemplate;
|
|
||||||
inputs: Record<string, InputFieldValue>;
|
|
||||||
}
|
|
||||||
|
|
||||||
const IAINodeInputs = (props: IAINodeInputsProps) => {
|
|
||||||
const { nodeId, template, inputs } = props;
|
|
||||||
|
|
||||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
|
||||||
|
|
||||||
const renderIAINodeInputs = useCallback(() => {
|
|
||||||
const IAINodeInputsToRender: ReactNode[] = [];
|
|
||||||
const inputSockets = map(inputs);
|
|
||||||
|
|
||||||
inputSockets.forEach((inputSocket, index) => {
|
|
||||||
const inputTemplate = template.inputs[inputSocket.name];
|
|
||||||
|
|
||||||
const isConnected = Boolean(
|
|
||||||
edges.filter((connectedInput) => {
|
|
||||||
return (
|
|
||||||
connectedInput.target === nodeId &&
|
|
||||||
connectedInput.targetHandle === inputSocket.name
|
|
||||||
);
|
|
||||||
}).length
|
|
||||||
);
|
|
||||||
|
|
||||||
if (index < inputSockets.length) {
|
|
||||||
IAINodeInputsToRender.push(
|
|
||||||
<Divider key={`${inputSocket.id}.divider`} />
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
IAINodeInputsToRender.push(
|
|
||||||
<IAINodeInput
|
|
||||||
key={inputSocket.id}
|
|
||||||
nodeId={nodeId}
|
|
||||||
input={inputSocket}
|
|
||||||
template={inputTemplate}
|
|
||||||
connected={isConnected}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex className="nopan" flexDir="column" gap={2} p={2}>
|
|
||||||
{IAINodeInputsToRender}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}, [edges, inputs, nodeId, template.inputs]);
|
|
||||||
|
|
||||||
return renderIAINodeInputs();
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(IAINodeInputs);
|
|
@ -1,97 +0,0 @@
|
|||||||
import {
|
|
||||||
InvocationTemplate,
|
|
||||||
OutputFieldTemplate,
|
|
||||||
OutputFieldValue,
|
|
||||||
} from 'features/nodes/types/types';
|
|
||||||
import { memo, ReactNode, useCallback } from 'react';
|
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react';
|
|
||||||
import FieldHandle from '../FieldHandle';
|
|
||||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
|
||||||
|
|
||||||
interface IAINodeOutputProps {
|
|
||||||
nodeId: string;
|
|
||||||
output: OutputFieldValue;
|
|
||||||
template?: OutputFieldTemplate | undefined;
|
|
||||||
connected: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
function IAINodeOutput(props: IAINodeOutputProps) {
|
|
||||||
const { nodeId, output, template, connected } = props;
|
|
||||||
const isValidConnection = useIsValidConnection();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Box position="relative">
|
|
||||||
<FormControl isDisabled={!template ? true : connected} paddingRight={3}>
|
|
||||||
{!template ? (
|
|
||||||
<HStack justifyContent="space-between" alignItems="center">
|
|
||||||
<FormLabel color="error.400">
|
|
||||||
Unknown Output: {output.name}
|
|
||||||
</FormLabel>
|
|
||||||
</HStack>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<FormLabel textAlign="end" padding={1}>
|
|
||||||
{template?.title}
|
|
||||||
</FormLabel>
|
|
||||||
<FieldHandle
|
|
||||||
key={output.id}
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={template}
|
|
||||||
isValidConnection={isValidConnection}
|
|
||||||
handleType="source"
|
|
||||||
/>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</FormControl>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
interface IAINodeOutputsProps {
|
|
||||||
nodeId: string;
|
|
||||||
template: InvocationTemplate;
|
|
||||||
outputs: Record<string, OutputFieldValue>;
|
|
||||||
}
|
|
||||||
|
|
||||||
const IAINodeOutputs = (props: IAINodeOutputsProps) => {
|
|
||||||
const { nodeId, template, outputs } = props;
|
|
||||||
|
|
||||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
|
||||||
|
|
||||||
const renderIAINodeOutputs = useCallback(() => {
|
|
||||||
const IAINodeOutputsToRender: ReactNode[] = [];
|
|
||||||
const outputSockets = map(outputs);
|
|
||||||
|
|
||||||
outputSockets.forEach((outputSocket) => {
|
|
||||||
const outputTemplate = template.outputs[outputSocket.name];
|
|
||||||
|
|
||||||
const isConnected = Boolean(
|
|
||||||
edges.filter((connectedInput) => {
|
|
||||||
return (
|
|
||||||
connectedInput.source === nodeId &&
|
|
||||||
connectedInput.sourceHandle === outputSocket.name
|
|
||||||
);
|
|
||||||
}).length
|
|
||||||
);
|
|
||||||
|
|
||||||
IAINodeOutputsToRender.push(
|
|
||||||
<IAINodeOutput
|
|
||||||
key={outputSocket.id}
|
|
||||||
nodeId={nodeId}
|
|
||||||
output={outputSocket}
|
|
||||||
template={outputTemplate}
|
|
||||||
connected={isConnected}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
return <Flex flexDir="column">{IAINodeOutputsToRender}</Flex>;
|
|
||||||
}, [edges, nodeId, outputs, template.outputs]);
|
|
||||||
|
|
||||||
return renderIAINodeOutputs();
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(IAINodeOutputs);
|
|
@ -1,252 +0,0 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
|
||||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
|
||||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
|
||||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
|
||||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
|
||||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
|
||||||
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
|
|
||||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
|
||||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
|
||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
|
||||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
|
||||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
|
||||||
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
|
|
||||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
|
||||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
|
||||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
|
||||||
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
|
||||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
|
||||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
|
||||||
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
|
|
||||||
|
|
||||||
type InputFieldComponentProps = {
|
|
||||||
nodeId: string;
|
|
||||||
field: InputFieldValue;
|
|
||||||
template: InputFieldTemplate;
|
|
||||||
};
|
|
||||||
|
|
||||||
// build an individual input element based on the schema
|
|
||||||
const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|
||||||
const { nodeId, field, template } = props;
|
|
||||||
const { type } = field;
|
|
||||||
|
|
||||||
if (type === 'string' && template.type === 'string') {
|
|
||||||
return (
|
|
||||||
<StringInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'boolean' && template.type === 'boolean') {
|
|
||||||
return (
|
|
||||||
<BooleanInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
(type === 'integer' && template.type === 'integer') ||
|
|
||||||
(type === 'float' && template.type === 'float')
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<NumberInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'enum' && template.type === 'enum') {
|
|
||||||
return (
|
|
||||||
<EnumInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'image' && template.type === 'image') {
|
|
||||||
return (
|
|
||||||
<ImageInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'latents' && template.type === 'latents') {
|
|
||||||
return (
|
|
||||||
<LatentsInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'conditioning' && template.type === 'conditioning') {
|
|
||||||
return (
|
|
||||||
<ConditioningInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'unet' && template.type === 'unet') {
|
|
||||||
return (
|
|
||||||
<UnetInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'clip' && template.type === 'clip') {
|
|
||||||
return (
|
|
||||||
<ClipInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'vae' && template.type === 'vae') {
|
|
||||||
return (
|
|
||||||
<VaeInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'control' && template.type === 'control') {
|
|
||||||
return (
|
|
||||||
<ControlInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'model' && template.type === 'model') {
|
|
||||||
return (
|
|
||||||
<ModelInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'refiner_model' && template.type === 'refiner_model') {
|
|
||||||
return (
|
|
||||||
<RefinerModelInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
|
||||||
return (
|
|
||||||
<VaeModelInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'lora_model' && template.type === 'lora_model') {
|
|
||||||
return (
|
|
||||||
<LoRAModelInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
|
|
||||||
return (
|
|
||||||
<ControlNetModelInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'array' && template.type === 'array') {
|
|
||||||
return (
|
|
||||||
<ArrayInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'item' && template.type === 'item') {
|
|
||||||
return (
|
|
||||||
<ItemInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'color' && template.type === 'color') {
|
|
||||||
return (
|
|
||||||
<ColorInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'item' && template.type === 'item') {
|
|
||||||
return (
|
|
||||||
<ItemInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'image_collection' && template.type === 'image_collection') {
|
|
||||||
return (
|
|
||||||
<ImageCollectionInputFieldComponent
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
template={template}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return <Box p={2}>Unknown field type: {type}</Box>;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(InputFieldComponent);
|
|
@ -0,0 +1,57 @@
|
|||||||
|
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { NodeData } from 'features/nodes/types/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { NodeProps, useUpdateNodeInternals } from 'reactflow';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
nodeProps: NodeProps<NodeData>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeCollapseButton = (props: Props) => {
|
||||||
|
const { id: nodeId, isOpen } = props.nodeProps.data;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const updateNodeInternals = useUpdateNodeInternals();
|
||||||
|
|
||||||
|
const handleClick = useCallback(() => {
|
||||||
|
dispatch(nodeIsOpenChanged({ nodeId, isOpen: !isOpen }));
|
||||||
|
updateNodeInternals(nodeId);
|
||||||
|
}, [dispatch, isOpen, nodeId, updateNodeInternals]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIIconButton
|
||||||
|
className="nopan"
|
||||||
|
onClick={handleClick}
|
||||||
|
aria-label="Minimize"
|
||||||
|
sx={{
|
||||||
|
minW: 8,
|
||||||
|
w: 8,
|
||||||
|
h: 8,
|
||||||
|
color: 'base.500',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.500',
|
||||||
|
},
|
||||||
|
_hover: {
|
||||||
|
color: 'base.700',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.300',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
variant="link"
|
||||||
|
icon={
|
||||||
|
<ChevronUpIcon
|
||||||
|
sx={{
|
||||||
|
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NodeCollapseButton);
|
@ -0,0 +1,74 @@
|
|||||||
|
import { useColorModeValue } from '@chakra-ui/react';
|
||||||
|
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||||
|
import {
|
||||||
|
InvocationNodeData,
|
||||||
|
InvocationTemplate,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { CSSProperties, memo, useMemo } from 'react';
|
||||||
|
import { Handle, NodeProps, Position } from 'reactflow';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
nodeProps: NodeProps<InvocationNodeData>;
|
||||||
|
nodeTemplate: InvocationTemplate;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeCollapsedHandles = (props: Props) => {
|
||||||
|
const { data } = props.nodeProps;
|
||||||
|
const { base400, base600 } = useChakraThemeTokens();
|
||||||
|
const backgroundColor = useColorModeValue(base400, base600);
|
||||||
|
|
||||||
|
const dummyHandleStyles: CSSProperties = useMemo(
|
||||||
|
() => ({
|
||||||
|
borderWidth: 0,
|
||||||
|
borderRadius: '3px',
|
||||||
|
width: '1rem',
|
||||||
|
height: '1rem',
|
||||||
|
backgroundColor,
|
||||||
|
zIndex: -1,
|
||||||
|
}),
|
||||||
|
[backgroundColor]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Handle
|
||||||
|
type="target"
|
||||||
|
id={`${data.id}-collapsed-target`}
|
||||||
|
isConnectable={false}
|
||||||
|
position={Position.Left}
|
||||||
|
style={{ ...dummyHandleStyles, left: '-0.5rem' }}
|
||||||
|
/>
|
||||||
|
{map(data.inputs, (input) => (
|
||||||
|
<Handle
|
||||||
|
key={`${data.id}-${input.name}-collapsed-input-handle`}
|
||||||
|
type="target"
|
||||||
|
id={input.name}
|
||||||
|
isValidConnection={() => false}
|
||||||
|
position={Position.Left}
|
||||||
|
style={{ visibility: 'hidden' }}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
<Handle
|
||||||
|
type="source"
|
||||||
|
id={`${data.id}-collapsed-source`}
|
||||||
|
isValidConnection={() => false}
|
||||||
|
isConnectable={false}
|
||||||
|
position={Position.Right}
|
||||||
|
style={{ ...dummyHandleStyles, right: '-0.5rem' }}
|
||||||
|
/>
|
||||||
|
{map(data.outputs, (output) => (
|
||||||
|
<Handle
|
||||||
|
key={`${data.id}-${output.name}-collapsed-output-handle`}
|
||||||
|
type="source"
|
||||||
|
id={output.name}
|
||||||
|
isValidConnection={() => false}
|
||||||
|
position={Position.Right}
|
||||||
|
style={{ visibility: 'hidden' }}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NodeCollapsedHandles);
|
@ -0,0 +1,77 @@
|
|||||||
|
import {
|
||||||
|
Checkbox,
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Spacer,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
|
import {
|
||||||
|
InvocationNodeData,
|
||||||
|
InvocationTemplate,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { some } from 'lodash-es';
|
||||||
|
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
|
||||||
|
import { NodeProps } from 'reactflow';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
nodeProps: NodeProps<InvocationNodeData>;
|
||||||
|
nodeTemplate: InvocationTemplate;
|
||||||
|
};
|
||||||
|
|
||||||
|
const NodeFooter = (props: Props) => {
|
||||||
|
const { nodeProps, nodeTemplate } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const hasImageOutput = useMemo(
|
||||||
|
() =>
|
||||||
|
some(nodeTemplate?.outputs, (output) =>
|
||||||
|
['ImageField', 'ImageCollection'].includes(output.type)
|
||||||
|
),
|
||||||
|
[nodeTemplate?.outputs]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeIsIntermediate = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(
|
||||||
|
fieldBooleanValueChanged({
|
||||||
|
nodeId: nodeProps.data.id,
|
||||||
|
fieldName: 'is_intermediate',
|
||||||
|
value: !e.target.checked,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, nodeProps.data.id]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
|
layerStyle="nodeFooter"
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
borderBottomRadius: 'base',
|
||||||
|
px: 2,
|
||||||
|
py: 0,
|
||||||
|
h: 6,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Spacer />
|
||||||
|
{hasImageOutput && (
|
||||||
|
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||||
|
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
|
||||||
|
<Checkbox
|
||||||
|
className="nopan"
|
||||||
|
size="sm"
|
||||||
|
onChange={handleChangeIsIntermediate}
|
||||||
|
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NodeFooter);
|
@ -0,0 +1,113 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Icon,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalFooter,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
useDisclosure,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAITextarea from 'common/components/IAITextarea';
|
||||||
|
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
|
import {
|
||||||
|
InvocationNodeData,
|
||||||
|
InvocationTemplate,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
import { FaInfoCircle } from 'react-icons/fa';
|
||||||
|
import { NodeProps } from 'reactflow';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
nodeProps: NodeProps<InvocationNodeData>;
|
||||||
|
nodeTemplate: InvocationTemplate;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeNotesEdit = (props: Props) => {
|
||||||
|
const { nodeProps, nodeTemplate } = props;
|
||||||
|
const { data } = nodeProps;
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const handleNotesChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
|
||||||
|
},
|
||||||
|
[data.id, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Tooltip
|
||||||
|
label={
|
||||||
|
nodeTemplate ? (
|
||||||
|
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||||
|
) : undefined
|
||||||
|
}
|
||||||
|
placement="top"
|
||||||
|
shouldWrapChildren
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
|
onClick={onOpen}
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
w: 8,
|
||||||
|
h: 8,
|
||||||
|
cursor: 'pointer',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Icon
|
||||||
|
as={FaInfoCircle}
|
||||||
|
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</Tooltip>
|
||||||
|
|
||||||
|
<Modal isOpen={isOpen} onClose={onClose} isCentered>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent>
|
||||||
|
<ModalHeader>
|
||||||
|
{data.label || nodeTemplate?.title || 'Unknown Node'}
|
||||||
|
</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody>
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel>Notes</FormLabel>
|
||||||
|
<IAITextarea
|
||||||
|
value={data.notes}
|
||||||
|
onChange={handleNotesChanged}
|
||||||
|
rows={10}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</ModalBody>
|
||||||
|
<ModalFooter />
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NodeNotesEdit);
|
||||||
|
|
||||||
|
type TooltipContentProps = Props;
|
||||||
|
|
||||||
|
const TooltipContent = (props: TooltipContentProps) => {
|
||||||
|
return (
|
||||||
|
<Flex sx={{ flexDir: 'column' }}>
|
||||||
|
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text>
|
||||||
|
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||||
|
{props.nodeTemplate?.description}
|
||||||
|
</Text>
|
||||||
|
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -2,7 +2,10 @@ import { NODE_MIN_WIDTH } from 'features/nodes/types/constants';
|
|||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
|
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
|
||||||
|
|
||||||
const IAINodeResizer = (props: NodeResizerProps) => {
|
// this causes https://github.com/invoke-ai/InvokeAI/issues/4140
|
||||||
|
// not using it for now
|
||||||
|
|
||||||
|
const NodeResizer = (props: NodeResizerProps) => {
|
||||||
const { ...rest } = props;
|
const { ...rest } = props;
|
||||||
return (
|
return (
|
||||||
<NodeResizeControl
|
<NodeResizeControl
|
||||||
@ -21,4 +24,4 @@ const IAINodeResizer = (props: NodeResizerProps) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(IAINodeResizer);
|
export default memo(NodeResizer);
|
@ -0,0 +1,69 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { InvocationNodeData } from 'features/nodes/types/types';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
import { FaBars } from 'react-icons/fa';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
data: InvocationNodeData;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeSettings = (props: Props) => {
|
||||||
|
const { data } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleChangeIsIntermediate = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(
|
||||||
|
fieldBooleanValueChanged({
|
||||||
|
nodeId: data.id,
|
||||||
|
fieldName: 'is_intermediate',
|
||||||
|
value: e.target.checked,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[data.id, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIPopover
|
||||||
|
isLazy={false}
|
||||||
|
triggerComponent={
|
||||||
|
<IAIIconButton
|
||||||
|
className="nopan"
|
||||||
|
aria-label="Node Settings"
|
||||||
|
variant="link"
|
||||||
|
sx={{
|
||||||
|
minW: 8,
|
||||||
|
color: 'base.500',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.500',
|
||||||
|
},
|
||||||
|
_hover: {
|
||||||
|
color: 'base.700',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.300',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
icon={<FaBars />}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Flex sx={{ flexDir: 'column', gap: 4, w: 64 }}>
|
||||||
|
<IAISwitch
|
||||||
|
label="Intermediate"
|
||||||
|
isChecked={Boolean(data.inputs['is_intermediate']?.value)}
|
||||||
|
onChange={handleChangeIsIntermediate}
|
||||||
|
helperText="The outputs of intermediate nodes are considered temporary objects. Intermediate images are not added to the gallery."
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</IAIPopover>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NodeSettings);
|
@ -0,0 +1,185 @@
|
|||||||
|
import {
|
||||||
|
Badge,
|
||||||
|
CircularProgress,
|
||||||
|
Flex,
|
||||||
|
Icon,
|
||||||
|
Image,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
|
import {
|
||||||
|
InvocationNodeData,
|
||||||
|
NodeExecutionState,
|
||||||
|
NodeStatus,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { memo, useMemo } from 'react';
|
||||||
|
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
|
||||||
|
import { NodeProps } from 'reactflow';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
nodeProps: NodeProps<InvocationNodeData>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const iconBoxSize = 3;
|
||||||
|
const circleStyles = {
|
||||||
|
circle: {
|
||||||
|
transitionProperty: 'none',
|
||||||
|
transitionDuration: '0s',
|
||||||
|
},
|
||||||
|
'.chakra-progress__track': { stroke: 'transparent' },
|
||||||
|
};
|
||||||
|
|
||||||
|
const NodeStatusIndicator = (props: Props) => {
|
||||||
|
const nodeId = props.nodeProps.data.id;
|
||||||
|
const selectNodeExecutionState = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => nodes.nodeExecutionStates[nodeId]
|
||||||
|
),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
|
||||||
|
|
||||||
|
if (!nodeExecutionState) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip
|
||||||
|
label={<TooltipLabel nodeExecutionState={nodeExecutionState} />}
|
||||||
|
placement="top"
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
|
sx={{
|
||||||
|
w: 5,
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'flex-end',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<StatusIcon nodeExecutionState={nodeExecutionState} />
|
||||||
|
</Flex>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NodeStatusIndicator);
|
||||||
|
|
||||||
|
type TooltipLabelProps = {
|
||||||
|
nodeExecutionState: NodeExecutionState;
|
||||||
|
};
|
||||||
|
|
||||||
|
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
|
||||||
|
const { status, progress, progressImage } = nodeExecutionState;
|
||||||
|
if (status === NodeStatus.PENDING) {
|
||||||
|
return <Text>Pending</Text>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status === NodeStatus.IN_PROGRESS) {
|
||||||
|
if (progressImage) {
|
||||||
|
return (
|
||||||
|
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
|
||||||
|
<Image
|
||||||
|
src={progressImage.dataURL}
|
||||||
|
sx={{ w: 32, h: 32, borderRadius: 'base', objectFit: 'contain' }}
|
||||||
|
/>
|
||||||
|
{progress !== null && (
|
||||||
|
<Badge
|
||||||
|
variant="solid"
|
||||||
|
sx={{ pos: 'absolute', top: 2.5, insetInlineEnd: 1 }}
|
||||||
|
>
|
||||||
|
{Math.round(progress * 100)}%
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (progress !== null) {
|
||||||
|
return <Text>In Progress ({Math.round(progress * 100)}%)</Text>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return <Text>In Progress</Text>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status === NodeStatus.COMPLETED) {
|
||||||
|
return <Text>Completed</Text>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status === NodeStatus.FAILED) {
|
||||||
|
return <Text>nodeExecutionState.error</Text>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
type StatusIconProps = {
|
||||||
|
nodeExecutionState: NodeExecutionState;
|
||||||
|
};
|
||||||
|
|
||||||
|
const StatusIcon = (props: StatusIconProps) => {
|
||||||
|
const { progress, status } = props.nodeExecutionState;
|
||||||
|
if (status === NodeStatus.PENDING) {
|
||||||
|
return (
|
||||||
|
<Icon
|
||||||
|
as={FaEllipsisH}
|
||||||
|
sx={{
|
||||||
|
boxSize: iconBoxSize,
|
||||||
|
color: 'base.600',
|
||||||
|
_dark: { color: 'base.300' },
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (status === NodeStatus.IN_PROGRESS) {
|
||||||
|
return progress === null ? (
|
||||||
|
<CircularProgress
|
||||||
|
isIndeterminate
|
||||||
|
size="14px"
|
||||||
|
color="base.500"
|
||||||
|
thickness={14}
|
||||||
|
sx={circleStyles}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<CircularProgress
|
||||||
|
value={Math.round(progress * 100)}
|
||||||
|
size="14px"
|
||||||
|
color="base.500"
|
||||||
|
thickness={14}
|
||||||
|
sx={circleStyles}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (status === NodeStatus.COMPLETED) {
|
||||||
|
return (
|
||||||
|
<Icon
|
||||||
|
as={FaCheck}
|
||||||
|
sx={{
|
||||||
|
boxSize: iconBoxSize,
|
||||||
|
color: 'ok.600',
|
||||||
|
_dark: { color: 'ok.300' },
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (status === NodeStatus.FAILED) {
|
||||||
|
return (
|
||||||
|
<Icon
|
||||||
|
as={FaExclamation}
|
||||||
|
sx={{
|
||||||
|
boxSize: iconBoxSize,
|
||||||
|
color: 'error.600',
|
||||||
|
_dark: { color: 'error.300' },
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
};
|
@ -0,0 +1,123 @@
|
|||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Editable,
|
||||||
|
EditableInput,
|
||||||
|
EditablePreview,
|
||||||
|
Flex,
|
||||||
|
useEditableControls,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
|
import { NodeData } from 'features/nodes/types/types';
|
||||||
|
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
nodeData: NodeData;
|
||||||
|
title: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const NodeTitle = (props: Props) => {
|
||||||
|
const { title } = props;
|
||||||
|
const { id: nodeId, label } = props.nodeData;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [localTitle, setLocalTitle] = useState(label || title);
|
||||||
|
|
||||||
|
const handleSubmit = useCallback(
|
||||||
|
async (newTitle: string) => {
|
||||||
|
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
|
||||||
|
setLocalTitle(newTitle || title);
|
||||||
|
},
|
||||||
|
[nodeId, dispatch, title]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChange = useCallback((newTitle: string) => {
|
||||||
|
setLocalTitle(newTitle);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// Another component may change the title; sync local title with global state
|
||||||
|
setLocalTitle(label || title);
|
||||||
|
}, [label, title]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
className="nopan"
|
||||||
|
sx={{
|
||||||
|
overflow: 'hidden',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
cursor: 'text',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Editable
|
||||||
|
as={Flex}
|
||||||
|
value={localTitle}
|
||||||
|
onChange={handleChange}
|
||||||
|
onSubmit={handleSubmit}
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
position: 'relative',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<EditablePreview
|
||||||
|
fontSize="sm"
|
||||||
|
sx={{
|
||||||
|
p: 0,
|
||||||
|
w: 'full',
|
||||||
|
}}
|
||||||
|
noOfLines={1}
|
||||||
|
/>
|
||||||
|
<EditableInput
|
||||||
|
fontSize="sm"
|
||||||
|
sx={{
|
||||||
|
p: 0,
|
||||||
|
_focusVisible: {
|
||||||
|
p: 0,
|
||||||
|
boxShadow: 'none',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<EditableControls />
|
||||||
|
</Editable>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NodeTitle);
|
||||||
|
|
||||||
|
function EditableControls() {
|
||||||
|
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||||
|
const handleDoubleClick = useCallback(
|
||||||
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
const { onClick } = getEditButtonProps();
|
||||||
|
if (!onClick) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
onClick(e);
|
||||||
|
},
|
||||||
|
[getEditButtonProps]
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isEditing) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
|
onDoubleClick={handleDoubleClick}
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
top: 0,
|
||||||
|
cursor: 'grab',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,96 @@
|
|||||||
|
import {
|
||||||
|
Box,
|
||||||
|
ChakraProps,
|
||||||
|
useColorModeValue,
|
||||||
|
useToken,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { nodeClicked } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react';
|
||||||
|
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
|
||||||
|
import { NodeData } from 'features/nodes/types/types';
|
||||||
|
import { NodeProps } from 'reactflow';
|
||||||
|
|
||||||
|
const useNodeSelect = (nodeId: string) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const selectNode = useCallback(
|
||||||
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
dispatch(nodeClicked({ nodeId, ctrlOrMeta: e.ctrlKey || e.metaKey }));
|
||||||
|
},
|
||||||
|
[dispatch, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return selectNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
type NodeWrapperProps = PropsWithChildren & {
|
||||||
|
nodeProps: NodeProps<NodeData>;
|
||||||
|
width?: NonNullable<ChakraProps['sx']>['w'];
|
||||||
|
};
|
||||||
|
|
||||||
|
const NodeWrapper = (props: NodeWrapperProps) => {
|
||||||
|
const { width, children, nodeProps } = props;
|
||||||
|
const { data, selected } = nodeProps;
|
||||||
|
const nodeId = data.id;
|
||||||
|
|
||||||
|
const [
|
||||||
|
nodeSelectedOutlineLight,
|
||||||
|
nodeSelectedOutlineDark,
|
||||||
|
shadowsXl,
|
||||||
|
shadowsBase,
|
||||||
|
] = useToken('shadows', [
|
||||||
|
'nodeSelectedOutline.light',
|
||||||
|
'nodeSelectedOutline.dark',
|
||||||
|
'shadows.xl',
|
||||||
|
'shadows.base',
|
||||||
|
]);
|
||||||
|
|
||||||
|
const selectNode = useNodeSelect(nodeId);
|
||||||
|
|
||||||
|
const shadow = useColorModeValue(
|
||||||
|
nodeSelectedOutlineLight,
|
||||||
|
nodeSelectedOutlineDark
|
||||||
|
);
|
||||||
|
|
||||||
|
const shift = useAppSelector((state) => state.hotkeys.shift);
|
||||||
|
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
|
||||||
|
const className = useMemo(
|
||||||
|
() => (shift ? DRAG_HANDLE_CLASSNAME : 'nopan'),
|
||||||
|
[shift]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
onClickCapture={selectNode}
|
||||||
|
className={className}
|
||||||
|
sx={{
|
||||||
|
h: 'full',
|
||||||
|
position: 'relative',
|
||||||
|
borderRadius: 'base',
|
||||||
|
w: width ?? NODE_WIDTH,
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.1s',
|
||||||
|
shadow: selected ? shadow : undefined,
|
||||||
|
opacity,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
bottom: 0,
|
||||||
|
insetInlineStart: 0,
|
||||||
|
borderRadius: 'base',
|
||||||
|
pointerEvents: 'none',
|
||||||
|
shadow: `${shadowsXl}, ${shadowsBase}, ${shadowsBase}`,
|
||||||
|
zIndex: -1,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{children}
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default NodeWrapper;
|
@ -1,74 +0,0 @@
|
|||||||
import { Flex, Icon } from '@chakra-ui/react';
|
|
||||||
import { FaExclamationCircle } from 'react-icons/fa';
|
|
||||||
import { NodeProps } from 'reactflow';
|
|
||||||
import { InvocationValue } from '../types/types';
|
|
||||||
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { memo, useMemo } from 'react';
|
|
||||||
import { makeTemplateSelector } from '../store/util/makeTemplateSelector';
|
|
||||||
import IAINodeHeader from './IAINode/IAINodeHeader';
|
|
||||||
import IAINodeInputs from './IAINode/IAINodeInputs';
|
|
||||||
import IAINodeOutputs from './IAINode/IAINodeOutputs';
|
|
||||||
import IAINodeResizer from './IAINode/IAINodeResizer';
|
|
||||||
import NodeWrapper from './NodeWrapper';
|
|
||||||
|
|
||||||
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
|
||||||
const { id: nodeId, data, selected } = props;
|
|
||||||
const { type, inputs, outputs } = data;
|
|
||||||
|
|
||||||
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
|
|
||||||
|
|
||||||
const template = useAppSelector(templateSelector);
|
|
||||||
|
|
||||||
if (!template) {
|
|
||||||
return (
|
|
||||||
<NodeWrapper selected={selected}>
|
|
||||||
<Flex
|
|
||||||
className="nopan"
|
|
||||||
sx={{
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
cursor: 'auto',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Icon
|
|
||||||
as={FaExclamationCircle}
|
|
||||||
sx={{
|
|
||||||
boxSize: 32,
|
|
||||||
color: 'base.600',
|
|
||||||
_dark: { color: 'base.400' },
|
|
||||||
}}
|
|
||||||
></Icon>
|
|
||||||
<IAINodeResizer />
|
|
||||||
</Flex>
|
|
||||||
</NodeWrapper>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<NodeWrapper selected={selected}>
|
|
||||||
<IAINodeHeader
|
|
||||||
nodeId={nodeId}
|
|
||||||
title={template.title}
|
|
||||||
description={template.description}
|
|
||||||
/>
|
|
||||||
<Flex
|
|
||||||
className={'nopan'}
|
|
||||||
sx={{
|
|
||||||
cursor: 'auto',
|
|
||||||
flexDirection: 'column',
|
|
||||||
borderBottomRadius: 'md',
|
|
||||||
py: 2,
|
|
||||||
bg: 'base.150',
|
|
||||||
_dark: { bg: 'base.800' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<IAINodeOutputs nodeId={nodeId} outputs={outputs} template={template} />
|
|
||||||
<IAINodeInputs nodeId={nodeId} inputs={inputs} template={template} />
|
|
||||||
</Flex>
|
|
||||||
<IAINodeResizer />
|
|
||||||
</NodeWrapper>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
InvocationComponent.displayName = 'InvocationComponent';
|
|
@ -1,25 +1,45 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import { ReactFlowProvider } from 'reactflow';
|
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||||
|
import { memo, useState } from 'react';
|
||||||
|
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||||
import 'reactflow/dist/style.css';
|
import 'reactflow/dist/style.css';
|
||||||
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { Flow } from './Flow';
|
import { Flow } from './Flow';
|
||||||
|
import NodeEditorPanelGroup from './panel/NodeEditorPanelGroup';
|
||||||
|
|
||||||
const NodeEditor = () => {
|
const NodeEditor = () => {
|
||||||
|
const [isPanelCollapsed, setIsPanelCollapsed] = useState(false);
|
||||||
return (
|
return (
|
||||||
<Box
|
<PanelGroup
|
||||||
layerStyle={'first'}
|
id="node-editor"
|
||||||
sx={{
|
autoSaveId="node-editor"
|
||||||
position: 'relative',
|
direction="horizontal"
|
||||||
width: 'full',
|
style={{ height: '100%', width: '100%' }}
|
||||||
height: 'full',
|
|
||||||
borderRadius: 'base',
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
<ReactFlowProvider>
|
<Panel
|
||||||
<Flow />
|
id="node-editor-panel-group"
|
||||||
</ReactFlowProvider>
|
collapsible
|
||||||
</Box>
|
onCollapse={setIsPanelCollapsed}
|
||||||
|
minSize={25}
|
||||||
|
>
|
||||||
|
<NodeEditorPanelGroup />
|
||||||
|
</Panel>
|
||||||
|
<ResizeHandle
|
||||||
|
collapsedDirection={isPanelCollapsed ? 'left' : undefined}
|
||||||
|
/>
|
||||||
|
<Panel id="node-editor-content">
|
||||||
|
<Box
|
||||||
|
layerStyle={'first'}
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
width: 'full',
|
||||||
|
height: 'full',
|
||||||
|
borderRadius: 'base',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flow />
|
||||||
|
</Box>
|
||||||
|
</Panel>
|
||||||
|
</PanelGroup>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,139 @@
|
|||||||
|
import {
|
||||||
|
Divider,
|
||||||
|
Flex,
|
||||||
|
Heading,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
useDisclosure,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { ChangeEvent, useCallback } from 'react';
|
||||||
|
import { FaCog } from 'react-icons/fa';
|
||||||
|
import {
|
||||||
|
shouldAnimateEdgesChanged,
|
||||||
|
shouldColorEdgesChanged,
|
||||||
|
shouldSnapToGridChanged,
|
||||||
|
shouldValidateGraphChanged,
|
||||||
|
} from '../store/nodesSlice';
|
||||||
|
|
||||||
|
const selector = createSelector(stateSelector, ({ nodes }) => {
|
||||||
|
const {
|
||||||
|
shouldAnimateEdges,
|
||||||
|
shouldValidateGraph,
|
||||||
|
shouldSnapToGrid,
|
||||||
|
shouldColorEdges,
|
||||||
|
} = nodes;
|
||||||
|
return {
|
||||||
|
shouldAnimateEdges,
|
||||||
|
shouldValidateGraph,
|
||||||
|
shouldSnapToGrid,
|
||||||
|
shouldColorEdges,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const NodeEditorSettings = () => {
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const {
|
||||||
|
shouldAnimateEdges,
|
||||||
|
shouldValidateGraph,
|
||||||
|
shouldSnapToGrid,
|
||||||
|
shouldColorEdges,
|
||||||
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
|
const handleChangeShouldValidate = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(shouldValidateGraphChanged(e.target.checked));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeShouldAnimate = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(shouldAnimateEdgesChanged(e.target.checked));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeShouldSnap = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(shouldSnapToGridChanged(e.target.checked));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeShouldColor = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(shouldColorEdgesChanged(e.target.checked));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label="Node Editor Settings"
|
||||||
|
icon={<FaCog />}
|
||||||
|
onClick={onOpen}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Modal isOpen={isOpen} onClose={onClose} size="2xl" isCentered>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent>
|
||||||
|
<ModalHeader>Node Editor Settings</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
gap: 4,
|
||||||
|
py: 4,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Heading size="sm">General</Heading>
|
||||||
|
<IAISwitch
|
||||||
|
onChange={handleChangeShouldAnimate}
|
||||||
|
isChecked={shouldAnimateEdges}
|
||||||
|
label="Animated Edges"
|
||||||
|
helperText="Animate selected edges and edges connected to selected nodes"
|
||||||
|
/>
|
||||||
|
<Divider />
|
||||||
|
<IAISwitch
|
||||||
|
isChecked={shouldSnapToGrid}
|
||||||
|
onChange={handleChangeShouldSnap}
|
||||||
|
label="Snap to Grid"
|
||||||
|
helperText="Snap nodes to grid when moved"
|
||||||
|
/>
|
||||||
|
<Divider />
|
||||||
|
<IAISwitch
|
||||||
|
isChecked={shouldColorEdges}
|
||||||
|
onChange={handleChangeShouldColor}
|
||||||
|
label="Color-Code Edges"
|
||||||
|
helperText="Color-code edges according to their connected fields"
|
||||||
|
/>
|
||||||
|
<Heading size="sm" pt={4}>
|
||||||
|
Advanced
|
||||||
|
</Heading>
|
||||||
|
<IAISwitch
|
||||||
|
isChecked={shouldValidateGraph}
|
||||||
|
onChange={handleChangeShouldValidate}
|
||||||
|
label="Validate Connections and Graph"
|
||||||
|
helperText="Prevent invalid connections from being made, and invalid graphs from being invoked"
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</ModalBody>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default NodeEditorSettings;
|
@ -1,34 +1,26 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { memo } from 'react';
|
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
|
||||||
|
import { omit } from 'lodash-es';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { useDebounce } from 'use-debounce';
|
||||||
import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph';
|
import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph';
|
||||||
|
|
||||||
const NodeGraphOverlay = () => {
|
const useNodesGraph = () => {
|
||||||
const state = useAppSelector((state: RootState) => state);
|
const nodes = useAppSelector((state: RootState) => state.nodes);
|
||||||
const graph = buildNodesGraph(state);
|
const [debouncedNodes] = useDebounce(nodes, 300);
|
||||||
|
const graph = useMemo(
|
||||||
return (
|
() => omit(buildNodesGraph(debouncedNodes), 'id'),
|
||||||
<Box
|
[debouncedNodes]
|
||||||
as="pre"
|
|
||||||
sx={{
|
|
||||||
fontFamily: 'monospace',
|
|
||||||
position: 'absolute',
|
|
||||||
top: 2,
|
|
||||||
right: 2,
|
|
||||||
opacity: 0.7,
|
|
||||||
p: 2,
|
|
||||||
maxHeight: 500,
|
|
||||||
maxWidth: 500,
|
|
||||||
overflowY: 'scroll',
|
|
||||||
borderRadius: 'base',
|
|
||||||
bg: 'base.200',
|
|
||||||
_dark: { bg: 'base.800' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{JSON.stringify(graph, null, 2)}
|
|
||||||
</Box>
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
return graph;
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(NodeGraphOverlay);
|
const NodeGraph = () => {
|
||||||
|
const graph = useNodesGraph();
|
||||||
|
|
||||||
|
return <ImageMetadataJSON jsonObject={graph} label="Graph" />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default NodeGraph;
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Slider,
|
||||||
|
SliderFilledTrack,
|
||||||
|
SliderThumb,
|
||||||
|
SliderTrack,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { nodeOpacityChanged } from '../store/nodesSlice';
|
||||||
|
|
||||||
|
export default function NodeOpacitySlider() {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const nodeOpacity = useAppSelector((state) => state.nodes.nodeOpacity);
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(nodeOpacityChanged(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box>
|
||||||
|
<Slider
|
||||||
|
aria-label="Node Opacity"
|
||||||
|
value={nodeOpacity}
|
||||||
|
min={0.5}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
onChange={handleChange}
|
||||||
|
orientation="vertical"
|
||||||
|
defaultValue={30}
|
||||||
|
>
|
||||||
|
<SliderTrack>
|
||||||
|
<SliderFilledTrack />
|
||||||
|
</SliderTrack>
|
||||||
|
<SliderThumb />
|
||||||
|
</Slider>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
@ -1,36 +0,0 @@
|
|||||||
import { Box, useToken } from '@chakra-ui/react';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { PropsWithChildren } from 'react';
|
|
||||||
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
|
|
||||||
import { NODE_MIN_WIDTH } from '../types/constants';
|
|
||||||
|
|
||||||
type NodeWrapperProps = PropsWithChildren & {
|
|
||||||
selected: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
const NodeWrapper = (props: NodeWrapperProps) => {
|
|
||||||
const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [
|
|
||||||
'nodeSelectedOutline',
|
|
||||||
'dark-lg',
|
|
||||||
]);
|
|
||||||
|
|
||||||
const shift = useAppSelector((state) => state.hotkeys.shift);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Box
|
|
||||||
className={shift ? DRAG_HANDLE_CLASSNAME : 'nopan'}
|
|
||||||
sx={{
|
|
||||||
position: 'relative',
|
|
||||||
borderRadius: 'md',
|
|
||||||
minWidth: NODE_MIN_WIDTH,
|
|
||||||
shadow: props.selected
|
|
||||||
? `${nodeSelectedOutline}, ${nodeShadow}`
|
|
||||||
: `${nodeShadow}`,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{props.children}
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default NodeWrapper;
|
|
@ -1,73 +0,0 @@
|
|||||||
import { Flex, Image } from '@chakra-ui/react';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useDispatch, useSelector } from 'react-redux';
|
|
||||||
import { NodeProps, OnResize } from 'reactflow';
|
|
||||||
import { setProgressNodeSize } from '../store/nodesSlice';
|
|
||||||
import IAINodeHeader from './IAINode/IAINodeHeader';
|
|
||||||
import IAINodeResizer from './IAINode/IAINodeResizer';
|
|
||||||
import NodeWrapper from './NodeWrapper';
|
|
||||||
|
|
||||||
const ProgressImageNode = (props: NodeProps) => {
|
|
||||||
const progressImage = useSelector(
|
|
||||||
(state: RootState) => state.system.progressImage
|
|
||||||
);
|
|
||||||
const progressNodeSize = useSelector(
|
|
||||||
(state: RootState) => state.nodes.progressNodeSize
|
|
||||||
);
|
|
||||||
const dispatch = useDispatch();
|
|
||||||
const { selected } = props;
|
|
||||||
|
|
||||||
const handleResize: OnResize = (_, newSize) => {
|
|
||||||
dispatch(setProgressNodeSize(newSize));
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<NodeWrapper selected={selected}>
|
|
||||||
<IAINodeHeader
|
|
||||||
title="Progress Image"
|
|
||||||
description="Displays the progress image in the Node Editor"
|
|
||||||
/>
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
flexShrink: 0,
|
|
||||||
borderBottomRadius: 'md',
|
|
||||||
bg: 'base.200',
|
|
||||||
_dark: { bg: 'base.800' },
|
|
||||||
width: progressNodeSize.width - 2,
|
|
||||||
height: progressNodeSize.height - 2,
|
|
||||||
minW: 250,
|
|
||||||
minH: 250,
|
|
||||||
overflow: 'hidden',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{progressImage ? (
|
|
||||||
<Image
|
|
||||||
src={progressImage.dataURL}
|
|
||||||
sx={{
|
|
||||||
w: 'full',
|
|
||||||
h: 'full',
|
|
||||||
objectFit: 'contain',
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
minW: 250,
|
|
||||||
minH: 250,
|
|
||||||
width: progressNodeSize.width - 2,
|
|
||||||
height: progressNodeSize.height - 2,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<IAINoContentFallback />
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
<IAINodeResizer onResize={handleResize} />
|
|
||||||
</NodeWrapper>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ProgressImageNode);
|
|
@ -2,18 +2,16 @@ import { ButtonGroup, Tooltip } from '@chakra-ui/react';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import {
|
|
||||||
FaCode,
|
|
||||||
FaExpand,
|
|
||||||
FaMinus,
|
|
||||||
FaPlus,
|
|
||||||
FaInfo,
|
|
||||||
FaMapMarkerAlt,
|
|
||||||
} from 'react-icons/fa';
|
|
||||||
import { useReactFlow } from 'reactflow';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
shouldShowGraphOverlayChanged,
|
FaExpand,
|
||||||
|
FaInfo,
|
||||||
|
FaMapMarkerAlt,
|
||||||
|
FaMinus,
|
||||||
|
FaPlus,
|
||||||
|
} from 'react-icons/fa';
|
||||||
|
import { useReactFlow } from 'reactflow';
|
||||||
|
import {
|
||||||
shouldShowFieldTypeLegendChanged,
|
shouldShowFieldTypeLegendChanged,
|
||||||
shouldShowMinimapPanelChanged,
|
shouldShowMinimapPanelChanged,
|
||||||
} from '../store/nodesSlice';
|
} from '../store/nodesSlice';
|
||||||
@ -22,9 +20,6 @@ const ViewportControls = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const shouldShowGraphOverlay = useAppSelector(
|
|
||||||
(state) => state.nodes.shouldShowGraphOverlay
|
|
||||||
);
|
|
||||||
const shouldShowFieldTypeLegend = useAppSelector(
|
const shouldShowFieldTypeLegend = useAppSelector(
|
||||||
(state) => state.nodes.shouldShowFieldTypeLegend
|
(state) => state.nodes.shouldShowFieldTypeLegend
|
||||||
);
|
);
|
||||||
@ -44,10 +39,6 @@ const ViewportControls = () => {
|
|||||||
fitView();
|
fitView();
|
||||||
}, [fitView]);
|
}, [fitView]);
|
||||||
|
|
||||||
const handleClickedToggleGraphOverlay = useCallback(() => {
|
|
||||||
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
|
|
||||||
}, [shouldShowGraphOverlay, dispatch]);
|
|
||||||
|
|
||||||
const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
||||||
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
||||||
}, [shouldShowFieldTypeLegend, dispatch]);
|
}, [shouldShowFieldTypeLegend, dispatch]);
|
||||||
@ -79,20 +70,6 @@ const ViewportControls = () => {
|
|||||||
icon={<FaExpand />}
|
icon={<FaExpand />}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<Tooltip
|
|
||||||
label={
|
|
||||||
shouldShowGraphOverlay
|
|
||||||
? t('nodes.hideGraphNodes')
|
|
||||||
: t('nodes.showGraphNodes')
|
|
||||||
}
|
|
||||||
>
|
|
||||||
<IAIIconButton
|
|
||||||
aria-label="Toggle nodes graph overlay"
|
|
||||||
isChecked={shouldShowGraphOverlay}
|
|
||||||
onClick={handleClickedToggleGraphOverlay}
|
|
||||||
icon={<FaCode />}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
<Tooltip
|
<Tooltip
|
||||||
label={
|
label={
|
||||||
shouldShowFieldTypeLegend
|
shouldShowFieldTypeLegend
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { Panel } from 'reactflow';
|
import { Panel } from 'reactflow';
|
||||||
import ViewportControls from '../ViewportControls';
|
import ViewportControls from '../ViewportControls';
|
||||||
|
import NodeOpacitySlider from '../NodeOpacitySlider';
|
||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
const BottomLeftPanel = () => (
|
const BottomLeftPanel = () => (
|
||||||
<Panel position="bottom-left">
|
<Panel position="bottom-left">
|
||||||
<ViewportControls />
|
<Flex sx={{ gap: 2 }}>
|
||||||
|
<ViewportControls />
|
||||||
|
<NodeOpacitySlider />
|
||||||
|
</Flex>
|
||||||
</Panel>
|
</Panel>
|
||||||
);
|
);
|
||||||
|
|
@ -20,7 +20,7 @@ const MinimapPanel = () => {
|
|||||||
|
|
||||||
const nodeColor = useColorModeValue(
|
const nodeColor = useColorModeValue(
|
||||||
'var(--invokeai-colors-accent-300)',
|
'var(--invokeai-colors-accent-300)',
|
||||||
'var(--invokeai-colors-accent-700)'
|
'var(--invokeai-colors-accent-600)'
|
||||||
);
|
);
|
||||||
|
|
||||||
const maskColor = useColorModeValue(
|
const maskColor = useColorModeValue(
|
||||||
@ -32,10 +32,9 @@ const MinimapPanel = () => {
|
|||||||
<>
|
<>
|
||||||
{shouldShowMinimapPanel && (
|
{shouldShowMinimapPanel && (
|
||||||
<MiniMap
|
<MiniMap
|
||||||
nodeStrokeWidth={3}
|
|
||||||
pannable
|
pannable
|
||||||
zoomable
|
zoomable
|
||||||
nodeBorderRadius={30}
|
nodeBorderRadius={15}
|
||||||
style={miniMapStyle}
|
style={miniMapStyle}
|
||||||
nodeColor={nodeColor}
|
nodeColor={nodeColor}
|
||||||
maskColor={maskColor}
|
maskColor={maskColor}
|
@ -2,11 +2,10 @@ import { HStack } from '@chakra-ui/react';
|
|||||||
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
|
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { Panel } from 'reactflow';
|
import { Panel } from 'reactflow';
|
||||||
|
import NodeEditorSettings from '../NodeEditorSettings';
|
||||||
import ClearGraphButton from '../ui/ClearGraphButton';
|
import ClearGraphButton from '../ui/ClearGraphButton';
|
||||||
import LoadGraphButton from '../ui/LoadGraphButton';
|
|
||||||
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
||||||
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
|
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
|
||||||
import SaveGraphButton from '../ui/SaveGraphButton';
|
|
||||||
|
|
||||||
const TopCenterPanel = () => {
|
const TopCenterPanel = () => {
|
||||||
return (
|
return (
|
||||||
@ -15,9 +14,8 @@ const TopCenterPanel = () => {
|
|||||||
<NodeInvokeButton />
|
<NodeInvokeButton />
|
||||||
<CancelButton />
|
<CancelButton />
|
||||||
<ReloadSchemaButton />
|
<ReloadSchemaButton />
|
||||||
<SaveGraphButton />
|
|
||||||
<LoadGraphButton />
|
|
||||||
<ClearGraphButton />
|
<ClearGraphButton />
|
||||||
|
<NodeEditorSettings />
|
||||||
</HStack>
|
</HStack>
|
||||||
</Panel>
|
</Panel>
|
||||||
);
|
);
|
@ -1,22 +1,16 @@
|
|||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { Panel } from 'reactflow';
|
import { Panel } from 'reactflow';
|
||||||
import FieldTypeLegend from '../FieldTypeLegend';
|
import FieldTypeLegend from '../FieldTypeLegend';
|
||||||
import NodeGraphOverlay from '../NodeGraphOverlay';
|
|
||||||
|
|
||||||
const TopRightPanel = () => {
|
const TopRightPanel = () => {
|
||||||
const shouldShowGraphOverlay = useAppSelector(
|
|
||||||
(state: RootState) => state.nodes.shouldShowGraphOverlay
|
|
||||||
);
|
|
||||||
const shouldShowFieldTypeLegend = useAppSelector(
|
const shouldShowFieldTypeLegend = useAppSelector(
|
||||||
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
|
(state) => state.nodes.shouldShowFieldTypeLegend
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Panel position="top-right">
|
<Panel position="top-right">
|
||||||
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
|
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
|
||||||
{shouldShowGraphOverlay && <NodeGraphOverlay />}
|
|
||||||
</Panel>
|
</Panel>
|
||||||
);
|
);
|
||||||
};
|
};
|
@ -1,15 +0,0 @@
|
|||||||
import {
|
|
||||||
ArrayInputFieldTemplate,
|
|
||||||
ArrayInputFieldValue,
|
|
||||||
} from 'features/nodes/types/types';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { FaList } from 'react-icons/fa';
|
|
||||||
import { FieldComponentProps } from './types';
|
|
||||||
|
|
||||||
const ArrayInputFieldComponent = (
|
|
||||||
_props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
|
|
||||||
) => {
|
|
||||||
return <FaList />;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ArrayInputFieldComponent);
|
|
@ -1,37 +0,0 @@
|
|||||||
import { Select } from '@chakra-ui/react';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
|
||||||
import {
|
|
||||||
EnumInputFieldTemplate,
|
|
||||||
EnumInputFieldValue,
|
|
||||||
} from 'features/nodes/types/types';
|
|
||||||
import { ChangeEvent, memo } from 'react';
|
|
||||||
import { FieldComponentProps } from './types';
|
|
||||||
|
|
||||||
const EnumInputFieldComponent = (
|
|
||||||
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
|
|
||||||
) => {
|
|
||||||
const { nodeId, field, template } = props;
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
|
||||||
dispatch(
|
|
||||||
fieldValueChanged({
|
|
||||||
nodeId,
|
|
||||||
fieldName: field.name,
|
|
||||||
value: e.target.value,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Select onChange={handleValueChanged} value={field.value}>
|
|
||||||
{template.options.map((option) => (
|
|
||||||
<option key={option}>{option}</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(EnumInputFieldComponent);
|
|
@ -0,0 +1,47 @@
|
|||||||
|
import { MenuItem, MenuList } from '@chakra-ui/react';
|
||||||
|
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
||||||
|
import {
|
||||||
|
InputFieldTemplate,
|
||||||
|
InputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MouseEvent, useCallback } from 'react';
|
||||||
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
nodeId: string;
|
||||||
|
field: InputFieldValue;
|
||||||
|
fieldTemplate: InputFieldTemplate;
|
||||||
|
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||||
|
};
|
||||||
|
|
||||||
|
const FieldContextMenu = (props: Props) => {
|
||||||
|
const skipEvent = useCallback((e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ContextMenu<HTMLDivElement>
|
||||||
|
menuProps={{
|
||||||
|
size: 'sm',
|
||||||
|
isLazy: true,
|
||||||
|
}}
|
||||||
|
menuButtonProps={{
|
||||||
|
bg: 'transparent',
|
||||||
|
_hover: { bg: 'transparent' },
|
||||||
|
}}
|
||||||
|
renderMenu={() => (
|
||||||
|
<MenuList
|
||||||
|
sx={{ visibility: 'visible !important' }}
|
||||||
|
motionProps={menuListMotionProps}
|
||||||
|
onContextMenu={skipEvent}
|
||||||
|
>
|
||||||
|
<MenuItem>Test</MenuItem>
|
||||||
|
</MenuList>
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{props.children}
|
||||||
|
</ContextMenu>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default FieldContextMenu;
|
@ -0,0 +1,122 @@
|
|||||||
|
import { Tooltip } from '@chakra-ui/react';
|
||||||
|
import { CSSProperties, memo, useMemo } from 'react';
|
||||||
|
import { Handle, HandleType, NodeProps, Position } from 'reactflow';
|
||||||
|
import {
|
||||||
|
FIELDS,
|
||||||
|
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||||
|
colorTokenToCssVar,
|
||||||
|
} from '../../types/constants';
|
||||||
|
import {
|
||||||
|
InputFieldTemplate,
|
||||||
|
InputFieldValue,
|
||||||
|
InvocationNodeData,
|
||||||
|
InvocationTemplate,
|
||||||
|
OutputFieldTemplate,
|
||||||
|
OutputFieldValue,
|
||||||
|
} from '../../types/types';
|
||||||
|
|
||||||
|
export const handleBaseStyles: CSSProperties = {
|
||||||
|
position: 'absolute',
|
||||||
|
width: '1rem',
|
||||||
|
height: '1rem',
|
||||||
|
borderWidth: 0,
|
||||||
|
zIndex: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const inputHandleStyles: CSSProperties = {
|
||||||
|
left: '-1rem',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const outputHandleStyles: CSSProperties = {
|
||||||
|
right: '-0.5rem',
|
||||||
|
};
|
||||||
|
|
||||||
|
type FieldHandleProps = {
|
||||||
|
nodeProps: NodeProps<InvocationNodeData>;
|
||||||
|
nodeTemplate: InvocationTemplate;
|
||||||
|
field: InputFieldValue | OutputFieldValue;
|
||||||
|
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
|
||||||
|
handleType: HandleType;
|
||||||
|
isConnectionInProgress: boolean;
|
||||||
|
isConnectionStartField: boolean;
|
||||||
|
connectionError: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const FieldHandle = (props: FieldHandleProps) => {
|
||||||
|
const {
|
||||||
|
fieldTemplate,
|
||||||
|
handleType,
|
||||||
|
isConnectionInProgress,
|
||||||
|
isConnectionStartField,
|
||||||
|
connectionError,
|
||||||
|
} = props;
|
||||||
|
const { name, type } = fieldTemplate;
|
||||||
|
const { color, title } = FIELDS[type];
|
||||||
|
|
||||||
|
const styles: CSSProperties = useMemo(() => {
|
||||||
|
const s: CSSProperties = {
|
||||||
|
backgroundColor: colorTokenToCssVar(color),
|
||||||
|
position: 'absolute',
|
||||||
|
width: '1rem',
|
||||||
|
height: '1rem',
|
||||||
|
borderWidth: 0,
|
||||||
|
zIndex: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (handleType === 'target') {
|
||||||
|
s.insetInlineStart = '-1rem';
|
||||||
|
} else {
|
||||||
|
s.insetInlineEnd = '-1rem';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
|
||||||
|
s.filter = 'opacity(0.4) grayscale(0.7)';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isConnectionInProgress && connectionError) {
|
||||||
|
if (isConnectionStartField) {
|
||||||
|
s.cursor = 'grab';
|
||||||
|
} else {
|
||||||
|
s.cursor = 'not-allowed';
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.cursor = 'crosshair';
|
||||||
|
}
|
||||||
|
|
||||||
|
return s;
|
||||||
|
}, [
|
||||||
|
color,
|
||||||
|
connectionError,
|
||||||
|
handleType,
|
||||||
|
isConnectionInProgress,
|
||||||
|
isConnectionStartField,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const tooltip = useMemo(() => {
|
||||||
|
if (isConnectionInProgress && isConnectionStartField) {
|
||||||
|
return title;
|
||||||
|
}
|
||||||
|
if (isConnectionInProgress && connectionError) {
|
||||||
|
return connectionError ?? title;
|
||||||
|
}
|
||||||
|
return title;
|
||||||
|
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip
|
||||||
|
label={tooltip}
|
||||||
|
placement={handleType === 'target' ? 'start' : 'end'}
|
||||||
|
hasArrow
|
||||||
|
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||||
|
>
|
||||||
|
<Handle
|
||||||
|
type={handleType}
|
||||||
|
id={name}
|
||||||
|
position={handleType === 'target' ? Position.Left : Position.Right}
|
||||||
|
style={styles}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(FieldHandle);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user