Compare commits

..

1 Commits

Author SHA1 Message Date
a7606901f3 feat(ui): add eslint rules
- `curly` requires conditionals to use curly braces
- `react/jsx-curly-brace-presence` requires string props to *not* have curly braces
- `react-memo/require-memo` requires function components to be wrapped in `memo`
- `react-memo/require-usememo` requires all complex props (objects, functions) to be wrapped in `useMemo` or `useCallback`
2023-08-16 12:31:33 +10:00
84 changed files with 957 additions and 2306 deletions

View File

@ -25,10 +25,10 @@ This method is recommended for experienced users and developers
#### [Docker Installation](040_INSTALL_DOCKER.md)
This method is recommended for those familiar with running Docker containers
### Other Installation Guides
- [PyPatchMatch](060_INSTALL_PATCHMATCH.md)
- [XFormers](070_INSTALL_XFORMERS.md)
- [CUDA and ROCm Drivers](030_INSTALL_CUDA_AND_ROCM.md)
- [Installing New Models](050_INSTALLING_MODELS.md)
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
- [XFormers](installation/070_INSTALL_XFORMERS.md)
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
## :fontawesome-solid-computer: Hardware Requirements

View File

@ -40,7 +40,7 @@ async def create_session(
@session_router.get(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[dict]}},
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
)
async def list_sessions(
page: int = Query(default=0, description="The page of results to get"),

View File

@ -48,7 +48,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
)
@title("Boolean Primitive")
@title("Boolean")
@tags("primitives", "boolean")
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
@ -62,7 +62,7 @@ class BooleanInvocation(BaseInvocation):
return BooleanOutput(a=self.a)
@title("Boolean Primitive Collection")
@title("Boolean Collection")
@tags("primitives", "boolean", "collection")
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
@ -101,7 +101,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
)
@title("Integer Primitive")
@title("Integer")
@tags("primitives", "integer")
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
@ -115,7 +115,7 @@ class IntegerInvocation(BaseInvocation):
return IntegerOutput(a=self.a)
@title("Integer Primitive Collection")
@title("Integer Collection")
@tags("primitives", "integer", "collection")
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
@ -154,7 +154,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
)
@title("Float Primitive")
@title("Float")
@tags("primitives", "float")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
@ -168,7 +168,7 @@ class FloatInvocation(BaseInvocation):
return FloatOutput(a=self.param)
@title("Float Primitive Collection")
@title("Float Collection")
@tags("primitives", "float", "collection")
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
@ -207,7 +207,7 @@ class StringCollectionOutput(BaseInvocationOutput):
)
@title("String Primitive")
@title("String")
@tags("primitives", "string")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
@ -221,7 +221,7 @@ class StringInvocation(BaseInvocation):
return StringOutput(text=self.text)
@title("String Primitive Collection")
@title("String Collection")
@tags("primitives", "string", "collection")
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
@ -289,7 +289,7 @@ class ImageInvocation(BaseInvocation):
)
@title("Image Primitive Collection")
@title("Image Collection")
@tags("primitives", "image", "collection")
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""
@ -357,7 +357,7 @@ class LatentsInvocation(BaseInvocation):
return build_latents_output(self.latents.latents_name, latents)
@title("Latents Primitive Collection")
@title("Latents Collection")
@tags("primitives", "latents", "collection")
class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values"""
@ -475,7 +475,7 @@ class ConditioningInvocation(BaseInvocation):
return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Primitive Collection")
@title("Conditioning Collection")
@tags("primitives", "conditioning", "collection")
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""

View File

@ -29,7 +29,6 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme
writes to the system log is stored in InvocationServices.performance_statistics.
"""
import psutil
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
@ -43,11 +42,6 @@ import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from invokeai.backend.model_management.model_cache import CacheStats
# size of GIG in bytes
GIG = 1073741824
class InvocationStatsServiceBase(ABC):
@ -95,8 +89,6 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
@ -105,8 +97,6 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
pass
@ -125,9 +115,6 @@ class NodeStats:
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
@ -146,62 +133,31 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0
self.ram_changed: float = 0.0
class StatsContext:
"""Context manager for collecting statistics."""
invocation: BaseInvocation = None
collector: "InvocationStatsServiceBase" = None
graph_id: str = None
start_time: int = 0
ram_used: int = 0
model_manager: ModelManagerService = None
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerService,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
self.ram_used = 0
self.model_manager = model_manager
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_mem_stats(
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type,
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
@ -210,8 +166,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
return self.StatsContext(invocation, graph_execution_state_id, self)
def reset_all_stats(self):
"""Zero all statistics"""
@ -224,36 +179,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used
self.ram_changed = ram_changed
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
:param time_used: Floating point seconds used by node's exection
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
@ -265,7 +197,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed when the execution of the graph
Stats will only be printed if when the execution of the graph
is complete.
"""
completed = set()
@ -276,30 +208,16 @@ class InvocationStatsService(InvocationStatsServiceBase):
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
logger.info("Node Calls Seconds VRAM Used")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
total_time += stats.time_used
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
logger.info(f"RAM used to load models: {loaded:4.2f}G")
if torch.cuda.is_available():
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
logger.info("RAM cache statistics:")
logger.info(f" Model cache hits: {cache_stats.hits}")
logger.info(f" Model cache misses: {cache_stats.misses}")
logger.info(f" Models cached: {cache_stats.in_cache}")
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]

View File

@ -22,7 +22,6 @@ from invokeai.backend.model_management import (
ModelNotFoundException,
)
from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.model_management.model_cache import CacheStats
import torch
from invokeai.app.models.exceptions import CanceledException
@ -277,13 +276,6 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
@ -508,12 +500,6 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.

View File

@ -86,9 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
with statistics.collect_stats(invocation, graph_execution_state.id):
# use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection

View File

@ -1,7 +1,6 @@
import sqlite3
import json
from threading import Lock
from typing import Generic, Optional, TypeVar, Union, get_args
from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as
@ -100,7 +99,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(
@ -109,7 +108,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = [json.loads(r[0]) for r in result]
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
@ -118,9 +117,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(
@ -129,7 +128,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = [json.loads(r[0]) for r in result]
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
@ -141,4 +140,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)

View File

@ -21,12 +21,12 @@ import os
import sys
import hashlib
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any
import torch
import logging
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, SubModelType, ModelBase
@ -41,18 +41,6 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
GIG = 1073741824
@dataclass
class CacheStats(object):
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object):
"Forward declaration"
pass
@ -127,9 +115,6 @@ class ModelCache(object):
self.sha_chunksize = sha_chunksize
self.logger = logger
# used for stats collection
self.stats = None
self._cached_models = dict()
self._cache_stack = list()
@ -196,14 +181,13 @@ class ModelCache(object):
model_type=model_type,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1
# this will remove older cached models until
# there is sufficient room to load the requested model
@ -217,17 +201,6 @@ class ModelCache(object):
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
else:
if self.stats:
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
)
with suppress(Exception):
self._cache_stack.remove(key)
@ -307,14 +280,14 @@ class ModelCache(object):
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
"""
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == "cuda"
@ -337,15 +310,12 @@ class ModelCache(object):
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
)
def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self._cache_size()
current_size = sum([m.size for m in self._cached_models.values()])
if current_size + bytes_needed > maximum_size:
self.logger.debug(
@ -394,8 +364,6 @@ class ModelCache(object):
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry

View File

@ -240,7 +240,6 @@ class InvokeAIDiffuserComponent:
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)

View File

@ -4,15 +4,8 @@ import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
@ -25,11 +18,10 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.
@ -60,25 +52,12 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
@ -119,15 +98,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
@ -135,7 +109,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads=64,
):
super().__init__()
@ -163,9 +136,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
@ -175,43 +145,16 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@ -235,29 +178,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
@ -292,7 +212,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
@ -329,7 +248,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
@ -359,22 +277,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
@ -560,7 +463,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False,
@ -584,9 +486,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
@ -649,7 +549,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
@ -661,34 +560,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if "addition_embed_type" in self.config:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down

View File

@ -20,9 +20,16 @@ module.exports = {
ecmaVersion: 2018,
sourceType: 'module',
},
plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'],
plugins: ['react', '@typescript-eslint', 'react-hooks', 'react-memo'],
root: true,
rules: {
curly: 'error',
'react-memo/require-memo': 'error',
'react-memo/require-usememo': 'error',
'react/jsx-curly-brace-presence': [
'error',
{ props: 'never', children: 'never' },
],
'react-hooks/exhaustive-deps': 'error',
'no-var': 'error',
'brace-style': 'error',

View File

@ -76,6 +76,7 @@
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"downshift": "^7.6.0",
"eslint-plugin-react-memo": "^0.0.3",
"formik": "^2.4.2",
"framer-motion": "^10.12.17",
"fuse.js": "^6.6.2",

View File

@ -506,14 +506,10 @@
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"seamPaintingHeader": "Seam Painting",
"seamSize": "Seam Size",
"seamBlur": "Seam Blur",
"seamSteps": "Seam Steps",
"seamStrength": "Seam Strength",
"seamThreshold": "Seam Threshold",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"seamSteps": "Seam Steps",
"scaleBeforeProcessing": "Scale Before Processing",
"scaledWidth": "Scaled W",
"scaledHeight": "Scaled H",

View File

@ -121,7 +121,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
// handle singles in separate listener
return;
}

View File

@ -1,126 +0,0 @@
/**
* This is a copy-paste of https://github.com/lukasbach/chakra-ui-contextmenu with a small change.
*
* The reactflow background element somehow prevents the chakra `useOutsideClick()` hook from working.
* With a menu open, clicking on the reactflow background element doesn't close the menu.
*
* Reactflow does provide an `onPaneClick` to handle clicks on the background element, but it is not
* straightforward to programatically close the menu.
*
* As a (hopefully temporary) workaround, we will use a dirty hack:
* - create `globalContextMenuCloseTrigger: number` in `ui` slice
* - increment it in `onPaneClick`
* - `useEffect()` to close the menu when `globalContextMenuCloseTrigger` changes
*/
import {
Menu,
MenuButton,
MenuButtonProps,
MenuProps,
Portal,
PortalProps,
useEventListener,
} from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import * as React from 'react';
import {
MutableRefObject,
useCallback,
useEffect,
useRef,
useState,
} from 'react';
export interface IAIContextMenuProps<T extends HTMLElement> {
renderMenu: () => JSX.Element | null;
children: (ref: MutableRefObject<T | null>) => JSX.Element | null;
menuProps?: Omit<MenuProps, 'children'> & { children?: React.ReactNode };
portalProps?: Omit<PortalProps, 'children'> & { children?: React.ReactNode };
menuButtonProps?: MenuButtonProps;
}
export function IAIContextMenu<T extends HTMLElement = HTMLElement>(
props: IAIContextMenuProps<T>
) {
const [isOpen, setIsOpen] = useState(false);
const [isRendered, setIsRendered] = useState(false);
const [isDeferredOpen, setIsDeferredOpen] = useState(false);
const [position, setPosition] = useState<[number, number]>([0, 0]);
const targetRef = useRef<T>(null);
const globalContextMenuCloseTrigger = useAppSelector(
(state) => state.ui.globalContextMenuCloseTrigger
);
useEffect(() => {
if (isOpen) {
setTimeout(() => {
setIsRendered(true);
setTimeout(() => {
setIsDeferredOpen(true);
});
});
} else {
setIsDeferredOpen(false);
const timeout = setTimeout(() => {
setIsRendered(isOpen);
}, 1000);
return () => clearTimeout(timeout);
}
}, [isOpen]);
useEffect(() => {
setIsOpen(false);
setIsDeferredOpen(false);
setIsRendered(false);
}, [globalContextMenuCloseTrigger]);
useEventListener('contextmenu', (e) => {
if (
targetRef.current?.contains(e.target as HTMLElement) ||
e.target === targetRef.current
) {
e.preventDefault();
setIsOpen(true);
setPosition([e.pageX, e.pageY]);
} else {
setIsOpen(false);
}
});
const onCloseHandler = useCallback(() => {
props.menuProps?.onClose?.();
setIsOpen(false);
}, [props.menuProps]);
return (
<>
{props.children(targetRef)}
{isRendered && (
<Portal {...props.portalProps}>
<Menu
isOpen={isDeferredOpen}
gutter={0}
{...props.menuProps}
onClose={onCloseHandler}
>
<MenuButton
aria-hidden={true}
w={1}
h={1}
style={{
position: 'absolute',
left: position[0],
top: position[1],
cursor: 'default',
}}
{...props.menuButtonProps}
/>
{props.renderMenu()}
</Menu>
</Portal>
)}
</>
);
}

View File

@ -16,7 +16,6 @@ import ImageContextMenu from 'features/gallery/components/ImageContextMenu/Image
import {
MouseEvent,
ReactElement,
ReactNode,
SyntheticEvent,
memo,
useCallback,
@ -33,17 +32,6 @@ import {
TypesafeDroppableData,
} from 'features/dnd/types';
const defaultUploadElement = (
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
);
const defaultNoContentFallback = <IAINoContentFallback icon={FaImage} />;
type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
@ -59,14 +47,13 @@ type IAIDndImageProps = FlexProps & {
fitContainer?: boolean;
droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData;
dropLabel?: ReactNode;
dropLabel?: string;
isSelected?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
withHoverOverlay?: boolean;
children?: JSX.Element;
uploadElement?: ReactNode;
};
const IAIDndImage = (props: IAIDndImageProps) => {
@ -87,8 +74,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel,
isSelected = false,
thumbnail = false,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
noContentFallback = <IAINoContentFallback icon={FaImage} />,
useThumbailFallback,
withHoverOverlay = false,
children,
@ -207,7 +193,12 @@ const IAIDndImage = (props: IAIDndImageProps) => {
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
{uploadElement}
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
</Flex>
</>
)}
@ -219,7 +210,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick}
/>
)}
{children}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
@ -227,6 +217,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel={dropLabel}
/>
)}
{children}
</Flex>
)}
</ImageContextMenu>

View File

@ -1,8 +1,5 @@
import { MenuList } from '@chakra-ui/react';
import {
IAIContextMenu,
IAIContextMenuProps,
} from 'common/components/IAIContextMenu';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import { MouseEvent, memo, useCallback } from 'react';
import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
@ -15,7 +12,7 @@ import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
type Props = {
imageDTO: ImageDTO | undefined;
children: IAIContextMenuProps<HTMLDivElement>['children'];
children: ContextMenuProps<HTMLDivElement>['children'];
};
const selector = createSelector(
@ -36,7 +33,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}, []);
return (
<IAIContextMenu<HTMLDivElement>
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
menuButtonProps={{
bg: 'transparent',
@ -71,7 +68,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}}
>
{children}
</IAIContextMenu>
</ContextMenu>
);
};

View File

@ -2,9 +2,8 @@ import { Badge, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { memo, useMemo } from 'react';
import { useMemo } from 'react';
import {
BaseEdge,
EdgeLabelRenderer,
@ -21,165 +20,78 @@ const makeEdgeSelector = (
targetHandleId: string | null | undefined,
selected?: boolean
) =>
createSelector(
stateSelector,
({ nodes }) => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
createSelector(stateSelector, ({ nodes }) => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge =
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isInvocationToInvocationEdge =
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected =
sourceNode?.selected || targetNode?.selected || selected;
const sourceType = isInvocationToInvocationEdge
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
: undefined;
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
const sourceType = isInvocationToInvocationEdge
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
: undefined;
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
: colorTokenToCssVar('base.500');
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
: colorTokenToCssVar('base.500');
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
},
defaultSelectorOptions
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
});
const CollapsedEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
);
const CollapsedEdge = memo(
({
const { isSelected, shouldAnimate } = useAppSelector(selector);
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
);
});
const { isSelected, shouldAnimate } = useAppSelector(selector);
const { base500 } = useChakraThemeTokens();
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
const { base500 } = useChakraThemeTokens();
return (
<>
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate
? 'dashdraw 0.5s linear infinite'
: undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
}
);
CollapsedEdge.displayName = 'CollapsedEdge';
const DefaultEdge = memo(
({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
return (
<>
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate
? 'dashdraw 0.5s linear infinite'
@ -187,11 +99,83 @@ const DefaultEdge = memo(
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
);
}
);
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
};
DefaultEdge.displayName = 'DefaultEdge';
const DefaultEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
);
};
export const edgeTypes = {
collapsed: CollapsedEdge,

View File

@ -1,6 +1,5 @@
import { useToken } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { useCallback } from 'react';
import {
Background,
@ -115,10 +114,6 @@ export const Flow = () => {
[dispatch]
);
const handlePaneClick = useCallback(() => {
dispatch(contextMenusClosed());
}, [dispatch]);
return (
<ReactFlow
defaultViewport={viewport}
@ -137,13 +132,12 @@ export const Flow = () => {
connectionLineComponent={CustomConnectionLine}
onSelectionChange={handleSelectionChange}
isValidConnection={isValidConnection}
minZoom={0.1}
minZoom={0.2}
snapToGrid={shouldSnapToGrid}
snapGrid={[25, 25]}
connectionRadius={30}
proOptions={proOptions}
style={{ borderRadius }}
onPaneClick={handlePaneClick}
>
<TopLeftPanel />
<TopCenterPanel />

View File

@ -1,34 +1,40 @@
import { Flex } from '@chakra-ui/react';
import { useFieldNames, useWithFooter } from 'features/nodes/hooks/useNodeData';
import { memo } from 'react';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map, some } from 'lodash-es';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import InputField from '../fields/InputField';
import OutputField from '../fields/OutputField';
import NodeFooter from './NodeFooter';
import NodeFooter, { FOOTER_FIELDS } from './NodeFooter';
import NodeHeader from './NodeHeader';
import NodeWrapper from './NodeWrapper';
type Props = {
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
};
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const inputFieldNames = useFieldNames(nodeId, 'input');
const outputFieldNames = useFieldNames(nodeId, 'output');
const withFooter = useWithFooter(nodeId);
const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => {
const { id: nodeId, data } = nodeProps;
const { inputs, outputs, isOpen } = data;
const inputFields = useMemo(
() => map(inputs).filter((i) => i.name !== 'is_intermediate'),
[inputs]
);
const outputFields = useMemo(() => map(outputs), [outputs]);
const withFooter = useMemo(
() => some(outputs, (output) => FOOTER_FIELDS.includes(output.type)),
[outputs]
);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<NodeHeader
nodeId={nodeId}
isOpen={isOpen}
label={label}
selected={selected}
type={type}
/>
<NodeWrapper nodeProps={nodeProps}>
<NodeHeader nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
{isOpen && (
<>
<Flex
@ -48,23 +54,27 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
className="nopan"
sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}
>
{outputFieldNames.map((fieldName) => (
{outputFields.map((field) => (
<OutputField
key={`${nodeId}.${fieldName}.output-field`}
nodeId={nodeId}
fieldName={fieldName}
key={`${nodeId}.${field.id}.input-field`}
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
/>
))}
{inputFieldNames.map((fieldName) => (
{inputFields.map((field) => (
<InputField
key={`${nodeId}.${fieldName}.input-field`}
nodeId={nodeId}
fieldName={fieldName}
key={`${nodeId}.${field.id}.input-field`}
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
/>
))}
</Flex>
</Flex>
{withFooter && <NodeFooter nodeId={nodeId} />}
{withFooter && (
<NodeFooter nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
)}
</>
)}
</NodeWrapper>

View File

@ -2,15 +2,16 @@ import { ChevronUpIcon } from '@chakra-ui/icons';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
import { NodeData } from 'features/nodes/types/types';
import { memo, useCallback } from 'react';
import { useUpdateNodeInternals } from 'reactflow';
import { NodeProps, useUpdateNodeInternals } from 'reactflow';
interface Props {
nodeId: string;
isOpen: boolean;
nodeProps: NodeProps<NodeData>;
}
const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
const NodeCollapseButton = (props: Props) => {
const { id: nodeId, isOpen } = props.nodeProps.data;
const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals();

View File

@ -1,17 +1,20 @@
import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { isInvocationNodeData } from 'features/nodes/types/types';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, Position } from 'reactflow';
import { Handle, NodeProps, Position } from 'reactflow';
interface Props {
nodeId: string;
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
}
const NodeCollapsedHandles = ({ nodeId }: Props) => {
const data = useNodeData(nodeId);
const NodeCollapsedHandles = (props: Props) => {
const { data } = props.nodeProps;
const { base400, base600 } = useChakraThemeTokens();
const backgroundColor = useColorModeValue(base400, base600);
@ -27,10 +30,6 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
[backgroundColor]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<>
<Handle
@ -45,7 +44,7 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
key={`${data.id}-${input.name}-collapsed-input-handle`}
type="target"
id={input.name}
isConnectable={false}
isValidConnection={() => false}
position={Position.Left}
style={{ visibility: 'hidden' }}
/>
@ -53,6 +52,7 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
<Handle
type="source"
id={`${data.id}-collapsed-source`}
isValidConnection={() => false}
isConnectable={false}
position={Position.Right}
style={{ ...dummyHandleStyles, right: '-0.5rem' }}
@ -62,7 +62,7 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
key={`${data.id}-${output.name}-collapsed-output-handle`}
type="source"
id={output.name}
isConnectable={false}
isValidConnection={() => false}
position={Position.Right}
style={{ visibility: 'hidden' }}
/>

View File

@ -6,19 +6,49 @@ import {
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import {
useHasImageOutput,
useIsIntermediate,
} from 'features/nodes/hooks/useNodeData';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { ChangeEvent, memo, useCallback } from 'react';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
import { NodeProps } from 'reactflow';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS;
type Props = {
nodeId: string;
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
};
const NodeFooter = ({ nodeId }: Props) => {
const NodeFooter = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const dispatch = useAppDispatch();
const hasImageOutput = useMemo(
() =>
some(nodeTemplate?.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
),
[nodeTemplate?.outputs]
);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: nodeProps.data.id,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeProps.data.id]
);
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
@ -32,45 +62,19 @@ const NodeFooter = ({ nodeId }: Props) => {
}}
>
<Spacer />
<SaveImageCheckbox nodeId={nodeId} />
{hasImageOutput && (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
/>
</FormControl>
)}
</Flex>
);
};
export default memo(NodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@ -1,5 +1,10 @@
import { Flex } from '@chakra-ui/react';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles';
import NodeNotesEdit from '../Invocation/NodeNotesEdit';
@ -7,14 +12,14 @@ import NodeStatusIndicator from '../Invocation/NodeStatusIndicator';
import NodeTitle from '../Invocation/NodeTitle';
type Props = {
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
};
const NodeHeader = ({ nodeId, isOpen }: Props) => {
const NodeHeader = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const { isOpen } = nodeProps.data;
return (
<Flex
layerStyle="nodeHeader"
@ -30,13 +35,18 @@ const NodeHeader = ({ nodeId, isOpen }: Props) => {
_dark: { color: 'base.200' },
}}
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeId={nodeId} />
<NodeCollapseButton nodeProps={nodeProps} />
<NodeTitle nodeData={nodeProps.data} title={nodeTemplate.title} />
<Flex alignItems="center">
<NodeStatusIndicator nodeId={nodeId} />
<NodeNotesEdit nodeId={nodeId} />
<NodeStatusIndicator nodeProps={nodeProps} />
<NodeNotesEdit nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
</Flex>
{!isOpen && <NodeCollapsedHandles nodeId={nodeId} />}
{!isOpen && (
<NodeCollapsedHandles
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
/>
)}
</Flex>
);
};

View File

@ -16,31 +16,41 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import {
useNodeData,
useNodeLabel,
useNodeTemplate,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { isInvocationNodeData } from 'features/nodes/types/types';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
interface Props {
nodeId: string;
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
}
const NodeNotesEdit = ({ nodeId }: Props) => {
const NodeNotesEdit = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const { data } = nodeProps;
const { isOpen, onOpen, onClose } = useDisclosure();
const label = useNodeLabel(nodeId);
const title = useNodeTemplateTitle(nodeId);
const dispatch = useAppDispatch();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
},
[data.id, dispatch]
);
return (
<>
<Tooltip
label={<TooltipContent nodeId={nodeId} />}
label={
nodeTemplate ? (
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
) : undefined
}
placement="top"
shouldWrapChildren
>
@ -65,10 +75,19 @@ const NodeNotesEdit = ({ nodeId }: Props) => {
<Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>{label || title || 'Unknown Node'}</ModalHeader>
<ModalHeader>
{data.label || nodeTemplate?.title || 'Unknown Node'}
</ModalHeader>
<ModalCloseButton />
<ModalBody>
<NotesTextarea nodeId={nodeId} />
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
</ModalBody>
<ModalFooter />
</ModalContent>
@ -79,49 +98,16 @@ const NodeNotesEdit = ({ nodeId }: Props) => {
export default memo(NodeNotesEdit);
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
if (!isInvocationNodeData(data)) {
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
}
type TooltipContentProps = Props;
const TooltipContent = (props: TooltipContentProps) => {
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{nodeTemplate?.description}
{props.nodeTemplate?.description}
</Text>
{data?.notes && <Text>{data.notes}</Text>}
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>}
</Flex>
);
});
TooltipContent.displayName = 'TooltipContent';
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
});
NotesTextarea.displayName = 'NodesTextarea';
};

View File

@ -11,12 +11,17 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
import {
InvocationNodeData,
NodeExecutionState,
NodeStatus,
} from 'features/nodes/types/types';
import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
type Props = {
nodeId: string;
nodeProps: NodeProps<InvocationNodeData>;
};
const iconBoxSize = 3;
@ -28,7 +33,8 @@ const circleStyles = {
'.chakra-progress__track': { stroke: 'transparent' },
};
const NodeStatusIndicator = ({ nodeId }: Props) => {
const NodeStatusIndicator = (props: Props) => {
const nodeId = props.nodeProps.data.id;
const selectNodeExecutionState = useMemo(
() =>
createSelector(
@ -70,7 +76,7 @@ type TooltipLabelProps = {
nodeExecutionState: NodeExecutionState;
};
const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState;
if (status === NodeStatus.PENDING) {
return <Text>Pending</Text>;
@ -112,15 +118,13 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
}
return null;
});
TooltipLabel.displayName = 'TooltipLabel';
};
type StatusIconProps = {
nodeExecutionState: NodeExecutionState;
};
const StatusIcon = memo((props: StatusIconProps) => {
const StatusIcon = (props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) {
return (
@ -178,6 +182,4 @@ const StatusIcon = memo((props: StatusIconProps) => {
);
}
return null;
});
StatusIcon.displayName = 'StatusIcon';
};

View File

@ -7,29 +7,26 @@ import {
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import {
useNodeLabel,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeData } from 'features/nodes/types/types';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
type Props = {
nodeId: string;
title?: string;
nodeData: NodeData;
title: string;
};
const NodeTitle = ({ nodeId, title }: Props) => {
const NodeTitle = (props: Props) => {
const { title } = props;
const { id: nodeId, label } = props.nodeData;
const dispatch = useAppDispatch();
const label = useNodeLabel(nodeId);
const templateTitle = useNodeTemplateTitle(nodeId);
const [localTitle, setLocalTitle] = useState(label || title);
const [localTitle, setLocalTitle] = useState('');
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title || 'Problem Setting Title');
setLocalTitle(newTitle || title);
},
[nodeId, dispatch, title]
);
@ -40,8 +37,8 @@ const NodeTitle = ({ nodeId, title }: Props) => {
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || title || templateTitle || 'Problem Setting Title');
}, [label, templateTitle, title]);
setLocalTitle(label || title);
}, [label, title]);
return (
<Flex

View File

@ -6,14 +6,10 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeClicked } from 'features/nodes/store/nodesSlice';
import {
MouseEvent,
PropsWithChildren,
memo,
useCallback,
useMemo,
} from 'react';
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
import { NodeData } from 'features/nodes/types/types';
import { NodeProps } from 'reactflow';
const useNodeSelect = (nodeId: string) => {
const dispatch = useAppDispatch();
@ -29,13 +25,14 @@ const useNodeSelect = (nodeId: string) => {
};
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
selected: boolean;
nodeProps: NodeProps<NodeData>;
width?: NonNullable<ChakraProps['sx']>['w'];
};
const NodeWrapper = (props: NodeWrapperProps) => {
const { width, children, nodeId, selected } = props;
const { width, children, nodeProps } = props;
const { data, selected } = nodeProps;
const nodeId = data.id;
const [
nodeSelectedOutlineLight,
@ -96,4 +93,4 @@ const NodeWrapper = (props: NodeWrapperProps) => {
);
};
export default memo(NodeWrapper);
export default NodeWrapper;

View File

@ -1,26 +1,20 @@
import { Box, Flex, Text } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { InvocationNodeData } from 'features/nodes/types/types';
import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeWrapper from '../Invocation/NodeWrapper';
type Props = {
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
nodeProps: NodeProps<InvocationNodeData>;
};
const UnknownNodeFallback = ({
nodeId,
isOpen,
label,
type,
selected,
}: Props) => {
const UnknownNodeFallback = ({ nodeProps }: Props) => {
const { data } = nodeProps;
const { isOpen, label, type } = data;
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<NodeWrapper nodeProps={nodeProps}>
<Flex
className={DRAG_HANDLE_CLASSNAME}
layerStyle="nodeHeader"
@ -33,7 +27,7 @@ const UnknownNodeFallback = ({
fontSize: 'sm',
}}
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeCollapseButton nodeProps={nodeProps} />
<Text
sx={{
w: 'full',

View File

@ -46,6 +46,7 @@ const NodeEditor = () => {
<AnimatePresence>
{isReady && (
<motion.div
layoutId="node-editor-flow"
initial={{
opacity: 0,
}}
@ -66,6 +67,7 @@ const NodeEditor = () => {
<AnimatePresence>
{!isReady && (
<motion.div
layoutId="node-editor-loading"
initial={{
opacity: 0,
}}

View File

@ -15,7 +15,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, memo, useCallback } from 'react';
import { ChangeEvent, useCallback } from 'react';
import { FaCog } from 'react-icons/fa';
import {
shouldAnimateEdgesChanged,
@ -23,26 +23,21 @@ import {
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from '../store/nodesSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
};
},
defaultSelectorOptions
);
const selector = createSelector(stateSelector, ({ nodes }) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
};
});
const NodeEditorSettings = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
@ -141,4 +136,4 @@ const NodeEditorSettings = () => {
);
};
export default memo(NodeEditorSettings);
export default NodeEditorSettings;

View File

@ -1,12 +1,19 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, Position } from 'reactflow';
import { Handle, HandleType, NodeProps, Position } from 'reactflow';
import {
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
colorTokenToCssVar,
} from '../../types/constants';
import { InputFieldTemplate, OutputFieldTemplate } from '../../types/types';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from '../../types/types';
export const handleBaseStyles: CSSProperties = {
position: 'absolute',
@ -25,6 +32,9 @@ export const outputHandleStyles: CSSProperties = {
};
type FieldHandleProps = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
handleType: HandleType;
isConnectionInProgress: boolean;

View File

@ -8,11 +8,13 @@ import {
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDraggable from 'common/components/IAIDraggable';
import { NodeFieldDraggableData } from 'features/dnd/types';
import {
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import {
MouseEvent,
memo,
@ -23,43 +25,41 @@ import {
} from 'react';
interface Props {
nodeId: string;
fieldName: string;
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
isDraggable?: boolean;
kind: 'input' | 'output';
}
const FieldTitle = (props: Props) => {
const { nodeId, fieldName, isDraggable = false, kind } = props;
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const field = useFieldData(nodeId, fieldName);
const { nodeData, field, fieldTemplate, isDraggable = false } = props;
const { label } = field;
const { title, input } = fieldTemplate;
const { id: nodeId } = nodeData;
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(
field?.label || fieldTemplate?.title || 'Unknown Field'
);
const [localTitle, setLocalTitle] = useState(label || title);
const draggableData: NodeFieldDraggableData | undefined = useMemo(
() =>
field &&
fieldTemplate?.fieldKind === 'input' &&
fieldTemplate?.input !== 'connection' &&
isDraggable
input !== 'connection' && isDraggable
? {
id: `${nodeId}-${fieldName}`,
id: `${nodeId}-${field.name}`,
payloadType: 'NODE_FIELD',
payload: { nodeId, field, fieldTemplate },
}
: undefined,
[field, fieldName, fieldTemplate, isDraggable, nodeId]
[field, fieldTemplate, input, isDraggable, nodeId]
);
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
setLocalTitle(newTitle || fieldTemplate?.title || 'Unknown Field');
dispatch(
fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle })
);
setLocalTitle(newTitle || title);
},
[dispatch, nodeId, fieldName, fieldTemplate?.title]
[dispatch, nodeId, field.name, title]
);
const handleChange = useCallback((newTitle: string) => {
@ -68,8 +68,8 @@ const FieldTitle = (props: Props) => {
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(field?.label || fieldTemplate?.title || 'Unknown Field');
}, [field?.label, fieldTemplate?.title]);
setLocalTitle(label || title);
}, [label, title]);
return (
<Flex
@ -120,7 +120,7 @@ type EditableControlsProps = {
draggableData?: NodeFieldDraggableData;
};
const EditableControls = memo((props: EditableControlsProps) => {
function EditableControls(props: EditableControlsProps) {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
@ -158,6 +158,4 @@ const EditableControls = memo((props: EditableControlsProps) => {
cursor="text"
/>
);
});
EditableControls.displayName = 'EditableControls';
}

View File

@ -1,53 +1,38 @@
import { Flex, Text } from '@chakra-ui/react';
import {
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { FIELDS } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
isInputFieldTemplate,
isInputFieldValue,
} from 'features/nodes/types/types';
import { startCase } from 'lodash-es';
import { useMemo } from 'react';
interface Props {
nodeId: string;
fieldName: string;
kind: 'input' | 'output';
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
}
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const FieldTooltipContent = ({ field, fieldTemplate }: Props) => {
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
const fieldTitle = useMemo(() => {
if (isInputFieldValue(field)) {
if (field.label && fieldTemplate) {
return `${field.label} (${fieldTemplate.title})`;
}
if (field.label && !fieldTemplate) {
return field.label;
}
if (!field.label && fieldTemplate) {
return fieldTemplate.title;
}
return 'Unknown Field';
}
}, [field, fieldTemplate]);
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{fieldTitle}</Text>
{fieldTemplate && (
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{fieldTemplate.description}
</Text>
)}
{fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
<Text sx={{ fontWeight: 600 }}>
{isInputFieldValue(field) && field.label
? `${field.label} (${fieldTemplate.title})`
: fieldTemplate.title}
</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{fieldTemplate.description}
</Text>
<Text>Type: {FIELDS[fieldTemplate.type].title}</Text>
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
</Flex>
);

View File

@ -1,24 +1,27 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import {
useDoesInputHaveValue,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, memo, useMemo } from 'react';
import {
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
interface Props {
nodeId: string;
fieldName: string;
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
}
const InputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const InputField = (props: Props) => {
const { nodeProps, nodeTemplate, field } = props;
const { id: nodeId } = nodeProps.data;
const {
isConnected,
@ -26,10 +29,15 @@ const InputField = ({ nodeId, fieldName }: Props) => {
isConnectionStartField,
connectionError,
shouldDim,
} = useConnectionState({ nodeId, fieldName, kind: 'input' });
} = useConnectionState({ nodeId, field, kind: 'input' });
const fieldTemplate = useMemo(
() => nodeTemplate.inputs[field.name],
[field.name, nodeTemplate.inputs]
);
const isMissingInput = useMemo(() => {
if (fieldTemplate?.fieldKind !== 'input') {
if (!fieldTemplate) {
return false;
}
@ -41,18 +49,18 @@ const InputField = ({ nodeId, fieldName }: Props) => {
return true;
}
if (!doesFieldHaveValue && !isConnected && fieldTemplate.input === 'any') {
if (!field.value && !isConnected && fieldTemplate.input === 'any') {
return true;
}
}, [fieldTemplate, isConnected, doesFieldHaveValue]);
}, [fieldTemplate, isConnected, field.value]);
if (fieldTemplate?.fieldKind !== 'input') {
if (!fieldTemplate) {
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
>
Unknown input: {fieldName}
Unknown input: {field.name}
</FormControl>
</InputFieldWrapper>
);
@ -74,9 +82,10 @@ const InputField = ({ nodeId, fieldName }: Props) => {
<Tooltip
label={
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -86,18 +95,27 @@ const InputField = ({ nodeId, fieldName }: Props) => {
>
<FormLabel sx={{ mb: 0 }}>
<FieldTitle
nodeId={nodeId}
fieldName={fieldName}
kind="input"
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
isDraggable
/>
</FormLabel>
</Tooltip>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
<InputFieldRenderer
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl>
{fieldTemplate.input !== 'direct' && (
<FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
handleType="target"
isConnectionInProgress={isConnectionInProgress}
@ -115,25 +133,21 @@ type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const InputFieldWrapper = memo(
({ shouldDim, children }: InputFieldWrapperProps) => (
<Flex
className="nopan"
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
w: 'full',
h: 'full',
}}
>
{children}
</Flex>
)
const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => (
<Flex
className="nopan"
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
w: 'full',
h: 'full',
}}
>
{children}
</Flex>
);
InputFieldWrapper.displayName = 'InputFieldWrapper';

View File

@ -1,9 +1,11 @@
import { Box } from '@chakra-ui/react';
import {
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { memo } from 'react';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from '../../types/types';
import BooleanInputField from './fieldTypes/BooleanInputField';
import ClipInputField from './fieldTypes/ClipInputField';
import CollectionInputField from './fieldTypes/CollectionInputField';
@ -27,33 +29,33 @@ import VaeInputField from './fieldTypes/VaeInputField';
import VaeModelInputField from './fieldTypes/VaeModelInputField';
type InputFieldProps = {
nodeId: string;
fieldName: string;
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
};
// build an individual input element based on the schema
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const InputFieldRenderer = (props: InputFieldProps) => {
const { nodeData, nodeTemplate, field, fieldTemplate } = props;
const { type } = field;
if (fieldTemplate?.fieldKind === 'output') {
return <Box p={2}>Output field in input: {field?.type}</Box>;
}
if (field?.type === 'string' && fieldTemplate?.type === 'string') {
if (type === 'string' && fieldTemplate.type === 'string') {
return (
<StringInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') {
if (type === 'boolean' && fieldTemplate.type === 'boolean') {
return (
<BooleanInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -61,45 +63,46 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(field?.type === 'float' && fieldTemplate?.type === 'float')
(type === 'integer' && fieldTemplate.type === 'integer') ||
(type === 'float' && fieldTemplate.type === 'float')
) {
return (
<NumberInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
if (type === 'enum' && fieldTemplate.type === 'enum') {
return (
<EnumInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') {
if (type === 'ImageField' && fieldTemplate.type === 'ImageField') {
return (
<ImageInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'LatentsField' &&
fieldTemplate?.type === 'LatentsField'
) {
if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') {
return (
<LatentsInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -107,68 +110,68 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
type === 'ConditioningField' &&
fieldTemplate.type === 'ConditioningField'
) {
return (
<ConditioningInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
if (type === 'UNetField' && fieldTemplate.type === 'UNetField') {
return (
<UnetInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
if (type === 'ClipField' && fieldTemplate.type === 'ClipField') {
return (
<ClipInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
if (type === 'VaeField' && fieldTemplate.type === 'VaeField') {
return (
<VaeInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlField' &&
fieldTemplate?.type === 'ControlField'
) {
if (type === 'ControlField' && fieldTemplate.type === 'ControlField') {
return (
<ControlInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
) {
if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') {
return (
<MainModelInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -176,38 +179,35 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (
field?.type === 'SDXLRefinerModelField' &&
fieldTemplate?.type === 'SDXLRefinerModelField'
type === 'SDXLRefinerModelField' &&
fieldTemplate.type === 'SDXLRefinerModelField'
) {
return (
<RefinerModelInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'VaeModelField' &&
fieldTemplate?.type === 'VaeModelField'
) {
if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') {
return (
<VaeModelInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'LoRAModelField' &&
fieldTemplate?.type === 'LoRAModelField'
) {
if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') {
return (
<LoRAModelInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -215,58 +215,57 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (
field?.type === 'ControlNetModelField' &&
fieldTemplate?.type === 'ControlNetModelField'
type === 'ControlNetModelField' &&
fieldTemplate.type === 'ControlNetModelField'
) {
return (
<ControlNetModelInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
if (type === 'Collection' && fieldTemplate.type === 'Collection') {
return (
<CollectionInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'CollectionItem' &&
fieldTemplate?.type === 'CollectionItem'
) {
if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') {
return (
<CollectionItemInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
if (type === 'ColorField' && fieldTemplate.type === 'ColorField') {
return (
<ColorInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ImageCollection' &&
fieldTemplate?.type === 'ImageCollection'
) {
if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') {
return (
<ImageCollectionInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -274,19 +273,20 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
type === 'SDXLMainModelField' &&
fieldTemplate.type === 'SDXLMainModelField'
) {
return (
<SDXLMainModelInputField
nodeId={nodeId}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return <Box p={2}>Unknown field type: {field?.type}</Box>;
return <Box p={2}>Unknown field type: {type}</Box>;
};
export default memo(InputFieldRenderer);

View File

@ -1,16 +1,39 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
type Props = {
nodeId: string;
fieldName: string;
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
};
const LinearViewField = ({ nodeId, fieldName }: Props) => {
const LinearViewField = ({
nodeData,
nodeTemplate,
field,
fieldTemplate,
}: Props) => {
// const dispatch = useAppDispatch();
// const handleRemoveField = useCallback(() => {
// dispatch(
// workflowExposedFieldRemoved({
// nodeId: nodeData.id,
// fieldName: field.name,
// })
// );
// }, [dispatch, field.name, nodeData.id]);
return (
<Flex
layerStyle="second"
@ -25,9 +48,10 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
<Tooltip
label={
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -42,10 +66,20 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
mb: 0,
}}
>
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
<FieldTitle
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormLabel>
</Tooltip>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
<InputFieldRenderer
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
</FormControl>
</Flex>
);

View File

@ -6,19 +6,25 @@ import {
Tooltip,
} from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldTemplate } from 'features/nodes/hooks/useNodeData';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, memo } from 'react';
import {
InvocationNodeData,
InvocationTemplate,
OutputFieldValue,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import FieldHandle from './FieldHandle';
import FieldTooltipContent from './FieldTooltipContent';
interface Props {
nodeId: string;
fieldName: string;
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: OutputFieldValue;
}
const OutputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'output');
const OutputField = (props: Props) => {
const { nodeTemplate, nodeProps, field } = props;
const {
isConnected,
@ -26,15 +32,20 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
isConnectionStartField,
connectionError,
shouldDim,
} = useConnectionState({ nodeId, fieldName, kind: 'output' });
} = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' });
if (fieldTemplate?.fieldKind !== 'output') {
const fieldTemplate = useMemo(
() => nodeTemplate.outputs[field.name],
[field.name, nodeTemplate]
);
if (!fieldTemplate) {
return (
<OutputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }}
>
Unknown output: {fieldName}
Unknown output: {field.name}
</FormControl>
</OutputFieldWrapper>
);
@ -46,9 +57,10 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
<Tooltip
label={
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="output"
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -63,6 +75,9 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
</FormControl>
</Tooltip>
<FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
handleType="source"
isConnectionInProgress={isConnectionInProgress}
@ -73,28 +88,27 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
);
};
export default memo(OutputField);
export default OutputField;
type OutputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const OutputFieldWrapper = memo(
({ shouldDim, children }: OutputFieldWrapperProps) => (
<Flex
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
}}
>
{children}
</Flex>
)
const OutputFieldWrapper = ({
shouldDim,
children,
}: OutputFieldWrapperProps) => (
<Flex
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
}}
>
{children}
</Flex>
);
OutputFieldWrapper.displayName = 'OutputFieldWrapper';

View File

@ -11,7 +11,8 @@ import { FieldComponentProps } from './types';
const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();

View File

@ -11,7 +11,8 @@ import { FieldComponentProps } from './types';
const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();

View File

@ -19,7 +19,8 @@ const ControlNetModelInputFieldComponent = (
ControlNetModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const controlNetModel = field.value;
const dispatch = useAppDispatch();

View File

@ -11,7 +11,8 @@ import { FieldComponentProps } from './types';
const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();

View File

@ -19,7 +19,8 @@ const ImageCollectionInputFieldComponent = (
ImageCollectionInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
// const dispatch = useAppDispatch();

View File

@ -21,7 +21,8 @@ import { FieldComponentProps } from './types';
const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const { currentData: imageDTO } = useGetImageDTOQuery(

View File

@ -21,7 +21,8 @@ const LoRAModelInputFieldComponent = (
LoRAModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const lora = field.value;
const dispatch = useAppDispatch();
const { data: loraModels } = useGetLoRAModelsQuery();

View File

@ -26,7 +26,8 @@ const MainModelInputFieldComponent = (
MainModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -23,7 +23,8 @@ const NumberInputFieldComponent = (
IntegerInputFieldTemplate | FloatInputFieldTemplate
>
) => {
const { nodeId, field, fieldTemplate } = props;
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const [valueAsString, setValueAsString] = useState<string>(
String(field.value)

View File

@ -24,7 +24,8 @@ const RefinerModelInputFieldComponent = (
SDXLRefinerModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -27,7 +27,8 @@ const ModelInputFieldComponent = (
SDXLMainModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -12,7 +12,8 @@ import { FieldComponentProps } from './types';
const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const handleValueChanged = useCallback(

View File

@ -20,7 +20,8 @@ const VaeModelInputFieldComponent = (
VaeModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const { nodeData, field } = props;
const nodeId = nodeData.id;
const vae = field.value;
const dispatch = useAppDispatch();
const { data: vaeModels } = useGetVaeModelsQuery();

View File

@ -1,13 +1,16 @@
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
export type FieldComponentProps<
V extends InputFieldValue,
T extends InputFieldTemplate
> = {
nodeId: string;
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: V;
fieldTemplate: T;
};

View File

@ -55,11 +55,7 @@ const CurrentImageNode = (props: NodeProps) => {
export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
<NodeWrapper
nodeId={props.nodeProps.data.id}
selected={props.nodeProps.selected}
width={384}
>
<NodeWrapper nodeProps={props.nodeProps} width={384}>
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{

View File

@ -1,6 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { makeTemplateSelector } from 'features/nodes/store/util/makeTemplateSelector';
import { InvocationNodeData } from 'features/nodes/types/types';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
@ -8,40 +7,18 @@ import InvocationNode from '../Invocation/InvocationNode';
import UnknownNodeFallback from '../Invocation/UnknownNodeFallback';
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { data, selected } = props;
const { id: nodeId, type, isOpen, label } = data;
const { data } = props;
const { type } = data;
const hasTemplateSelector = useMemo(
() =>
createSelector(stateSelector, ({ nodes }) =>
Boolean(nodes.nodeTemplates[type])
),
[type]
);
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
const nodeTemplate = useAppSelector(hasTemplateSelector);
const nodeTemplate = useAppSelector(templateSelector);
if (!nodeTemplate) {
return (
<UnknownNodeFallback
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
return <UnknownNodeFallback nodeProps={props} />;
}
return (
<InvocationNode
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
return <InvocationNode nodeProps={props} nodeTemplate={nodeTemplate} />;
};
export default memo(InvocationNodeWrapper);

View File

@ -10,7 +10,7 @@ import NodeTitle from '../Invocation/NodeTitle';
import NodeWrapper from '../Invocation/NodeWrapper';
const NotesNode = (props: NodeProps<NotesNodeData>) => {
const { id: nodeId, data, selected } = props;
const { id: nodeId, data } = props;
const { notes, isOpen } = data;
const dispatch = useAppDispatch();
const handleChange = useCallback(
@ -21,7 +21,7 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<NodeWrapper nodeProps={props}>
<Flex
layerStyle="nodeHeader"
sx={{
@ -32,8 +32,8 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
h: 8,
}}
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeId={nodeId} title="Notes" />
<NodeCollapseButton nodeProps={props} />
<NodeTitle nodeData={props.data} title="Notes" />
<Box minW={8} />
</Flex>
{isOpen && (

View File

@ -6,11 +6,39 @@ import {
TabPanels,
Tabs,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react';
import NodeDataInspector from './NodeDataInspector';
import NodeTemplateInspector from './NodeTemplateInspector';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
node: lastSelectedNode,
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const InspectorPanel = () => {
const { node, template } = useAppSelector(selector);
return (
<Flex
layerStyle="first"
@ -32,10 +60,37 @@ const InspectorPanel = () => {
<TabPanels>
<TabPanel>
<NodeTemplateInspector />
{template ? (
<Flex
sx={{
flexDir: 'column',
alignItems: 'flex-start',
gap: 2,
h: 'full',
}}
>
<ImageMetadataJSON
jsonObject={template}
label="Node Template"
/>
</Flex>
) : (
<IAINoContentFallback
label={
node
? 'No template found for selected node'
: 'No node selected'
}
icon={null}
/>
)}
</TabPanel>
<TabPanel>
<NodeDataInspector />
{node ? (
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
) : (
<IAINoContentFallback label="No node selected" icon={null} />
)}
</TabPanel>
</TabPanels>
</Tabs>

View File

@ -17,20 +17,20 @@ const selector = createSelector(
);
return {
data: lastSelectedNode?.data,
node: lastSelectedNode,
};
},
defaultSelectorOptions
);
const NodeDataInspector = () => {
const { data } = useAppSelector(selector);
const { node } = useAppSelector(selector);
if (!data) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <ImageMetadataJSON jsonObject={data} label="Node Data" />;
return node ? (
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
) : (
<IAINoContentFallback label="No node data" icon={null} />
);
};
export default memo(NodeDataInspector);

View File

@ -1,40 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector);
if (!template) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <ImageMetadataJSON jsonObject={template} label="Node Template" />;
};
export default memo(NodeTemplateInspector);

View File

@ -6,6 +6,14 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AddFieldToLinearViewDropData } from 'features/dnd/types';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
isInvocationNode,
} from 'features/nodes/types/types';
import { forEach } from 'lodash-es';
import { memo } from 'react';
import LinearViewField from '../../fields/LinearViewField';
import ScrollableContent from '../ScrollableContent';
@ -13,8 +21,41 @@ import ScrollableContent from '../ScrollableContent';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const fields: {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}[] = [];
const { exposedFields } = nodes.workflow;
nodes.nodes.filter(isInvocationNode).forEach((node) => {
const nodeTemplate = nodes.nodeTemplates[node.data.type];
if (!nodeTemplate) {
return;
}
forEach(node.data.inputs, (field) => {
if (
!exposedFields.some(
(f) => f.nodeId === node.id && f.fieldName === field.name
)
) {
return;
}
const fieldTemplate = nodeTemplate.inputs[field.name];
if (!fieldTemplate) {
return;
}
fields.push({
nodeData: node.data,
nodeTemplate,
field,
fieldTemplate,
});
});
});
return {
fields: nodes.workflow.exposedFields,
fields,
};
},
defaultSelectorOptions
@ -48,11 +89,13 @@ const LinearTabContent = () => {
}}
>
{fields.length ? (
fields.map(({ nodeId, fieldName }) => (
fields.map(({ nodeData, nodeTemplate, field, fieldTemplate }) => (
<LinearViewField
key={`${nodeId}-${fieldName}`}
nodeId={nodeId}
fieldName={fieldName}
key={field.id}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
))
) : (

View File

@ -25,9 +25,7 @@ const ClearGraphButton = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null);
const nodesCount = useAppSelector(
(state: RootState) => state.nodes.nodes.length
);
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const handleConfirmClear = useCallback(() => {
dispatch(nodeEditorReset());
@ -51,7 +49,7 @@ const ClearGraphButton = () => {
tooltip={t('nodes.clearGraph')}
aria-label={t('nodes.clearGraph')}
onClick={onOpen}
isDisabled={!nodesCount}
isDisabled={nodes.length === 0}
/>
<AlertDialog

View File

@ -8,7 +8,7 @@ import IAIIconButton, {
import { selectIsReadyNodes } from 'features/nodes/store/selectors';
import ProgressBar from 'features/system/components/ProgressBar';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
@ -18,7 +18,7 @@ interface InvokeButton
iconButton?: boolean;
}
const NodeInvokeButton = (props: InvokeButton) => {
export default function NodeInvokeButton(props: InvokeButton) {
const { iconButton = false, ...rest } = props;
const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector);
@ -92,6 +92,4 @@ const NodeInvokeButton = (props: InvokeButton) => {
</Box>
</Box>
);
};
export default memo(NodeInvokeButton);
}

View File

@ -1,11 +1,11 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSyncAlt } from 'react-icons/fa';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
const ReloadSchemaButton = () => {
export default function ReloadSchemaButton() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
@ -21,6 +21,4 @@ const ReloadSchemaButton = () => {
onClick={handleReloadSchema}
/>
);
};
export default memo(ReloadSchemaButton);
}

View File

@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { InputFieldValue, OutputFieldValue } from 'features/nodes/types/types';
import { useMemo } from 'react';
import { useFieldType } from './useNodeData';
const selectIsConnectionInProgress = createSelector(
stateSelector,
@ -12,19 +12,23 @@ const selectIsConnectionInProgress = createSelector(
nodes.connectionStartParams !== null
);
export type UseConnectionStateProps = {
nodeId: string;
fieldName: string;
kind: 'input' | 'output';
};
export type UseConnectionStateProps =
| {
nodeId: string;
field: InputFieldValue;
kind: 'input';
}
| {
nodeId: string;
field: OutputFieldValue;
kind: 'output';
};
export const useConnectionState = ({
nodeId,
fieldName,
field,
kind,
}: UseConnectionStateProps) => {
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo(
() =>
createSelector(stateSelector, ({ nodes }) =>
@ -33,23 +37,23 @@ export const useConnectionState = ({
return (
(kind === 'input' ? edge.target : edge.source) === nodeId &&
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) ===
fieldName
field.name
);
}).length
)
),
[fieldName, kind, nodeId]
[field.name, kind, nodeId]
);
const selectConnectionError = useMemo(
() =>
makeConnectionErrorSelector(
nodeId,
fieldName,
field.name,
kind === 'input' ? 'target' : 'source',
fieldType
field.type
),
[nodeId, fieldName, kind, fieldType]
[nodeId, field.name, field.type, kind]
);
const selectIsConnectionStartField = useMemo(
@ -57,12 +61,12 @@ export const useConnectionState = ({
createSelector(stateSelector, ({ nodes }) =>
Boolean(
nodes.connectionStartParams?.nodeId === nodeId &&
nodes.connectionStartParams?.handleId === fieldName &&
nodes.connectionStartParams?.handleId === field.name &&
nodes.connectionStartParams?.handleType ===
{ input: 'target', output: 'source' }[kind]
)
),
[fieldName, kind, nodeId]
[field.name, kind, nodeId]
);
const isConnected = useAppSelector(selectIsConnected);

View File

@ -1,286 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map, some } from 'lodash-es';
import { useMemo } from 'react';
import { FOOTER_FIELDS, IMAGE_FIELDS } from '../types/constants';
import { isInvocationNode } from '../types/types';
const KIND_MAP = {
input: 'inputs' as const,
output: 'outputs' as const,
};
export const useNodeTemplate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
};
export const useNodeData = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
return node?.data;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeData = useAppSelector(selector);
return nodeData;
};
export const useFieldData = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data.inputs[fieldName];
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const fieldData = useAppSelector(selector);
return fieldData;
};
export const useFieldType = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldType = useAppSelector(selector);
return fieldType;
};
export const useFieldNames = (nodeId: string, kind: 'input' | 'output') => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return [];
}
return map(node.data[KIND_MAP[kind]], (field) => field.name).filter(
(fieldName) => fieldName !== 'is_intermediate'
);
},
defaultSelectorOptions
),
[kind, nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames;
};
export const useWithFooter = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
FOOTER_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const withFooter = useAppSelector(selector);
return withFooter;
};
export const useHasImageOutput = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const hasImageOutput = useAppSelector(selector);
return hasImageOutput;
};
export const useIsIntermediate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return Boolean(node.data.inputs.is_intermediate?.value);
},
defaultSelectorOptions
),
[nodeId]
);
const is_intermediate = useAppSelector(selector);
return is_intermediate;
};
export const useNodeLabel = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.label;
},
defaultSelectorOptions
),
[nodeId]
);
const label = useAppSelector(selector);
return label;
};
export const useNodeTemplateTitle = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = node
? nodes.nodeTemplates[node.data.type]
: undefined;
return nodeTemplate?.title;
},
defaultSelectorOptions
),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};
export const useFieldTemplate = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate?.[KIND_MAP[kind]][fieldName];
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate;
};
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return Boolean(node?.data.inputs[fieldName]?.value);
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const doesFieldHaveValue = useAppSelector(selector);
return doesFieldHaveValue;
};

View File

@ -9,13 +9,9 @@ export const makeConnectionErrorSelector = (
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType?: FieldType
fieldType: FieldType
) =>
createSelector(stateSelector, (state) => {
if (!fieldType) {
return 'No field type';
}
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes;

View File

@ -6,9 +6,6 @@ export const NODE_WIDTH = 320;
export const NODE_MIN_WIDTH = 320;
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS;
export const COLLECTION_TYPES: FieldType[] = [
'Collection',
'IntegerCollection',

View File

@ -457,13 +457,12 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
};
export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
field: InputFieldValue | OutputFieldValue
): field is InputFieldValue => field.fieldKind === 'input';
export const isInputFieldTemplate = (
fieldTemplate?: InputFieldTemplate | OutputFieldTemplate
): fieldTemplate is InputFieldTemplate =>
Boolean(fieldTemplate && fieldTemplate.fieldKind === 'input');
fieldTemplate: InputFieldTemplate | OutputFieldTemplate
): fieldTemplate is InputFieldTemplate => fieldTemplate.fieldKind === 'input';
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
@ -633,22 +632,20 @@ export type NodeData =
export const isInvocationNode = (
node?: Node<NodeData>
): node is Node<InvocationNodeData> =>
Boolean(node && node.type === 'invocation');
): node is Node<InvocationNodeData> => node?.type === 'invocation';
export const isInvocationNodeData = (
node?: NodeData
): node is InvocationNodeData =>
Boolean(node && !['notes', 'current_image'].includes(node.type));
!['notes', 'current_image'].includes(node?.type ?? '');
export const isNotesNode = (
node?: Node<NodeData>
): node is Node<NotesNodeData> => Boolean(node && node.type === 'notes');
): node is Node<NotesNodeData> => node?.type === 'notes';
export const isProgressImageNode = (
node?: Node<NodeData>
): node is Node<CurrentImageNodeData> =>
Boolean(node && node.type === 'current_image');
): node is Node<CurrentImageNodeData> => node?.type === 'current_image';
export enum NodeStatus {
PENDING,

View File

@ -32,7 +32,6 @@ import {
MAIN_MODEL_LOADER,
MASK_BLUR,
MASK_COMBINE,
MASK_EDGE,
MASK_FROM_ALPHA,
MASK_RESIZE_DOWN,
MASK_RESIZE_UP,
@ -41,10 +40,6 @@ import {
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
SEAM_FIX_DENOISE_LATENTS,
SEAM_MASK_COMBINE,
SEAM_MASK_RESIZE_DOWN,
SEAM_MASK_RESIZE_UP,
} from './constants';
/**
@ -72,12 +67,6 @@ export const buildCanvasOutpaintGraph = (
shouldUseCpuNoise,
maskBlur,
maskBlurMethod,
seamSize,
seamBlur,
seamSteps,
seamStrength,
seamLowThreshold,
seamHighThreshold,
tileSize,
infillMethod,
clipSkip,
@ -141,11 +130,6 @@ export const buildCanvasOutpaintGraph = (
is_intermediate: true,
mask2: canvasMaskImage,
},
[SEAM_MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
@ -181,25 +165,6 @@ export const buildCanvasOutpaintGraph = (
denoising_start: 1 - strength,
denoising_end: 1,
},
[MASK_EDGE]: {
type: 'mask_edge',
id: MASK_EDGE,
is_intermediate: true,
edge_size: seamSize,
edge_blur: seamBlur,
low_threshold: seamLowThreshold,
high_threshold: seamHighThreshold,
},
[SEAM_FIX_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SEAM_FIX_DENOISE_LATENTS,
is_intermediate: true,
steps: seamSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - seamStrength,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
@ -368,61 +333,10 @@ export const buildCanvasOutpaintGraph = (
field: 'seed',
},
},
// Seam Paint
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
},
// Decode the result from Inpaint
{
source: {
node_id: SEAM_FIX_DENOISE_LATENTS,
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
@ -434,6 +348,7 @@ export const buildCanvasOutpaintGraph = (
};
// Add Infill Nodes
if (infillMethod === 'patchmatch') {
graph.nodes[INPAINT_INFILL] = {
type: 'infill_patchmatch',
@ -463,13 +378,6 @@ export const buildCanvasOutpaintGraph = (
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[SEAM_MASK_RESIZE_UP] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
@ -491,13 +399,6 @@ export const buildCanvasOutpaintGraph = (
width: width,
height: height,
};
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation),
@ -539,57 +440,6 @@ export const buildCanvasOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Resize Results Down
{
source: {
@ -603,7 +453,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
node_id: MASK_BLUR,
field: 'image',
},
destination: {
@ -611,16 +461,6 @@ export const buildCanvasOutpaintGraph = (
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
},
{
source: {
node_id: INPAINT_INFILL,
@ -654,7 +494,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_RESIZE_DOWN,
node_id: MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -685,7 +525,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_RESIZE_DOWN,
node_id: MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -713,6 +553,7 @@ export const buildCanvasOutpaintGraph = (
};
graph.nodes[MASK_BLUR] = {
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
image: canvasMaskImage,
};
graph.edges.push(
@ -727,47 +568,6 @@ export const buildCanvasOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Color Correct The Inpainted Result
{
source: {
@ -791,7 +591,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_COMBINE,
node_id: MASK_BLUR,
field: 'image',
},
destination: {
@ -822,7 +622,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_COMBINE,
node_id: MASK_BLUR,
field: 'image',
},
destination: {

View File

@ -29,7 +29,6 @@ import {
LATENTS_TO_IMAGE,
MASK_BLUR,
MASK_COMBINE,
MASK_EDGE,
MASK_FROM_ALPHA,
MASK_RESIZE_DOWN,
MASK_RESIZE_UP,
@ -41,10 +40,6 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SEAM_FIX_DENOISE_LATENTS,
SEAM_MASK_COMBINE,
SEAM_MASK_RESIZE_DOWN,
SEAM_MASK_RESIZE_UP,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -72,12 +67,6 @@ export const buildCanvasSDXLOutpaintGraph = (
shouldUseCpuNoise,
maskBlur,
maskBlurMethod,
seamSize,
seamBlur,
seamSteps,
seamStrength,
seamLowThreshold,
seamHighThreshold,
tileSize,
infillMethod,
} = state.generation;
@ -144,11 +133,6 @@ export const buildCanvasSDXLOutpaintGraph = (
is_intermediate: true,
mask2: canvasMaskImage,
},
[SEAM_MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
@ -186,25 +170,6 @@ export const buildCanvasSDXLOutpaintGraph = (
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[MASK_EDGE]: {
type: 'mask_edge',
id: MASK_EDGE,
is_intermediate: true,
edge_size: seamSize,
edge_blur: seamBlur,
low_threshold: seamLowThreshold,
high_threshold: seamHighThreshold,
},
[SEAM_FIX_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SEAM_FIX_DENOISE_LATENTS,
is_intermediate: true,
steps: seamSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - seamStrength,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
@ -382,61 +347,10 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'seed',
},
},
// Seam Paint
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: SDXL_DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
},
// Decode inpainted latents to image
{
source: {
node_id: SEAM_FIX_DENOISE_LATENTS,
node_id: SDXL_DENOISE_LATENTS,
field: 'latents',
},
destination: {
@ -478,13 +392,6 @@ export const buildCanvasSDXLOutpaintGraph = (
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[SEAM_MASK_RESIZE_UP] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
@ -506,13 +413,6 @@ export const buildCanvasSDXLOutpaintGraph = (
width: width,
height: height,
};
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation),
@ -554,57 +454,6 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Resize Results Down
{
source: {
@ -618,7 +467,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_UP,
node_id: MASK_BLUR,
field: 'image',
},
destination: {
@ -626,16 +475,6 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
},
{
source: {
node_id: INPAINT_INFILL,
@ -669,7 +508,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_RESIZE_DOWN,
node_id: MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -700,7 +539,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_RESIZE_DOWN,
node_id: MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -728,6 +567,7 @@ export const buildCanvasSDXLOutpaintGraph = (
};
graph.nodes[MASK_BLUR] = {
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
image: canvasMaskImage,
};
graph.edges.push(
@ -742,47 +582,6 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Color Correct The Inpainted Result
{
source: {
@ -806,7 +605,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_COMBINE,
node_id: MASK_BLUR,
field: 'image',
},
destination: {
@ -837,7 +636,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: SEAM_MASK_COMBINE,
node_id: MASK_BLUR,
field: 'image',
},
destination: {
@ -870,7 +669,7 @@ export const buildCanvasSDXLOutpaintGraph = (
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SEAM_FIX_DENOISE_LATENTS);
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
}
// optionally add custom VAE

View File

@ -18,6 +18,8 @@ export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';
export const CANVAS_OUTPUT = 'canvas_output';
export const INPAINT = 'inpaint';
export const INPAINT_SEAM_FIX = 'inpaint_seam_fix';
export const INPAINT_IMAGE = 'inpaint_image';
export const SCALED_INPAINT_IMAGE = 'scaled_inpaint_image';
export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up';
@ -25,14 +27,10 @@ export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down';
export const INPAINT_INFILL = 'inpaint_infill';
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
export const INPAINT_FINAL_IMAGE = 'inpaint_final_image';
export const SEAM_FIX_DENOISE_LATENTS = 'seam_fix_denoise_latents';
export const MASK_FROM_ALPHA = 'tomask';
export const MASK_EDGE = 'mask_edge';
export const MASK_BLUR = 'mask_blur';
export const MASK_COMBINE = 'mask_combine';
export const SEAM_MASK_COMBINE = 'seam_mask_combine';
export const SEAM_MASK_RESIZE_UP = 'seam_mask_resize_up';
export const SEAM_MASK_RESIZE_DOWN = 'seam_mask_resize_down';
export const MASK_RESIZE_UP = 'mask_resize_up';
export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const COLOR_CORRECT = 'color_correct';

View File

@ -1,36 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamBlur } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamBlur = () => {
const dispatch = useAppDispatch();
const seamBlur = useAppSelector(
(state: RootState) => state.generation.seamBlur
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamBlur')}
min={0}
max={64}
step={8}
sliderNumberInputProps={{ max: 512 }}
value={seamBlur}
onChange={(v) => {
dispatch(setSeamBlur(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamBlur(8));
}}
/>
);
};
export default memo(ParamSeamBlur);

View File

@ -1,27 +0,0 @@
import { Flex } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamSeamBlur from './ParamSeamBlur';
import ParamSeamSize from './ParamSeamSize';
import ParamSeamSteps from './ParamSeamSteps';
import ParamSeamStrength from './ParamSeamStrength';
import ParamSeamThreshold from './ParamSeamThreshold';
const ParamSeamPaintingCollapse = () => {
const { t } = useTranslation();
return (
<IAICollapse label={t('parameters.seamPaintingHeader')}>
<Flex sx={{ flexDirection: 'column', gap: 2, paddingBottom: 2 }}>
<ParamSeamSize />
<ParamSeamBlur />
<ParamSeamSteps />
<ParamSeamStrength />
<ParamSeamThreshold />
</Flex>
</IAICollapse>
);
};
export default memo(ParamSeamPaintingCollapse);

View File

@ -1,36 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamSize } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamSize = () => {
const dispatch = useAppDispatch();
const seamSize = useAppSelector(
(state: RootState) => state.generation.seamSize
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamSize')}
min={0}
max={128}
step={8}
sliderNumberInputProps={{ max: 512 }}
value={seamSize}
onChange={(v) => {
dispatch(setSeamSize(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamSize(16));
}}
/>
);
};
export default memo(ParamSeamSize);

View File

@ -1,36 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamSteps } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamSteps = () => {
const dispatch = useAppDispatch();
const seamSteps = useAppSelector(
(state: RootState) => state.generation.seamSteps
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamSteps')}
min={0}
max={100}
step={1}
sliderNumberInputProps={{ max: 999 }}
value={seamSteps}
onChange={(v) => {
dispatch(setSeamSteps(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamSteps(20));
}}
/>
);
};
export default memo(ParamSeamSteps);

View File

@ -1,36 +0,0 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamStrength } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamStrength = () => {
const dispatch = useAppDispatch();
const seamStrength = useAppSelector(
(state: RootState) => state.generation.seamStrength
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamStrength')}
min={0}
max={1}
step={0.01}
sliderNumberInputProps={{ max: 999 }}
value={seamStrength}
onChange={(v) => {
dispatch(setSeamStrength(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamStrength(0.7));
}}
/>
);
};
export default memo(ParamSeamStrength);

View File

@ -1,121 +0,0 @@
import {
FormControl,
FormLabel,
HStack,
RangeSlider,
RangeSliderFilledTrack,
RangeSliderMark,
RangeSliderThumb,
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
setSeamHighThreshold,
setSeamLowThreshold,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
const ParamSeamThreshold = () => {
const dispatch = useAppDispatch();
const seamLowThreshold = useAppSelector(
(state: RootState) => state.generation.seamLowThreshold
);
const seamHighThreshold = useAppSelector(
(state: RootState) => state.generation.seamHighThreshold
);
const { t } = useTranslation();
const handleSeamThresholdChange = useCallback(
(v: number[]) => {
dispatch(setSeamLowThreshold(v[0] as number));
dispatch(setSeamHighThreshold(v[1] as number));
},
[dispatch]
);
const handleSeamThresholdReset = () => {
dispatch(setSeamLowThreshold(100));
dispatch(setSeamHighThreshold(200));
};
return (
<FormControl>
<FormLabel>{t('parameters.seamThreshold')}</FormLabel>
<HStack w="100%" gap={4} mt={-2}>
<RangeSlider
aria-label={[
t('parameters.seamLowThreshold'),
t('parameters.seamHighThreshold'),
]}
value={[seamLowThreshold, seamHighThreshold]}
min={0}
max={255}
step={1}
minStepsBetweenThumbs={1}
onChange={handleSeamThresholdChange}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={seamLowThreshold} placement="top" hasArrow>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={seamHighThreshold} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
}}
>
0
</RangeSliderMark>
<RangeSliderMark
value={0.392}
sx={{
insetInlineStart: '38.4% !important',
transform: 'translateX(-38.4%)',
}}
>
100
</RangeSliderMark>
<RangeSliderMark
value={0.784}
sx={{
insetInlineStart: '79.8% !important',
transform: 'translateX(-79.8%)',
}}
>
200
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
}}
>
255
</RangeSliderMark>
</RangeSlider>
<IAIIconButton
size="sm"
aria-label={t('accessibility.reset')}
tooltip={t('accessibility.reset')}
icon={<BiReset />}
onClick={handleSeamThresholdReset}
/>
</HStack>
</FormControl>
);
};
export default memo(ParamSeamThreshold);

View File

@ -37,12 +37,6 @@ export interface GenerationState {
scheduler: SchedulerParam;
maskBlur: number;
maskBlurMethod: MaskBlurMethodParam;
seamSize: number;
seamBlur: number;
seamSteps: number;
seamStrength: StrengthParam;
seamLowThreshold: number;
seamHighThreshold: number;
seed: SeedParam;
seedWeights: string;
shouldFitToWidthHeight: boolean;
@ -80,12 +74,6 @@ export const initialGenerationState: GenerationState = {
scheduler: 'euler',
maskBlur: 16,
maskBlurMethod: 'box',
seamSize: 16,
seamBlur: 8,
seamSteps: 20,
seamStrength: 0.7,
seamLowThreshold: 100,
seamHighThreshold: 200,
seed: 0,
seedWeights: '',
shouldFitToWidthHeight: true,
@ -212,24 +200,6 @@ export const generationSlice = createSlice({
setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => {
state.maskBlurMethod = action.payload;
},
setSeamSize: (state, action: PayloadAction<number>) => {
state.seamSize = action.payload;
},
setSeamBlur: (state, action: PayloadAction<number>) => {
state.seamBlur = action.payload;
},
setSeamSteps: (state, action: PayloadAction<number>) => {
state.seamSteps = action.payload;
},
setSeamStrength: (state, action: PayloadAction<number>) => {
state.seamStrength = action.payload;
},
setSeamLowThreshold: (state, action: PayloadAction<number>) => {
state.seamLowThreshold = action.payload;
},
setSeamHighThreshold: (state, action: PayloadAction<number>) => {
state.seamHighThreshold = action.payload;
},
setTileSize: (state, action: PayloadAction<number>) => {
state.tileSize = action.payload;
},
@ -336,12 +306,6 @@ export const {
setScheduler,
setMaskBlur,
setMaskBlurMethod,
setSeamSize,
setSeamBlur,
setSeamSteps,
setSeamStrength,
setSeamLowThreshold,
setSeamHighThreshold,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,

View File

@ -2,7 +2,6 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
@ -23,7 +22,6 @@ export default function SDXLUnifiedCanvasTabParameters() {
<ParamNoiseCollapse />
<ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse />
<ParamSeamPaintingCollapse />
</>
);
}

View File

@ -6,7 +6,6 @@ import ParamControlNetCollapse from 'features/parameters/components/Parameters/C
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
@ -24,7 +23,6 @@ const UnifiedCanvasParameters = () => {
<ParamSymmetryCollapse />
<ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse />
<ParamSeamPaintingCollapse />
<ParamAdvancedCollapse />
</>
);

View File

@ -3,7 +3,4 @@ import { UIState } from './uiTypes';
/**
* UI slice persist denylist
*/
export const uiPersistDenylist: (keyof UIState)[] = [
'shouldShowImageDetails',
'globalContextMenuCloseTrigger',
];
export const uiPersistDenylist: (keyof UIState)[] = ['shouldShowImageDetails'];

View File

@ -20,7 +20,6 @@ export const initialUIState: UIState = {
shouldShowProgressInViewer: true,
shouldShowEmbeddingPicker: false,
favoriteSchedulers: [],
globalContextMenuCloseTrigger: 0,
};
export const uiSlice = createSlice({
@ -97,9 +96,6 @@ export const uiSlice = createSlice({
toggleEmbeddingPicker: (state) => {
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
},
contextMenusClosed: (state) => {
state.globalContextMenuCloseTrigger += 1;
},
},
extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => {
@ -126,7 +122,6 @@ export const {
setShouldShowProgressInViewer,
favoriteSchedulersChanged,
toggleEmbeddingPicker,
contextMenusClosed,
} = uiSlice.actions;
export default uiSlice.reducer;

View File

@ -26,5 +26,4 @@ export interface UIState {
shouldShowProgressInViewer: boolean;
shouldShowEmbeddingPicker: boolean;
favoriteSchedulers: SchedulerParam[];
globalContextMenuCloseTrigger: number;
}

View File

@ -573,7 +573,7 @@ export type components = {
file: Blob;
};
/**
* Boolean Primitive Collection
* Boolean Collection
* @description A collection of boolean primitive values
*/
BooleanCollectionInvocation: {
@ -619,7 +619,7 @@ export type components = {
collection?: (boolean)[];
};
/**
* Boolean Primitive
* Boolean
* @description A boolean primitive value
*/
BooleanInvocation: {
@ -1002,7 +1002,7 @@ export type components = {
clip?: components["schemas"]["ClipField"];
};
/**
* Conditioning Primitive Collection
* Conditioning Collection
* @description A collection of conditioning tensor primitive values
*/
ConditioningCollectionInvocation: {
@ -1770,7 +1770,7 @@ export type components = {
field: string;
};
/**
* Float Primitive Collection
* Float Collection
* @description A collection of float primitive values
*/
FloatCollectionInvocation: {
@ -1816,7 +1816,7 @@ export type components = {
collection?: (number)[];
};
/**
* Float Primitive
* Float
* @description A float primitive value
*/
FloatInvocation: {
@ -2161,7 +2161,7 @@ export type components = {
channel?: "A" | "R" | "G" | "B";
};
/**
* Image Primitive Collection
* Image Collection
* @description A collection of image primitive values
*/
ImageCollectionInvocation: {
@ -3113,7 +3113,7 @@ export type components = {
seed?: number;
};
/**
* Integer Primitive Collection
* Integer Collection
* @description A collection of integer primitive values
*/
IntegerCollectionInvocation: {
@ -3159,7 +3159,7 @@ export type components = {
collection?: (number)[];
};
/**
* Integer Primitive
* Integer
* @description An integer primitive value
*/
IntegerInvocation: {
@ -3256,7 +3256,7 @@ export type components = {
item?: unknown;
};
/**
* Latents Primitive Collection
* Latents Collection
* @description A collection of latents tensor primitive values
*/
LatentsCollectionInvocation: {
@ -5786,7 +5786,7 @@ export type components = {
show_easing_plot?: boolean;
};
/**
* String Primitive Collection
* String Collection
* @description A collection of string primitive values
*/
StringCollectionInvocation: {
@ -5832,7 +5832,7 @@ export type components = {
collection?: (string)[];
};
/**
* String Primitive
* String
* @description A string primitive value
*/
StringInvocation: {
@ -6193,6 +6193,24 @@ export type components = {
ui_hidden: boolean;
ui_type?: components["schemas"]["UIType"];
};
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
@ -6205,24 +6223,6 @@ export type components = {
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;

View File

@ -3557,6 +3557,11 @@ eslint-plugin-react-hooks@^4.6.0:
resolved "https://registry.yarnpkg.com/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-4.6.0.tgz#4c3e697ad95b77e93f8646aaa1630c1ba607edd3"
integrity sha512-oFc7Itz9Qxh2x4gNHStv3BqJq54ExXmfC+a1NjAta66IAN87Wu0R/QArgIS9qKzX3dXKPI9H5crl9QchNMY9+g==
eslint-plugin-react-memo@^0.0.3:
version "0.0.3"
resolved "https://registry.yarnpkg.com/eslint-plugin-react-memo/-/eslint-plugin-react-memo-0.0.3.tgz#26542aa2eeabed37f354c64c6b4eabc07051cf78"
integrity sha512-IZzLDZJF4V84XL9+v74ypDSts/hAQtNeYFZGc3wvdX+YgIw4pkn4GiXPJ6MNUccNTPYJqr89Nnvqo3rcesEBOQ==
eslint-plugin-react@^7.32.2:
version "7.32.2"
resolved "https://registry.yarnpkg.com/eslint-plugin-react/-/eslint-plugin-react-7.32.2.tgz#e71f21c7c265ebce01bcbc9d0955170c55571f10"