mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
1 Commits
bugfix/mak
...
fix/ui/esl
Author | SHA1 | Date | |
---|---|---|---|
a7606901f3 |
@ -25,10 +25,10 @@ This method is recommended for experienced users and developers
|
||||
#### [Docker Installation](040_INSTALL_DOCKER.md)
|
||||
This method is recommended for those familiar with running Docker containers
|
||||
### Other Installation Guides
|
||||
- [PyPatchMatch](060_INSTALL_PATCHMATCH.md)
|
||||
- [XFormers](070_INSTALL_XFORMERS.md)
|
||||
- [CUDA and ROCm Drivers](030_INSTALL_CUDA_AND_ROCM.md)
|
||||
- [Installing New Models](050_INSTALLING_MODELS.md)
|
||||
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
|
||||
- [XFormers](installation/070_INSTALL_XFORMERS.md)
|
||||
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
|
||||
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
|
||||
|
||||
## :fontawesome-solid-computer: Hardware Requirements
|
||||
|
||||
|
@ -40,7 +40,7 @@ async def create_session(
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[dict]}},
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
|
@ -48,7 +48,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("Boolean Primitive")
|
||||
@title("Boolean")
|
||||
@tags("primitives", "boolean")
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
@ -62,7 +62,7 @@ class BooleanInvocation(BaseInvocation):
|
||||
return BooleanOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Boolean Primitive Collection")
|
||||
@title("Boolean Collection")
|
||||
@tags("primitives", "boolean", "collection")
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
@ -101,7 +101,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("Integer Primitive")
|
||||
@title("Integer")
|
||||
@tags("primitives", "integer")
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
@ -115,7 +115,7 @@ class IntegerInvocation(BaseInvocation):
|
||||
return IntegerOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Integer Primitive Collection")
|
||||
@title("Integer Collection")
|
||||
@tags("primitives", "integer", "collection")
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
@ -154,7 +154,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("Float Primitive")
|
||||
@title("Float")
|
||||
@tags("primitives", "float")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
@ -168,7 +168,7 @@ class FloatInvocation(BaseInvocation):
|
||||
return FloatOutput(a=self.param)
|
||||
|
||||
|
||||
@title("Float Primitive Collection")
|
||||
@title("Float Collection")
|
||||
@tags("primitives", "float", "collection")
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
@ -207,7 +207,7 @@ class StringCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("String Primitive")
|
||||
@title("String")
|
||||
@tags("primitives", "string")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
@ -221,7 +221,7 @@ class StringInvocation(BaseInvocation):
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
@title("String Primitive Collection")
|
||||
@title("String Collection")
|
||||
@tags("primitives", "string", "collection")
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
@ -289,7 +289,7 @@ class ImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Image Primitive Collection")
|
||||
@title("Image Collection")
|
||||
@tags("primitives", "image", "collection")
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
@ -357,7 +357,7 @@ class LatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(self.latents.latents_name, latents)
|
||||
|
||||
|
||||
@title("Latents Primitive Collection")
|
||||
@title("Latents Collection")
|
||||
@tags("primitives", "latents", "collection")
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
@ -475,7 +475,7 @@ class ConditioningInvocation(BaseInvocation):
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@title("Conditioning Primitive Collection")
|
||||
@title("Conditioning Collection")
|
||||
@tags("primitives", "conditioning", "collection")
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
@ -29,7 +29,6 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme
|
||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
@ -43,11 +42,6 @@ import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
@ -95,8 +89,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
@ -105,8 +97,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -125,9 +115,6 @@ class NodeStats:
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -146,62 +133,31 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
# {graph_id => NodeLog}
|
||||
self._stats: Dict[str, NodeLog] = {}
|
||||
self._cache_stats: Dict[str, CacheStats] = {}
|
||||
self.ram_used: float = 0.0
|
||||
self.ram_changed: float = 0.0
|
||||
|
||||
class StatsContext:
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
invocation: BaseInvocation = None
|
||||
collector: "InvocationStatsServiceBase" = None
|
||||
graph_id: str = None
|
||||
start_time: int = 0
|
||||
ram_used: int = 0
|
||||
model_manager: ModelManagerService = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
collector: "InvocationStatsServiceBase",
|
||||
):
|
||||
"""Initialize statistics for this run."""
|
||||
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.ram_used = psutil.Process().memory_info().rss
|
||||
if self.model_manager:
|
||||
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
|
||||
def __exit__(self, *args):
|
||||
"""Called on exit from the context."""
|
||||
ram_used = psutil.Process().memory_info().rss
|
||||
self.collector.update_mem_stats(
|
||||
ram_used=ram_used / GIG,
|
||||
ram_changed=(ram_used - self.ram_used) / GIG,
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type,
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
self.graph_id,
|
||||
self.invocation.type,
|
||||
time.time() - self.start_time,
|
||||
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
"""
|
||||
Return a context object that will capture the statistics.
|
||||
@ -210,8 +166,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
"""
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||
return self.StatsContext(invocation, graph_execution_state_id, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
@ -224,36 +179,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
except KeyError:
|
||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
self.ram_used = ram_used
|
||||
self.ram_changed = ram_changed
|
||||
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
:param time_used: Floating point seconds used by node's exection
|
||||
"""
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
@ -265,7 +197,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
def log_stats(self):
|
||||
"""
|
||||
Send the statistics to the system logger at the info level.
|
||||
Stats will only be printed when the execution of the graph
|
||||
Stats will only be printed if when the execution of the graph
|
||||
is complete.
|
||||
"""
|
||||
completed = set()
|
||||
@ -276,30 +208,16 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
|
||||
total_time = 0
|
||||
logger.info(f"Graph stats: {graph_id}")
|
||||
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
|
||||
logger.info("Node Calls Seconds VRAM Used")
|
||||
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
|
||||
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
|
||||
total_time += stats.time_used
|
||||
|
||||
cache_stats = self._cache_stats[graph_id]
|
||||
hwm = cache_stats.high_watermark / GIG
|
||||
tot = cache_stats.cache_size / GIG
|
||||
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
||||
|
||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||
logger.info(f"RAM used to load models: {loaded:4.2f}G")
|
||||
if torch.cuda.is_available():
|
||||
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
|
||||
logger.info("RAM cache statistics:")
|
||||
logger.info(f" Model cache hits: {cache_stats.hits}")
|
||||
logger.info(f" Model cache misses: {cache_stats.misses}")
|
||||
logger.info(f" Models cached: {cache_stats.in_cache}")
|
||||
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
|
||||
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
|
||||
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
|
||||
|
||||
completed.add(graph_id)
|
||||
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
@ -22,7 +22,6 @@ from invokeai.backend.model_management import (
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -277,13 +276,6 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
@ -508,12 +500,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
|
@ -86,9 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
|
@ -1,7 +1,6 @@
|
||||
import sqlite3
|
||||
import json
|
||||
from threading import Lock
|
||||
from typing import Generic, Optional, TypeVar, Union, get_args
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
@ -100,7 +99,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -109,7 +108,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = [json.loads(r[0]) for r in result]
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
@ -118,9 +117,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -129,7 +128,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = [json.loads(r[0]) for r in result]
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
@ -141,4 +140,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
@ -21,12 +21,12 @@ import os
|
||||
import sys
|
||||
import hashlib
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, types, Optional, Type, Any
|
||||
|
||||
import torch
|
||||
|
||||
import logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
|
||||
@ -41,18 +41,6 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
# {submodel_key => size}
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
@ -127,9 +115,6 @@ class ModelCache(object):
|
||||
self.sha_chunksize = sha_chunksize
|
||||
self.logger = logger
|
||||
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
|
||||
@ -196,14 +181,13 @@ class ModelCache(object):
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(
|
||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||
)
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
|
||||
# this will remove older cached models until
|
||||
# there is sufficient room to load the requested model
|
||||
@ -217,17 +201,6 @@ class ModelCache(object):
|
||||
|
||||
cache_entry = _CacheRecord(self, model, mem_used)
|
||||
self._cached_models[key] = cache_entry
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
|
||||
if self.stats:
|
||||
self.stats.cache_size = self.max_cache_size * GIG
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[key] = max(
|
||||
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
@ -307,14 +280,14 @@ class ModelCache(object):
|
||||
"""
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
"""
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"""Return the current size of the cache, in GB."""
|
||||
return self._cache_size() / GIG
|
||||
"Return the current size of the cache, in GB"
|
||||
current_cache_size = sum([m.size for m in self._cached_models.values()])
|
||||
return current_cache_size / GIG
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.execution_device.type == "cuda"
|
||||
@ -337,15 +310,12 @@ class ModelCache(object):
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
|
||||
)
|
||||
|
||||
def _cache_size(self) -> int:
|
||||
return sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
def _make_cache_room(self, model_size):
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self._cache_size()
|
||||
current_size = sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
@ -394,8 +364,6 @@ class ModelCache(object):
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
if self.stats:
|
||||
self.stats.cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
@ -240,7 +240,6 @@ class InvokeAIDiffuserComponent:
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
|
@ -4,15 +4,8 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlnetMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.embeddings import (
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
@ -25,11 +18,10 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
import diffusers
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
|
||||
# TODO: create PR to diffusers
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@ -60,25 +52,12 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
@ -119,15 +98,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
@ -135,7 +109,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -163,9 +136,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
@ -175,43 +145,16 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
@ -235,29 +178,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
@ -292,7 +212,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
@ -329,7 +248,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
@ -359,22 +277,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
@ -560,7 +463,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
guess_mode: bool = False,
|
||||
@ -584,9 +486,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||
@ -649,7 +549,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
@ -661,34 +560,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if "addition_embed_type" in self.config:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
|
@ -20,9 +20,16 @@ module.exports = {
|
||||
ecmaVersion: 2018,
|
||||
sourceType: 'module',
|
||||
},
|
||||
plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'],
|
||||
plugins: ['react', '@typescript-eslint', 'react-hooks', 'react-memo'],
|
||||
root: true,
|
||||
rules: {
|
||||
curly: 'error',
|
||||
'react-memo/require-memo': 'error',
|
||||
'react-memo/require-usememo': 'error',
|
||||
'react/jsx-curly-brace-presence': [
|
||||
'error',
|
||||
{ props: 'never', children: 'never' },
|
||||
],
|
||||
'react-hooks/exhaustive-deps': 'error',
|
||||
'no-var': 'error',
|
||||
'brace-style': 'error',
|
||||
|
@ -76,6 +76,7 @@
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"downshift": "^7.6.0",
|
||||
"eslint-plugin-react-memo": "^0.0.3",
|
||||
"formik": "^2.4.2",
|
||||
"framer-motion": "^10.12.17",
|
||||
"fuse.js": "^6.6.2",
|
||||
|
@ -506,14 +506,10 @@
|
||||
"maskAdjustmentsHeader": "Mask Adjustments",
|
||||
"maskBlur": "Mask Blur",
|
||||
"maskBlurMethod": "Mask Blur Method",
|
||||
"seamPaintingHeader": "Seam Painting",
|
||||
"seamSize": "Seam Size",
|
||||
"seamBlur": "Seam Blur",
|
||||
"seamSteps": "Seam Steps",
|
||||
"seamStrength": "Seam Strength",
|
||||
"seamThreshold": "Seam Threshold",
|
||||
"seamLowThreshold": "Low",
|
||||
"seamHighThreshold": "High",
|
||||
"seamSteps": "Seam Steps",
|
||||
"scaleBeforeProcessing": "Scale Before Processing",
|
||||
"scaledWidth": "Scaled W",
|
||||
"scaledHeight": "Scaled H",
|
||||
|
@ -121,7 +121,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTOs, imagesUsage } = action.payload;
|
||||
|
||||
if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
|
||||
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
|
||||
// handle singles in separate listener
|
||||
return;
|
||||
}
|
||||
|
@ -1,126 +0,0 @@
|
||||
/**
|
||||
* This is a copy-paste of https://github.com/lukasbach/chakra-ui-contextmenu with a small change.
|
||||
*
|
||||
* The reactflow background element somehow prevents the chakra `useOutsideClick()` hook from working.
|
||||
* With a menu open, clicking on the reactflow background element doesn't close the menu.
|
||||
*
|
||||
* Reactflow does provide an `onPaneClick` to handle clicks on the background element, but it is not
|
||||
* straightforward to programatically close the menu.
|
||||
*
|
||||
* As a (hopefully temporary) workaround, we will use a dirty hack:
|
||||
* - create `globalContextMenuCloseTrigger: number` in `ui` slice
|
||||
* - increment it in `onPaneClick`
|
||||
* - `useEffect()` to close the menu when `globalContextMenuCloseTrigger` changes
|
||||
*/
|
||||
|
||||
import {
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuButtonProps,
|
||||
MenuProps,
|
||||
Portal,
|
||||
PortalProps,
|
||||
useEventListener,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import * as React from 'react';
|
||||
import {
|
||||
MutableRefObject,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react';
|
||||
|
||||
export interface IAIContextMenuProps<T extends HTMLElement> {
|
||||
renderMenu: () => JSX.Element | null;
|
||||
children: (ref: MutableRefObject<T | null>) => JSX.Element | null;
|
||||
menuProps?: Omit<MenuProps, 'children'> & { children?: React.ReactNode };
|
||||
portalProps?: Omit<PortalProps, 'children'> & { children?: React.ReactNode };
|
||||
menuButtonProps?: MenuButtonProps;
|
||||
}
|
||||
|
||||
export function IAIContextMenu<T extends HTMLElement = HTMLElement>(
|
||||
props: IAIContextMenuProps<T>
|
||||
) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [isRendered, setIsRendered] = useState(false);
|
||||
const [isDeferredOpen, setIsDeferredOpen] = useState(false);
|
||||
const [position, setPosition] = useState<[number, number]>([0, 0]);
|
||||
const targetRef = useRef<T>(null);
|
||||
|
||||
const globalContextMenuCloseTrigger = useAppSelector(
|
||||
(state) => state.ui.globalContextMenuCloseTrigger
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
setTimeout(() => {
|
||||
setIsRendered(true);
|
||||
setTimeout(() => {
|
||||
setIsDeferredOpen(true);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
setIsDeferredOpen(false);
|
||||
const timeout = setTimeout(() => {
|
||||
setIsRendered(isOpen);
|
||||
}, 1000);
|
||||
return () => clearTimeout(timeout);
|
||||
}
|
||||
}, [isOpen]);
|
||||
|
||||
useEffect(() => {
|
||||
setIsOpen(false);
|
||||
setIsDeferredOpen(false);
|
||||
setIsRendered(false);
|
||||
}, [globalContextMenuCloseTrigger]);
|
||||
|
||||
useEventListener('contextmenu', (e) => {
|
||||
if (
|
||||
targetRef.current?.contains(e.target as HTMLElement) ||
|
||||
e.target === targetRef.current
|
||||
) {
|
||||
e.preventDefault();
|
||||
setIsOpen(true);
|
||||
setPosition([e.pageX, e.pageY]);
|
||||
} else {
|
||||
setIsOpen(false);
|
||||
}
|
||||
});
|
||||
|
||||
const onCloseHandler = useCallback(() => {
|
||||
props.menuProps?.onClose?.();
|
||||
setIsOpen(false);
|
||||
}, [props.menuProps]);
|
||||
|
||||
return (
|
||||
<>
|
||||
{props.children(targetRef)}
|
||||
{isRendered && (
|
||||
<Portal {...props.portalProps}>
|
||||
<Menu
|
||||
isOpen={isDeferredOpen}
|
||||
gutter={0}
|
||||
{...props.menuProps}
|
||||
onClose={onCloseHandler}
|
||||
>
|
||||
<MenuButton
|
||||
aria-hidden={true}
|
||||
w={1}
|
||||
h={1}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
left: position[0],
|
||||
top: position[1],
|
||||
cursor: 'default',
|
||||
}}
|
||||
{...props.menuButtonProps}
|
||||
/>
|
||||
{props.renderMenu()}
|
||||
</Menu>
|
||||
</Portal>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
@ -16,7 +16,6 @@ import ImageContextMenu from 'features/gallery/components/ImageContextMenu/Image
|
||||
import {
|
||||
MouseEvent,
|
||||
ReactElement,
|
||||
ReactNode,
|
||||
SyntheticEvent,
|
||||
memo,
|
||||
useCallback,
|
||||
@ -33,17 +32,6 @@ import {
|
||||
TypesafeDroppableData,
|
||||
} from 'features/dnd/types';
|
||||
|
||||
const defaultUploadElement = (
|
||||
<Icon
|
||||
as={FaUpload}
|
||||
sx={{
|
||||
boxSize: 16,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
|
||||
const defaultNoContentFallback = <IAINoContentFallback icon={FaImage} />;
|
||||
|
||||
type IAIDndImageProps = FlexProps & {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
@ -59,14 +47,13 @@ type IAIDndImageProps = FlexProps & {
|
||||
fitContainer?: boolean;
|
||||
droppableData?: TypesafeDroppableData;
|
||||
draggableData?: TypesafeDraggableData;
|
||||
dropLabel?: ReactNode;
|
||||
dropLabel?: string;
|
||||
isSelected?: boolean;
|
||||
thumbnail?: boolean;
|
||||
noContentFallback?: ReactElement;
|
||||
useThumbailFallback?: boolean;
|
||||
withHoverOverlay?: boolean;
|
||||
children?: JSX.Element;
|
||||
uploadElement?: ReactNode;
|
||||
};
|
||||
|
||||
const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
@ -87,8 +74,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
dropLabel,
|
||||
isSelected = false,
|
||||
thumbnail = false,
|
||||
noContentFallback = defaultNoContentFallback,
|
||||
uploadElement = defaultUploadElement,
|
||||
noContentFallback = <IAINoContentFallback icon={FaImage} />,
|
||||
useThumbailFallback,
|
||||
withHoverOverlay = false,
|
||||
children,
|
||||
@ -207,7 +193,12 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
{...getUploadButtonProps()}
|
||||
>
|
||||
<input {...getUploadInputProps()} />
|
||||
{uploadElement}
|
||||
<Icon
|
||||
as={FaUpload}
|
||||
sx={{
|
||||
boxSize: 16,
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
@ -219,7 +210,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
onClick={onClick}
|
||||
/>
|
||||
)}
|
||||
{children}
|
||||
{!isDropDisabled && (
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
@ -227,6 +217,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{children}
|
||||
</Flex>
|
||||
)}
|
||||
</ImageContextMenu>
|
||||
|
@ -1,8 +1,5 @@
|
||||
import { MenuList } from '@chakra-ui/react';
|
||||
import {
|
||||
IAIContextMenu,
|
||||
IAIContextMenuProps,
|
||||
} from 'common/components/IAIContextMenu';
|
||||
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
||||
import { MouseEvent, memo, useCallback } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
@ -15,7 +12,7 @@ import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
|
||||
|
||||
type Props = {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
children: IAIContextMenuProps<HTMLDivElement>['children'];
|
||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
@ -36,7 +33,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<IAIContextMenu<HTMLDivElement>
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
menuButtonProps={{
|
||||
bg: 'transparent',
|
||||
@ -71,7 +68,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</IAIContextMenu>
|
||||
</ContextMenu>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -2,9 +2,8 @@ import { Badge, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
BaseEdge,
|
||||
EdgeLabelRenderer,
|
||||
@ -21,165 +20,78 @@ const makeEdgeSelector = (
|
||||
targetHandleId: string | null | undefined,
|
||||
selected?: boolean
|
||||
) =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
createSelector(stateSelector, ({ nodes }) => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
const isInvocationToInvocationEdge =
|
||||
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
const isInvocationToInvocationEdge =
|
||||
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
|
||||
const isSelected =
|
||||
sourceNode?.selected || targetNode?.selected || selected;
|
||||
const sourceType = isInvocationToInvocationEdge
|
||||
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
|
||||
: undefined;
|
||||
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
|
||||
const sourceType = isInvocationToInvocationEdge
|
||||
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
|
||||
: undefined;
|
||||
|
||||
const stroke =
|
||||
sourceType && nodes.shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[sourceType].color)
|
||||
: colorTokenToCssVar('base.500');
|
||||
const stroke =
|
||||
sourceType && nodes.shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[sourceType].color)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
};
|
||||
});
|
||||
|
||||
const CollapsedEdge = ({
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
data,
|
||||
selected,
|
||||
source,
|
||||
target,
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps<{ count: number }>) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
makeEdgeSelector(
|
||||
source,
|
||||
sourceHandleId,
|
||||
target,
|
||||
targetHandleId,
|
||||
selected
|
||||
),
|
||||
[selected, source, sourceHandleId, target, targetHandleId]
|
||||
);
|
||||
|
||||
const CollapsedEdge = memo(
|
||||
({
|
||||
const { isSelected, shouldAnimate } = useAppSelector(selector);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
data,
|
||||
selected,
|
||||
source,
|
||||
target,
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps<{ count: number }>) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
makeEdgeSelector(
|
||||
source,
|
||||
sourceHandleId,
|
||||
target,
|
||||
targetHandleId,
|
||||
selected
|
||||
),
|
||||
[selected, source, sourceHandleId, target, targetHandleId]
|
||||
);
|
||||
});
|
||||
|
||||
const { isSelected, shouldAnimate } = useAppSelector(selector);
|
||||
const { base500 } = useChakraThemeTokens();
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
targetX,
|
||||
targetY,
|
||||
targetPosition,
|
||||
});
|
||||
|
||||
const { base500 } = useChakraThemeTokens();
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
markerEnd={markerEnd}
|
||||
style={{
|
||||
strokeWidth: isSelected ? 3 : 2,
|
||||
stroke: base500,
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
animation: shouldAnimate
|
||||
? 'dashdraw 0.5s linear infinite'
|
||||
: undefined,
|
||||
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||
}}
|
||||
/>
|
||||
{data?.count && data.count > 1 && (
|
||||
<EdgeLabelRenderer>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
|
||||
}}
|
||||
className="nodrag nopan"
|
||||
>
|
||||
<Badge
|
||||
variant="solid"
|
||||
sx={{
|
||||
bg: 'base.500',
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
boxShadow: 'base',
|
||||
}}
|
||||
>
|
||||
{data.count}
|
||||
</Badge>
|
||||
</Flex>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
CollapsedEdge.displayName = 'CollapsedEdge';
|
||||
|
||||
const DefaultEdge = memo(
|
||||
({
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
selected,
|
||||
source,
|
||||
target,
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
makeEdgeSelector(
|
||||
source,
|
||||
sourceHandleId,
|
||||
target,
|
||||
targetHandleId,
|
||||
selected
|
||||
),
|
||||
[source, sourceHandleId, target, targetHandleId, selected]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
||||
|
||||
const [edgePath] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
targetX,
|
||||
targetY,
|
||||
targetPosition,
|
||||
});
|
||||
|
||||
return (
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
markerEnd={markerEnd}
|
||||
style={{
|
||||
strokeWidth: isSelected ? 3 : 2,
|
||||
stroke,
|
||||
stroke: base500,
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
animation: shouldAnimate
|
||||
? 'dashdraw 0.5s linear infinite'
|
||||
@ -187,11 +99,83 @@ const DefaultEdge = memo(
|
||||
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
{data?.count && data.count > 1 && (
|
||||
<EdgeLabelRenderer>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
|
||||
}}
|
||||
className="nodrag nopan"
|
||||
>
|
||||
<Badge
|
||||
variant="solid"
|
||||
sx={{
|
||||
bg: 'base.500',
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
boxShadow: 'base',
|
||||
}}
|
||||
>
|
||||
{data.count}
|
||||
</Badge>
|
||||
</Flex>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
DefaultEdge.displayName = 'DefaultEdge';
|
||||
const DefaultEdge = ({
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
selected,
|
||||
source,
|
||||
target,
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
makeEdgeSelector(
|
||||
source,
|
||||
sourceHandleId,
|
||||
target,
|
||||
targetHandleId,
|
||||
selected
|
||||
),
|
||||
[source, sourceHandleId, target, targetHandleId, selected]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
||||
|
||||
const [edgePath] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
targetX,
|
||||
targetY,
|
||||
targetPosition,
|
||||
});
|
||||
|
||||
return (
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
markerEnd={markerEnd}
|
||||
style={{
|
||||
strokeWidth: isSelected ? 3 : 2,
|
||||
stroke,
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
|
||||
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export const edgeTypes = {
|
||||
collapsed: CollapsedEdge,
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { useToken } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
Background,
|
||||
@ -115,10 +114,6 @@ export const Flow = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handlePaneClick = useCallback(() => {
|
||||
dispatch(contextMenusClosed());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
defaultViewport={viewport}
|
||||
@ -137,13 +132,12 @@ export const Flow = () => {
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
onSelectionChange={handleSelectionChange}
|
||||
isValidConnection={isValidConnection}
|
||||
minZoom={0.1}
|
||||
minZoom={0.2}
|
||||
snapToGrid={shouldSnapToGrid}
|
||||
snapGrid={[25, 25]}
|
||||
connectionRadius={30}
|
||||
proOptions={proOptions}
|
||||
style={{ borderRadius }}
|
||||
onPaneClick={handlePaneClick}
|
||||
>
|
||||
<TopLeftPanel />
|
||||
<TopCenterPanel />
|
||||
|
@ -1,34 +1,40 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useFieldNames, useWithFooter } from 'features/nodes/hooks/useNodeData';
|
||||
import { memo } from 'react';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { map, some } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import InputField from '../fields/InputField';
|
||||
import OutputField from '../fields/OutputField';
|
||||
import NodeFooter from './NodeFooter';
|
||||
import NodeFooter, { FOOTER_FIELDS } from './NodeFooter';
|
||||
import NodeHeader from './NodeHeader';
|
||||
import NodeWrapper from './NodeWrapper';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
isOpen: boolean;
|
||||
label: string;
|
||||
type: string;
|
||||
selected: boolean;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
};
|
||||
|
||||
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
const inputFieldNames = useFieldNames(nodeId, 'input');
|
||||
const outputFieldNames = useFieldNames(nodeId, 'output');
|
||||
const withFooter = useWithFooter(nodeId);
|
||||
const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => {
|
||||
const { id: nodeId, data } = nodeProps;
|
||||
const { inputs, outputs, isOpen } = data;
|
||||
|
||||
const inputFields = useMemo(
|
||||
() => map(inputs).filter((i) => i.name !== 'is_intermediate'),
|
||||
[inputs]
|
||||
);
|
||||
const outputFields = useMemo(() => map(outputs), [outputs]);
|
||||
|
||||
const withFooter = useMemo(
|
||||
() => some(outputs, (output) => FOOTER_FIELDS.includes(output.type)),
|
||||
[outputs]
|
||||
);
|
||||
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<NodeHeader
|
||||
nodeId={nodeId}
|
||||
isOpen={isOpen}
|
||||
label={label}
|
||||
selected={selected}
|
||||
type={type}
|
||||
/>
|
||||
<NodeWrapper nodeProps={nodeProps}>
|
||||
<NodeHeader nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
{isOpen && (
|
||||
<>
|
||||
<Flex
|
||||
@ -48,23 +54,27 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
className="nopan"
|
||||
sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}
|
||||
>
|
||||
{outputFieldNames.map((fieldName) => (
|
||||
{outputFields.map((field) => (
|
||||
<OutputField
|
||||
key={`${nodeId}.${fieldName}.output-field`}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
key={`${nodeId}.${field.id}.input-field`}
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
/>
|
||||
))}
|
||||
{inputFieldNames.map((fieldName) => (
|
||||
{inputFields.map((field) => (
|
||||
<InputField
|
||||
key={`${nodeId}.${fieldName}.input-field`}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
key={`${nodeId}.${field.id}.input-field`}
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
{withFooter && <NodeFooter nodeId={nodeId} />}
|
||||
{withFooter && (
|
||||
<NodeFooter nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</NodeWrapper>
|
||||
|
@ -2,15 +2,16 @@ import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NodeData } from 'features/nodes/types/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useUpdateNodeInternals } from 'reactflow';
|
||||
import { NodeProps, useUpdateNodeInternals } from 'reactflow';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
isOpen: boolean;
|
||||
nodeProps: NodeProps<NodeData>;
|
||||
}
|
||||
|
||||
const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
|
||||
const NodeCollapseButton = (props: Props) => {
|
||||
const { id: nodeId, isOpen } = props.nodeProps.data;
|
||||
const dispatch = useAppDispatch();
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
|
||||
|
@ -1,17 +1,20 @@
|
||||
import { useColorModeValue } from '@chakra-ui/react';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { map } from 'lodash-es';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
import { Handle, NodeProps, Position } from 'reactflow';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
}
|
||||
|
||||
const NodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const NodeCollapsedHandles = (props: Props) => {
|
||||
const { data } = props.nodeProps;
|
||||
const { base400, base600 } = useChakraThemeTokens();
|
||||
const backgroundColor = useColorModeValue(base400, base600);
|
||||
|
||||
@ -27,10 +30,6 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
[backgroundColor]
|
||||
);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Handle
|
||||
@ -45,7 +44,7 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
key={`${data.id}-${input.name}-collapsed-input-handle`}
|
||||
type="target"
|
||||
id={input.name}
|
||||
isConnectable={false}
|
||||
isValidConnection={() => false}
|
||||
position={Position.Left}
|
||||
style={{ visibility: 'hidden' }}
|
||||
/>
|
||||
@ -53,6 +52,7 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
<Handle
|
||||
type="source"
|
||||
id={`${data.id}-collapsed-source`}
|
||||
isValidConnection={() => false}
|
||||
isConnectable={false}
|
||||
position={Position.Right}
|
||||
style={{ ...dummyHandleStyles, right: '-0.5rem' }}
|
||||
@ -62,7 +62,7 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
key={`${data.id}-${output.name}-collapsed-output-handle`}
|
||||
type="source"
|
||||
id={output.name}
|
||||
isConnectable={false}
|
||||
isValidConnection={() => false}
|
||||
position={Position.Right}
|
||||
style={{ visibility: 'hidden' }}
|
||||
/>
|
||||
|
@ -6,19 +6,49 @@ import {
|
||||
Spacer,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
useHasImageOutput,
|
||||
useIsIntermediate,
|
||||
} from 'features/nodes/hooks/useNodeData';
|
||||
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { some } from 'lodash-es';
|
||||
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
|
||||
export const FOOTER_FIELDS = IMAGE_FIELDS;
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
};
|
||||
|
||||
const NodeFooter = ({ nodeId }: Props) => {
|
||||
const NodeFooter = (props: Props) => {
|
||||
const { nodeProps, nodeTemplate } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const hasImageOutput = useMemo(
|
||||
() =>
|
||||
some(nodeTemplate?.outputs, (output) =>
|
||||
IMAGE_FIELDS.includes(output.type)
|
||||
),
|
||||
[nodeTemplate?.outputs]
|
||||
);
|
||||
|
||||
const handleChangeIsIntermediate = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
fieldBooleanValueChanged({
|
||||
nodeId: nodeProps.data.id,
|
||||
fieldName: 'is_intermediate',
|
||||
value: !e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, nodeProps.data.id]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
@ -32,45 +62,19 @@ const NodeFooter = ({ nodeId }: Props) => {
|
||||
}}
|
||||
>
|
||||
<Spacer />
|
||||
<SaveImageCheckbox nodeId={nodeId} />
|
||||
{hasImageOutput && (
|
||||
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
|
||||
<Checkbox
|
||||
className="nopan"
|
||||
size="sm"
|
||||
onChange={handleChangeIsIntermediate}
|
||||
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeFooter);
|
||||
|
||||
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const is_intermediate = useIsIntermediate(nodeId);
|
||||
const handleChangeIsIntermediate = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
fieldBooleanValueChanged({
|
||||
nodeId,
|
||||
fieldName: 'is_intermediate',
|
||||
value: !e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
|
||||
if (!hasImageOutput) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
|
||||
<Checkbox
|
||||
className="nopan"
|
||||
size="sm"
|
||||
onChange={handleChangeIsIntermediate}
|
||||
isChecked={!is_intermediate}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
SaveImageCheckbox.displayName = 'SaveImageCheckbox';
|
||||
|
@ -1,5 +1,10 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import NodeCollapseButton from '../Invocation/NodeCollapseButton';
|
||||
import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles';
|
||||
import NodeNotesEdit from '../Invocation/NodeNotesEdit';
|
||||
@ -7,14 +12,14 @@ import NodeStatusIndicator from '../Invocation/NodeStatusIndicator';
|
||||
import NodeTitle from '../Invocation/NodeTitle';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
isOpen: boolean;
|
||||
label: string;
|
||||
type: string;
|
||||
selected: boolean;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
};
|
||||
|
||||
const NodeHeader = ({ nodeId, isOpen }: Props) => {
|
||||
const NodeHeader = (props: Props) => {
|
||||
const { nodeProps, nodeTemplate } = props;
|
||||
const { isOpen } = nodeProps.data;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
@ -30,13 +35,18 @@ const NodeHeader = ({ nodeId, isOpen }: Props) => {
|
||||
_dark: { color: 'base.200' },
|
||||
}}
|
||||
>
|
||||
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
||||
<NodeTitle nodeId={nodeId} />
|
||||
<NodeCollapseButton nodeProps={nodeProps} />
|
||||
<NodeTitle nodeData={nodeProps.data} title={nodeTemplate.title} />
|
||||
<Flex alignItems="center">
|
||||
<NodeStatusIndicator nodeId={nodeId} />
|
||||
<NodeNotesEdit nodeId={nodeId} />
|
||||
<NodeStatusIndicator nodeProps={nodeProps} />
|
||||
<NodeNotesEdit nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
</Flex>
|
||||
{!isOpen && <NodeCollapsedHandles nodeId={nodeId} />}
|
||||
{!isOpen && (
|
||||
<NodeCollapsedHandles
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -16,31 +16,41 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import {
|
||||
useNodeData,
|
||||
useNodeLabel,
|
||||
useNodeTemplate,
|
||||
useNodeTemplateTitle,
|
||||
} from 'features/nodes/hooks/useNodeData';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
}
|
||||
|
||||
const NodeNotesEdit = ({ nodeId }: Props) => {
|
||||
const NodeNotesEdit = (props: Props) => {
|
||||
const { nodeProps, nodeTemplate } = props;
|
||||
const { data } = nodeProps;
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const title = useNodeTemplateTitle(nodeId);
|
||||
const dispatch = useAppDispatch();
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
|
||||
},
|
||||
[data.id, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Tooltip
|
||||
label={<TooltipContent nodeId={nodeId} />}
|
||||
label={
|
||||
nodeTemplate ? (
|
||||
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
) : undefined
|
||||
}
|
||||
placement="top"
|
||||
shouldWrapChildren
|
||||
>
|
||||
@ -65,10 +75,19 @@ const NodeNotesEdit = ({ nodeId }: Props) => {
|
||||
<Modal isOpen={isOpen} onClose={onClose} isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{label || title || 'Unknown Node'}</ModalHeader>
|
||||
<ModalHeader>
|
||||
{data.label || nodeTemplate?.title || 'Unknown Node'}
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<NotesTextarea nodeId={nodeId} />
|
||||
<FormControl>
|
||||
<FormLabel>Notes</FormLabel>
|
||||
<IAITextarea
|
||||
value={data.notes}
|
||||
onChange={handleNotesChanged}
|
||||
rows={10}
|
||||
/>
|
||||
</FormControl>
|
||||
</ModalBody>
|
||||
<ModalFooter />
|
||||
</ModalContent>
|
||||
@ -79,49 +98,16 @@ const NodeNotesEdit = ({ nodeId }: Props) => {
|
||||
|
||||
export default memo(NodeNotesEdit);
|
||||
|
||||
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
||||
}
|
||||
type TooltipContentProps = Props;
|
||||
|
||||
const TooltipContent = (props: TooltipContentProps) => {
|
||||
return (
|
||||
<Flex sx={{ flexDir: 'column' }}>
|
||||
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
|
||||
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text>
|
||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||
{nodeTemplate?.description}
|
||||
{props.nodeTemplate?.description}
|
||||
</Text>
|
||||
{data?.notes && <Text>{data.notes}</Text>}
|
||||
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipContent.displayName = 'TooltipContent';
|
||||
|
||||
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const data = useNodeData(nodeId);
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>Notes</FormLabel>
|
||||
<IAITextarea
|
||||
value={data?.notes}
|
||||
onChange={handleNotesChanged}
|
||||
rows={10}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
NotesTextarea.displayName = 'NodesTextarea';
|
||||
};
|
||||
|
@ -11,12 +11,17 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
NodeExecutionState,
|
||||
NodeStatus,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
};
|
||||
|
||||
const iconBoxSize = 3;
|
||||
@ -28,7 +33,8 @@ const circleStyles = {
|
||||
'.chakra-progress__track': { stroke: 'transparent' },
|
||||
};
|
||||
|
||||
const NodeStatusIndicator = ({ nodeId }: Props) => {
|
||||
const NodeStatusIndicator = (props: Props) => {
|
||||
const nodeId = props.nodeProps.data.id;
|
||||
const selectNodeExecutionState = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
@ -70,7 +76,7 @@ type TooltipLabelProps = {
|
||||
nodeExecutionState: NodeExecutionState;
|
||||
};
|
||||
|
||||
const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
const { status, progress, progressImage } = nodeExecutionState;
|
||||
if (status === NodeStatus.PENDING) {
|
||||
return <Text>Pending</Text>;
|
||||
@ -112,15 +118,13 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
TooltipLabel.displayName = 'TooltipLabel';
|
||||
};
|
||||
|
||||
type StatusIconProps = {
|
||||
nodeExecutionState: NodeExecutionState;
|
||||
};
|
||||
|
||||
const StatusIcon = memo((props: StatusIconProps) => {
|
||||
const StatusIcon = (props: StatusIconProps) => {
|
||||
const { progress, status } = props.nodeExecutionState;
|
||||
if (status === NodeStatus.PENDING) {
|
||||
return (
|
||||
@ -178,6 +182,4 @@ const StatusIcon = memo((props: StatusIconProps) => {
|
||||
);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
StatusIcon.displayName = 'StatusIcon';
|
||||
};
|
||||
|
@ -7,29 +7,26 @@ import {
|
||||
useEditableControls,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
useNodeLabel,
|
||||
useNodeTemplateTitle,
|
||||
} from 'features/nodes/hooks/useNodeData';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { NodeData } from 'features/nodes/types/types';
|
||||
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
title?: string;
|
||||
nodeData: NodeData;
|
||||
title: string;
|
||||
};
|
||||
|
||||
const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const NodeTitle = (props: Props) => {
|
||||
const { title } = props;
|
||||
const { id: nodeId, label } = props.nodeData;
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const [localTitle, setLocalTitle] = useState(label || title);
|
||||
|
||||
const [localTitle, setLocalTitle] = useState('');
|
||||
const handleSubmit = useCallback(
|
||||
async (newTitle: string) => {
|
||||
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
|
||||
setLocalTitle(newTitle || title || 'Problem Setting Title');
|
||||
setLocalTitle(newTitle || title);
|
||||
},
|
||||
[nodeId, dispatch, title]
|
||||
);
|
||||
@ -40,8 +37,8 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
|
||||
useEffect(() => {
|
||||
// Another component may change the title; sync local title with global state
|
||||
setLocalTitle(label || title || templateTitle || 'Problem Setting Title');
|
||||
}, [label, templateTitle, title]);
|
||||
setLocalTitle(label || title);
|
||||
}, [label, title]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
|
@ -6,14 +6,10 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { nodeClicked } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
MouseEvent,
|
||||
PropsWithChildren,
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react';
|
||||
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react';
|
||||
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
|
||||
import { NodeData } from 'features/nodes/types/types';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
const useNodeSelect = (nodeId: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -29,13 +25,14 @@ const useNodeSelect = (nodeId: string) => {
|
||||
};
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
nodeId: string;
|
||||
selected: boolean;
|
||||
nodeProps: NodeProps<NodeData>;
|
||||
width?: NonNullable<ChakraProps['sx']>['w'];
|
||||
};
|
||||
|
||||
const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { width, children, nodeId, selected } = props;
|
||||
const { width, children, nodeProps } = props;
|
||||
const { data, selected } = nodeProps;
|
||||
const nodeId = data.id;
|
||||
|
||||
const [
|
||||
nodeSelectedOutlineLight,
|
||||
@ -96,4 +93,4 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeWrapper);
|
||||
export default NodeWrapper;
|
||||
|
@ -1,26 +1,20 @@
|
||||
import { Box, Flex, Text } from '@chakra-ui/react';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { InvocationNodeData } from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import NodeCollapseButton from '../Invocation/NodeCollapseButton';
|
||||
import NodeWrapper from '../Invocation/NodeWrapper';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
isOpen: boolean;
|
||||
label: string;
|
||||
type: string;
|
||||
selected: boolean;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
};
|
||||
|
||||
const UnknownNodeFallback = ({
|
||||
nodeId,
|
||||
isOpen,
|
||||
label,
|
||||
type,
|
||||
selected,
|
||||
}: Props) => {
|
||||
const UnknownNodeFallback = ({ nodeProps }: Props) => {
|
||||
const { data } = nodeProps;
|
||||
const { isOpen, label, type } = data;
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<NodeWrapper nodeProps={nodeProps}>
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
layerStyle="nodeHeader"
|
||||
@ -33,7 +27,7 @@ const UnknownNodeFallback = ({
|
||||
fontSize: 'sm',
|
||||
}}
|
||||
>
|
||||
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
||||
<NodeCollapseButton nodeProps={nodeProps} />
|
||||
<Text
|
||||
sx={{
|
||||
w: 'full',
|
||||
|
@ -46,6 +46,7 @@ const NodeEditor = () => {
|
||||
<AnimatePresence>
|
||||
{isReady && (
|
||||
<motion.div
|
||||
layoutId="node-editor-flow"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
@ -66,6 +67,7 @@ const NodeEditor = () => {
|
||||
<AnimatePresence>
|
||||
{!isReady && (
|
||||
<motion.div
|
||||
layoutId="node-editor-loading"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
|
@ -15,7 +15,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { ChangeEvent, useCallback } from 'react';
|
||||
import { FaCog } from 'react-icons/fa';
|
||||
import {
|
||||
shouldAnimateEdgesChanged,
|
||||
@ -23,26 +23,21 @@ import {
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
} from '../store/nodesSlice';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
shouldSnapToGrid,
|
||||
shouldColorEdges,
|
||||
} = nodes;
|
||||
return {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
shouldSnapToGrid,
|
||||
shouldColorEdges,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
const selector = createSelector(stateSelector, ({ nodes }) => {
|
||||
const {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
shouldSnapToGrid,
|
||||
shouldColorEdges,
|
||||
} = nodes;
|
||||
return {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
shouldSnapToGrid,
|
||||
shouldColorEdges,
|
||||
};
|
||||
});
|
||||
|
||||
const NodeEditorSettings = () => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
@ -141,4 +136,4 @@ const NodeEditorSettings = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeEditorSettings);
|
||||
export default NodeEditorSettings;
|
||||
|
@ -1,12 +1,19 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, HandleType, Position } from 'reactflow';
|
||||
import { Handle, HandleType, NodeProps, Position } from 'reactflow';
|
||||
import {
|
||||
FIELDS,
|
||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||
colorTokenToCssVar,
|
||||
} from '../../types/constants';
|
||||
import { InputFieldTemplate, OutputFieldTemplate } from '../../types/types';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
OutputFieldTemplate,
|
||||
OutputFieldValue,
|
||||
} from '../../types/types';
|
||||
|
||||
export const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
@ -25,6 +32,9 @@ export const outputHandleStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
type FieldHandleProps = {
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: InputFieldValue | OutputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
|
||||
handleType: HandleType;
|
||||
isConnectionInProgress: boolean;
|
||||
|
@ -8,11 +8,13 @@ import {
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIDraggable from 'common/components/IAIDraggable';
|
||||
import { NodeFieldDraggableData } from 'features/dnd/types';
|
||||
import {
|
||||
useFieldData,
|
||||
useFieldTemplate,
|
||||
} from 'features/nodes/hooks/useNodeData';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
MouseEvent,
|
||||
memo,
|
||||
@ -23,43 +25,41 @@ import {
|
||||
} from 'react';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
nodeData: InvocationNodeData;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: InputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate;
|
||||
isDraggable?: boolean;
|
||||
kind: 'input' | 'output';
|
||||
}
|
||||
|
||||
const FieldTitle = (props: Props) => {
|
||||
const { nodeId, fieldName, isDraggable = false, kind } = props;
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
|
||||
const { nodeData, field, fieldTemplate, isDraggable = false } = props;
|
||||
const { label } = field;
|
||||
const { title, input } = fieldTemplate;
|
||||
const { id: nodeId } = nodeData;
|
||||
const dispatch = useAppDispatch();
|
||||
const [localTitle, setLocalTitle] = useState(
|
||||
field?.label || fieldTemplate?.title || 'Unknown Field'
|
||||
);
|
||||
const [localTitle, setLocalTitle] = useState(label || title);
|
||||
|
||||
const draggableData: NodeFieldDraggableData | undefined = useMemo(
|
||||
() =>
|
||||
field &&
|
||||
fieldTemplate?.fieldKind === 'input' &&
|
||||
fieldTemplate?.input !== 'connection' &&
|
||||
isDraggable
|
||||
input !== 'connection' && isDraggable
|
||||
? {
|
||||
id: `${nodeId}-${fieldName}`,
|
||||
id: `${nodeId}-${field.name}`,
|
||||
payloadType: 'NODE_FIELD',
|
||||
payload: { nodeId, field, fieldTemplate },
|
||||
}
|
||||
: undefined,
|
||||
[field, fieldName, fieldTemplate, isDraggable, nodeId]
|
||||
[field, fieldTemplate, input, isDraggable, nodeId]
|
||||
);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (newTitle: string) => {
|
||||
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
|
||||
setLocalTitle(newTitle || fieldTemplate?.title || 'Unknown Field');
|
||||
dispatch(
|
||||
fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle })
|
||||
);
|
||||
setLocalTitle(newTitle || title);
|
||||
},
|
||||
[dispatch, nodeId, fieldName, fieldTemplate?.title]
|
||||
[dispatch, nodeId, field.name, title]
|
||||
);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
@ -68,8 +68,8 @@ const FieldTitle = (props: Props) => {
|
||||
|
||||
useEffect(() => {
|
||||
// Another component may change the title; sync local title with global state
|
||||
setLocalTitle(field?.label || fieldTemplate?.title || 'Unknown Field');
|
||||
}, [field?.label, fieldTemplate?.title]);
|
||||
setLocalTitle(label || title);
|
||||
}, [label, title]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -120,7 +120,7 @@ type EditableControlsProps = {
|
||||
draggableData?: NodeFieldDraggableData;
|
||||
};
|
||||
|
||||
const EditableControls = memo((props: EditableControlsProps) => {
|
||||
function EditableControls(props: EditableControlsProps) {
|
||||
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||
const handleDoubleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
@ -158,6 +158,4 @@ const EditableControls = memo((props: EditableControlsProps) => {
|
||||
cursor="text"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
EditableControls.displayName = 'EditableControls';
|
||||
}
|
||||
|
@ -1,53 +1,38 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import {
|
||||
useFieldData,
|
||||
useFieldTemplate,
|
||||
} from 'features/nodes/hooks/useNodeData';
|
||||
import { FIELDS } from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
OutputFieldTemplate,
|
||||
OutputFieldValue,
|
||||
isInputFieldTemplate,
|
||||
isInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
nodeData: InvocationNodeData;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: InputFieldValue | OutputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
|
||||
}
|
||||
|
||||
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const FieldTooltipContent = ({ field, fieldTemplate }: Props) => {
|
||||
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
|
||||
const fieldTitle = useMemo(() => {
|
||||
if (isInputFieldValue(field)) {
|
||||
if (field.label && fieldTemplate) {
|
||||
return `${field.label} (${fieldTemplate.title})`;
|
||||
}
|
||||
|
||||
if (field.label && !fieldTemplate) {
|
||||
return field.label;
|
||||
}
|
||||
|
||||
if (!field.label && fieldTemplate) {
|
||||
return fieldTemplate.title;
|
||||
}
|
||||
|
||||
return 'Unknown Field';
|
||||
}
|
||||
}, [field, fieldTemplate]);
|
||||
|
||||
return (
|
||||
<Flex sx={{ flexDir: 'column' }}>
|
||||
<Text sx={{ fontWeight: 600 }}>{fieldTitle}</Text>
|
||||
{fieldTemplate && (
|
||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||
{fieldTemplate.description}
|
||||
</Text>
|
||||
)}
|
||||
{fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
|
||||
<Text sx={{ fontWeight: 600 }}>
|
||||
{isInputFieldValue(field) && field.label
|
||||
? `${field.label} (${fieldTemplate.title})`
|
||||
: fieldTemplate.title}
|
||||
</Text>
|
||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||
{fieldTemplate.description}
|
||||
</Text>
|
||||
<Text>Type: {FIELDS[fieldTemplate.type].title}</Text>
|
||||
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,24 +1,27 @@
|
||||
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import {
|
||||
useDoesInputHaveValue,
|
||||
useFieldTemplate,
|
||||
} from 'features/nodes/hooks/useNodeData';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import { PropsWithChildren, memo, useMemo } from 'react';
|
||||
import {
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { PropsWithChildren, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import FieldHandle from './FieldHandle';
|
||||
import FieldTitle from './FieldTitle';
|
||||
import FieldTooltipContent from './FieldTooltipContent';
|
||||
import InputFieldRenderer from './InputFieldRenderer';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: InputFieldValue;
|
||||
}
|
||||
|
||||
const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
||||
const InputField = (props: Props) => {
|
||||
const { nodeProps, nodeTemplate, field } = props;
|
||||
const { id: nodeId } = nodeProps.data;
|
||||
|
||||
const {
|
||||
isConnected,
|
||||
@ -26,10 +29,15 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
isConnectionStartField,
|
||||
connectionError,
|
||||
shouldDim,
|
||||
} = useConnectionState({ nodeId, fieldName, kind: 'input' });
|
||||
} = useConnectionState({ nodeId, field, kind: 'input' });
|
||||
|
||||
const fieldTemplate = useMemo(
|
||||
() => nodeTemplate.inputs[field.name],
|
||||
[field.name, nodeTemplate.inputs]
|
||||
);
|
||||
|
||||
const isMissingInput = useMemo(() => {
|
||||
if (fieldTemplate?.fieldKind !== 'input') {
|
||||
if (!fieldTemplate) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -41,18 +49,18 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!doesFieldHaveValue && !isConnected && fieldTemplate.input === 'any') {
|
||||
if (!field.value && !isConnected && fieldTemplate.input === 'any') {
|
||||
return true;
|
||||
}
|
||||
}, [fieldTemplate, isConnected, doesFieldHaveValue]);
|
||||
}, [fieldTemplate, isConnected, field.value]);
|
||||
|
||||
if (fieldTemplate?.fieldKind !== 'input') {
|
||||
if (!fieldTemplate) {
|
||||
return (
|
||||
<InputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl
|
||||
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
|
||||
>
|
||||
Unknown input: {fieldName}
|
||||
Unknown input: {field.name}
|
||||
</FormControl>
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
@ -74,9 +82,10 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
<Tooltip
|
||||
label={
|
||||
<FieldTooltipContent
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
nodeData={nodeProps.data}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
@ -86,18 +95,27 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
>
|
||||
<FormLabel sx={{ mb: 0 }}>
|
||||
<FieldTitle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
nodeData={nodeProps.data}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isDraggable
|
||||
/>
|
||||
</FormLabel>
|
||||
</Tooltip>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputFieldRenderer
|
||||
nodeData={nodeProps.data}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
</FormControl>
|
||||
|
||||
{fieldTemplate.input !== 'direct' && (
|
||||
<FieldHandle
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
handleType="target"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
@ -115,25 +133,21 @@ type InputFieldWrapperProps = PropsWithChildren<{
|
||||
shouldDim: boolean;
|
||||
}>;
|
||||
|
||||
const InputFieldWrapper = memo(
|
||||
({ shouldDim, children }: InputFieldWrapperProps) => (
|
||||
<Flex
|
||||
className="nopan"
|
||||
sx={{
|
||||
position: 'relative',
|
||||
minH: 8,
|
||||
py: 0.5,
|
||||
alignItems: 'center',
|
||||
opacity: shouldDim ? 0.5 : 1,
|
||||
transitionProperty: 'opacity',
|
||||
transitionDuration: '0.1s',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
)
|
||||
const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => (
|
||||
<Flex
|
||||
className="nopan"
|
||||
sx={{
|
||||
position: 'relative',
|
||||
minH: 8,
|
||||
py: 0.5,
|
||||
alignItems: 'center',
|
||||
opacity: shouldDim ? 0.5 : 1,
|
||||
transitionProperty: 'opacity',
|
||||
transitionDuration: '0.1s',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
|
||||
InputFieldWrapper.displayName = 'InputFieldWrapper';
|
||||
|
@ -1,9 +1,11 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import {
|
||||
useFieldData,
|
||||
useFieldTemplate,
|
||||
} from 'features/nodes/hooks/useNodeData';
|
||||
import { memo } from 'react';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from '../../types/types';
|
||||
import BooleanInputField from './fieldTypes/BooleanInputField';
|
||||
import ClipInputField from './fieldTypes/ClipInputField';
|
||||
import CollectionInputField from './fieldTypes/CollectionInputField';
|
||||
@ -27,33 +29,33 @@ import VaeInputField from './fieldTypes/VaeInputField';
|
||||
import VaeModelInputField from './fieldTypes/VaeModelInputField';
|
||||
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
nodeData: InvocationNodeData;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: InputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate;
|
||||
};
|
||||
|
||||
// build an individual input element based on the schema
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
const InputFieldRenderer = (props: InputFieldProps) => {
|
||||
const { nodeData, nodeTemplate, field, fieldTemplate } = props;
|
||||
const { type } = field;
|
||||
|
||||
if (fieldTemplate?.fieldKind === 'output') {
|
||||
return <Box p={2}>Output field in input: {field?.type}</Box>;
|
||||
}
|
||||
|
||||
if (field?.type === 'string' && fieldTemplate?.type === 'string') {
|
||||
if (type === 'string' && fieldTemplate.type === 'string') {
|
||||
return (
|
||||
<StringInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') {
|
||||
if (type === 'boolean' && fieldTemplate.type === 'boolean') {
|
||||
return (
|
||||
<BooleanInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
@ -61,45 +63,46 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
}
|
||||
|
||||
if (
|
||||
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
|
||||
(field?.type === 'float' && fieldTemplate?.type === 'float')
|
||||
(type === 'integer' && fieldTemplate.type === 'integer') ||
|
||||
(type === 'float' && fieldTemplate.type === 'float')
|
||||
) {
|
||||
return (
|
||||
<NumberInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
|
||||
if (type === 'enum' && fieldTemplate.type === 'enum') {
|
||||
return (
|
||||
<EnumInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') {
|
||||
if (type === 'ImageField' && fieldTemplate.type === 'ImageField') {
|
||||
return (
|
||||
<ImageInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'LatentsField' &&
|
||||
fieldTemplate?.type === 'LatentsField'
|
||||
) {
|
||||
if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') {
|
||||
return (
|
||||
<LatentsInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
@ -107,68 +110,68 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ConditioningField' &&
|
||||
fieldTemplate?.type === 'ConditioningField'
|
||||
type === 'ConditioningField' &&
|
||||
fieldTemplate.type === 'ConditioningField'
|
||||
) {
|
||||
return (
|
||||
<ConditioningInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
|
||||
if (type === 'UNetField' && fieldTemplate.type === 'UNetField') {
|
||||
return (
|
||||
<UnetInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
|
||||
if (type === 'ClipField' && fieldTemplate.type === 'ClipField') {
|
||||
return (
|
||||
<ClipInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
|
||||
if (type === 'VaeField' && fieldTemplate.type === 'VaeField') {
|
||||
return (
|
||||
<VaeInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ControlField' &&
|
||||
fieldTemplate?.type === 'ControlField'
|
||||
) {
|
||||
if (type === 'ControlField' && fieldTemplate.type === 'ControlField') {
|
||||
return (
|
||||
<ControlInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'MainModelField' &&
|
||||
fieldTemplate?.type === 'MainModelField'
|
||||
) {
|
||||
if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') {
|
||||
return (
|
||||
<MainModelInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
@ -176,38 +179,35 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLRefinerModelField' &&
|
||||
fieldTemplate?.type === 'SDXLRefinerModelField'
|
||||
type === 'SDXLRefinerModelField' &&
|
||||
fieldTemplate.type === 'SDXLRefinerModelField'
|
||||
) {
|
||||
return (
|
||||
<RefinerModelInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'VaeModelField' &&
|
||||
fieldTemplate?.type === 'VaeModelField'
|
||||
) {
|
||||
if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') {
|
||||
return (
|
||||
<VaeModelInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'LoRAModelField' &&
|
||||
fieldTemplate?.type === 'LoRAModelField'
|
||||
) {
|
||||
if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') {
|
||||
return (
|
||||
<LoRAModelInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
@ -215,58 +215,57 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ControlNetModelField' &&
|
||||
fieldTemplate?.type === 'ControlNetModelField'
|
||||
type === 'ControlNetModelField' &&
|
||||
fieldTemplate.type === 'ControlNetModelField'
|
||||
) {
|
||||
return (
|
||||
<ControlNetModelInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
|
||||
if (type === 'Collection' && fieldTemplate.type === 'Collection') {
|
||||
return (
|
||||
<CollectionInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'CollectionItem' &&
|
||||
fieldTemplate?.type === 'CollectionItem'
|
||||
) {
|
||||
if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') {
|
||||
return (
|
||||
<CollectionItemInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
if (type === 'ColorField' && fieldTemplate.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ImageCollection' &&
|
||||
fieldTemplate?.type === 'ImageCollection'
|
||||
) {
|
||||
if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') {
|
||||
return (
|
||||
<ImageCollectionInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
@ -274,19 +273,20 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLMainModelField' &&
|
||||
fieldTemplate?.type === 'SDXLMainModelField'
|
||||
type === 'SDXLMainModelField' &&
|
||||
fieldTemplate.type === 'SDXLMainModelField'
|
||||
) {
|
||||
return (
|
||||
<SDXLMainModelInputField
|
||||
nodeId={nodeId}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box p={2}>Unknown field type: {field?.type}</Box>;
|
||||
return <Box p={2}>Unknown field type: {type}</Box>;
|
||||
};
|
||||
|
||||
export default memo(InputFieldRenderer);
|
||||
|
@ -1,16 +1,39 @@
|
||||
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import FieldTitle from './FieldTitle';
|
||||
import FieldTooltipContent from './FieldTooltipContent';
|
||||
import InputFieldRenderer from './InputFieldRenderer';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
nodeData: InvocationNodeData;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: InputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate;
|
||||
};
|
||||
|
||||
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
const LinearViewField = ({
|
||||
nodeData,
|
||||
nodeTemplate,
|
||||
field,
|
||||
fieldTemplate,
|
||||
}: Props) => {
|
||||
// const dispatch = useAppDispatch();
|
||||
// const handleRemoveField = useCallback(() => {
|
||||
// dispatch(
|
||||
// workflowExposedFieldRemoved({
|
||||
// nodeId: nodeData.id,
|
||||
// fieldName: field.name,
|
||||
// })
|
||||
// );
|
||||
// }, [dispatch, field.name, nodeData.id]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
@ -25,9 +48,10 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
<Tooltip
|
||||
label={
|
||||
<FieldTooltipContent
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
@ -42,10 +66,20 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
mb: 0,
|
||||
}}
|
||||
>
|
||||
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
|
||||
<FieldTitle
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
</FormLabel>
|
||||
</Tooltip>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputFieldRenderer
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -6,19 +6,25 @@ import {
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useNodeData';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import { PropsWithChildren, memo } from 'react';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { PropsWithChildren, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import FieldHandle from './FieldHandle';
|
||||
import FieldTooltipContent from './FieldTooltipContent';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: OutputFieldValue;
|
||||
}
|
||||
|
||||
const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'output');
|
||||
const OutputField = (props: Props) => {
|
||||
const { nodeTemplate, nodeProps, field } = props;
|
||||
|
||||
const {
|
||||
isConnected,
|
||||
@ -26,15 +32,20 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
isConnectionStartField,
|
||||
connectionError,
|
||||
shouldDim,
|
||||
} = useConnectionState({ nodeId, fieldName, kind: 'output' });
|
||||
} = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' });
|
||||
|
||||
if (fieldTemplate?.fieldKind !== 'output') {
|
||||
const fieldTemplate = useMemo(
|
||||
() => nodeTemplate.outputs[field.name],
|
||||
[field.name, nodeTemplate]
|
||||
);
|
||||
|
||||
if (!fieldTemplate) {
|
||||
return (
|
||||
<OutputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl
|
||||
sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }}
|
||||
>
|
||||
Unknown output: {fieldName}
|
||||
Unknown output: {field.name}
|
||||
</FormControl>
|
||||
</OutputFieldWrapper>
|
||||
);
|
||||
@ -46,9 +57,10 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
<Tooltip
|
||||
label={
|
||||
<FieldTooltipContent
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="output"
|
||||
nodeData={nodeProps.data}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
@ -63,6 +75,9 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<FieldHandle
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
handleType="source"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
@ -73,28 +88,27 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(OutputField);
|
||||
export default OutputField;
|
||||
|
||||
type OutputFieldWrapperProps = PropsWithChildren<{
|
||||
shouldDim: boolean;
|
||||
}>;
|
||||
|
||||
const OutputFieldWrapper = memo(
|
||||
({ shouldDim, children }: OutputFieldWrapperProps) => (
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'relative',
|
||||
minH: 8,
|
||||
py: 0.5,
|
||||
alignItems: 'center',
|
||||
opacity: shouldDim ? 0.5 : 1,
|
||||
transitionProperty: 'opacity',
|
||||
transitionDuration: '0.1s',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
)
|
||||
const OutputFieldWrapper = ({
|
||||
shouldDim,
|
||||
children,
|
||||
}: OutputFieldWrapperProps) => (
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'relative',
|
||||
minH: 8,
|
||||
py: 0.5,
|
||||
alignItems: 'center',
|
||||
opacity: shouldDim ? 0.5 : 1,
|
||||
transitionProperty: 'opacity',
|
||||
transitionDuration: '0.1s',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
|
||||
OutputFieldWrapper.displayName = 'OutputFieldWrapper';
|
||||
|
@ -11,7 +11,8 @@ import { FieldComponentProps } from './types';
|
||||
const BooleanInputFieldComponent = (
|
||||
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
|
@ -11,7 +11,8 @@ import { FieldComponentProps } from './types';
|
||||
const ColorInputFieldComponent = (
|
||||
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
|
@ -19,7 +19,8 @@ const ControlNetModelInputFieldComponent = (
|
||||
ControlNetModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const controlNetModel = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
|
@ -11,7 +11,8 @@ import { FieldComponentProps } from './types';
|
||||
const EnumInputFieldComponent = (
|
||||
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const { nodeData, field, fieldTemplate } = props;
|
||||
const nodeId = nodeData.id;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
|
@ -19,7 +19,8 @@ const ImageCollectionInputFieldComponent = (
|
||||
ImageCollectionInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
|
||||
// const dispatch = useAppDispatch();
|
||||
|
||||
|
@ -21,7 +21,8 @@ import { FieldComponentProps } from './types';
|
||||
const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(
|
||||
|
@ -21,7 +21,8 @@ const LoRAModelInputFieldComponent = (
|
||||
LoRAModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const lora = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||
|
@ -26,7 +26,8 @@ const MainModelInputFieldComponent = (
|
||||
MainModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const dispatch = useAppDispatch();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
|
@ -23,7 +23,8 @@ const NumberInputFieldComponent = (
|
||||
IntegerInputFieldTemplate | FloatInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const { nodeData, field, fieldTemplate } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const dispatch = useAppDispatch();
|
||||
const [valueAsString, setValueAsString] = useState<string>(
|
||||
String(field.value)
|
||||
|
@ -24,7 +24,8 @@ const RefinerModelInputFieldComponent = (
|
||||
SDXLRefinerModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
@ -27,7 +27,8 @@ const ModelInputFieldComponent = (
|
||||
SDXLMainModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
@ -12,7 +12,8 @@ import { FieldComponentProps } from './types';
|
||||
const StringInputFieldComponent = (
|
||||
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const { nodeData, field, fieldTemplate } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
|
@ -20,7 +20,8 @@ const VaeModelInputFieldComponent = (
|
||||
VaeModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { nodeData, field } = props;
|
||||
const nodeId = nodeData.id;
|
||||
const vae = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||
|
@ -1,13 +1,16 @@
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
|
||||
export type FieldComponentProps<
|
||||
V extends InputFieldValue,
|
||||
T extends InputFieldTemplate
|
||||
> = {
|
||||
nodeId: string;
|
||||
nodeData: InvocationNodeData;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: V;
|
||||
fieldTemplate: T;
|
||||
};
|
||||
|
@ -55,11 +55,7 @@ const CurrentImageNode = (props: NodeProps) => {
|
||||
export default memo(CurrentImageNode);
|
||||
|
||||
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
|
||||
<NodeWrapper
|
||||
nodeId={props.nodeProps.data.id}
|
||||
selected={props.nodeProps.selected}
|
||||
width={384}
|
||||
>
|
||||
<NodeWrapper nodeProps={props.nodeProps} width={384}>
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={{
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { makeTemplateSelector } from 'features/nodes/store/util/makeTemplateSelector';
|
||||
import { InvocationNodeData } from 'features/nodes/types/types';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
@ -8,40 +7,18 @@ import InvocationNode from '../Invocation/InvocationNode';
|
||||
import UnknownNodeFallback from '../Invocation/UnknownNodeFallback';
|
||||
|
||||
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
|
||||
const { data, selected } = props;
|
||||
const { id: nodeId, type, isOpen, label } = data;
|
||||
const { data } = props;
|
||||
const { type } = data;
|
||||
|
||||
const hasTemplateSelector = useMemo(
|
||||
() =>
|
||||
createSelector(stateSelector, ({ nodes }) =>
|
||||
Boolean(nodes.nodeTemplates[type])
|
||||
),
|
||||
[type]
|
||||
);
|
||||
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
|
||||
|
||||
const nodeTemplate = useAppSelector(hasTemplateSelector);
|
||||
const nodeTemplate = useAppSelector(templateSelector);
|
||||
|
||||
if (!nodeTemplate) {
|
||||
return (
|
||||
<UnknownNodeFallback
|
||||
nodeId={nodeId}
|
||||
isOpen={isOpen}
|
||||
label={label}
|
||||
type={type}
|
||||
selected={selected}
|
||||
/>
|
||||
);
|
||||
return <UnknownNodeFallback nodeProps={props} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<InvocationNode
|
||||
nodeId={nodeId}
|
||||
isOpen={isOpen}
|
||||
label={label}
|
||||
type={type}
|
||||
selected={selected}
|
||||
/>
|
||||
);
|
||||
return <InvocationNode nodeProps={props} nodeTemplate={nodeTemplate} />;
|
||||
};
|
||||
|
||||
export default memo(InvocationNodeWrapper);
|
||||
|
@ -10,7 +10,7 @@ import NodeTitle from '../Invocation/NodeTitle';
|
||||
import NodeWrapper from '../Invocation/NodeWrapper';
|
||||
|
||||
const NotesNode = (props: NodeProps<NotesNodeData>) => {
|
||||
const { id: nodeId, data, selected } = props;
|
||||
const { id: nodeId, data } = props;
|
||||
const { notes, isOpen } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const handleChange = useCallback(
|
||||
@ -21,7 +21,7 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
|
||||
);
|
||||
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<NodeWrapper nodeProps={props}>
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
sx={{
|
||||
@ -32,8 +32,8 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
|
||||
h: 8,
|
||||
}}
|
||||
>
|
||||
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
||||
<NodeTitle nodeId={nodeId} title="Notes" />
|
||||
<NodeCollapseButton nodeProps={props} />
|
||||
<NodeTitle nodeData={props.data} title="Notes" />
|
||||
<Box minW={8} />
|
||||
</Flex>
|
||||
{isOpen && (
|
||||
|
@ -6,11 +6,39 @@ import {
|
||||
TabPanels,
|
||||
Tabs,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
|
||||
import { memo } from 'react';
|
||||
import NodeDataInspector from './NodeDataInspector';
|
||||
import NodeTemplateInspector from './NodeTemplateInspector';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const lastSelectedNodeId =
|
||||
nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find(
|
||||
(node) => node.id === lastSelectedNodeId
|
||||
);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode
|
||||
? nodes.nodeTemplates[lastSelectedNode.data.type]
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
node: lastSelectedNode,
|
||||
template: lastSelectedNodeTemplate,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const InspectorPanel = () => {
|
||||
const { node, template } = useAppSelector(selector);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="first"
|
||||
@ -32,10 +60,37 @@ const InspectorPanel = () => {
|
||||
|
||||
<TabPanels>
|
||||
<TabPanel>
|
||||
<NodeTemplateInspector />
|
||||
{template ? (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
alignItems: 'flex-start',
|
||||
gap: 2,
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<ImageMetadataJSON
|
||||
jsonObject={template}
|
||||
label="Node Template"
|
||||
/>
|
||||
</Flex>
|
||||
) : (
|
||||
<IAINoContentFallback
|
||||
label={
|
||||
node
|
||||
? 'No template found for selected node'
|
||||
: 'No node selected'
|
||||
}
|
||||
icon={null}
|
||||
/>
|
||||
)}
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<NodeDataInspector />
|
||||
{node ? (
|
||||
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
|
||||
) : (
|
||||
<IAINoContentFallback label="No node selected" icon={null} />
|
||||
)}
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
|
@ -17,20 +17,20 @@ const selector = createSelector(
|
||||
);
|
||||
|
||||
return {
|
||||
data: lastSelectedNode?.data,
|
||||
node: lastSelectedNode,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const NodeDataInspector = () => {
|
||||
const { data } = useAppSelector(selector);
|
||||
const { node } = useAppSelector(selector);
|
||||
|
||||
if (!data) {
|
||||
return <IAINoContentFallback label="No node selected" icon={null} />;
|
||||
}
|
||||
|
||||
return <ImageMetadataJSON jsonObject={data} label="Node Data" />;
|
||||
return node ? (
|
||||
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
|
||||
) : (
|
||||
<IAINoContentFallback label="No node data" icon={null} />
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeDataInspector);
|
||||
|
@ -1,40 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const lastSelectedNodeId =
|
||||
nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find(
|
||||
(node) => node.id === lastSelectedNodeId
|
||||
);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode
|
||||
? nodes.nodeTemplates[lastSelectedNode.data.type]
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
template: lastSelectedNodeTemplate,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const NodeTemplateInspector = () => {
|
||||
const { template } = useAppSelector(selector);
|
||||
|
||||
if (!template) {
|
||||
return <IAINoContentFallback label="No node selected" icon={null} />;
|
||||
}
|
||||
|
||||
return <ImageMetadataJSON jsonObject={template} label="Node Template" />;
|
||||
};
|
||||
|
||||
export default memo(NodeTemplateInspector);
|
@ -6,6 +6,14 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { AddFieldToLinearViewDropData } from 'features/dnd/types';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
isInvocationNode,
|
||||
} from 'features/nodes/types/types';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import LinearViewField from '../../fields/LinearViewField';
|
||||
import ScrollableContent from '../ScrollableContent';
|
||||
@ -13,8 +21,41 @@ import ScrollableContent from '../ScrollableContent';
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const fields: {
|
||||
nodeData: InvocationNodeData;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
field: InputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate;
|
||||
}[] = [];
|
||||
const { exposedFields } = nodes.workflow;
|
||||
nodes.nodes.filter(isInvocationNode).forEach((node) => {
|
||||
const nodeTemplate = nodes.nodeTemplates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
return;
|
||||
}
|
||||
forEach(node.data.inputs, (field) => {
|
||||
if (
|
||||
!exposedFields.some(
|
||||
(f) => f.nodeId === node.id && f.fieldName === field.name
|
||||
)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
if (!fieldTemplate) {
|
||||
return;
|
||||
}
|
||||
fields.push({
|
||||
nodeData: node.data,
|
||||
nodeTemplate,
|
||||
field,
|
||||
fieldTemplate,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
fields: nodes.workflow.exposedFields,
|
||||
fields,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -48,11 +89,13 @@ const LinearTabContent = () => {
|
||||
}}
|
||||
>
|
||||
{fields.length ? (
|
||||
fields.map(({ nodeId, fieldName }) => (
|
||||
fields.map(({ nodeData, nodeTemplate, field, fieldTemplate }) => (
|
||||
<LinearViewField
|
||||
key={`${nodeId}-${fieldName}`}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
key={field.id}
|
||||
nodeData={nodeData}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
|
@ -25,9 +25,7 @@ const ClearGraphButton = () => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const cancelRef = useRef<HTMLButtonElement | null>(null);
|
||||
|
||||
const nodesCount = useAppSelector(
|
||||
(state: RootState) => state.nodes.nodes.length
|
||||
);
|
||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||
|
||||
const handleConfirmClear = useCallback(() => {
|
||||
dispatch(nodeEditorReset());
|
||||
@ -51,7 +49,7 @@ const ClearGraphButton = () => {
|
||||
tooltip={t('nodes.clearGraph')}
|
||||
aria-label={t('nodes.clearGraph')}
|
||||
onClick={onOpen}
|
||||
isDisabled={!nodesCount}
|
||||
isDisabled={nodes.length === 0}
|
||||
/>
|
||||
|
||||
<AlertDialog
|
||||
|
@ -8,7 +8,7 @@ import IAIIconButton, {
|
||||
import { selectIsReadyNodes } from 'features/nodes/store/selectors';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlay } from 'react-icons/fa';
|
||||
@ -18,7 +18,7 @@ interface InvokeButton
|
||||
iconButton?: boolean;
|
||||
}
|
||||
|
||||
const NodeInvokeButton = (props: InvokeButton) => {
|
||||
export default function NodeInvokeButton(props: InvokeButton) {
|
||||
const { iconButton = false, ...rest } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
@ -92,6 +92,4 @@ const NodeInvokeButton = (props: InvokeButton) => {
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeInvokeButton);
|
||||
}
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSyncAlt } from 'react-icons/fa';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
|
||||
const ReloadSchemaButton = () => {
|
||||
export default function ReloadSchemaButton() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@ -21,6 +21,4 @@ const ReloadSchemaButton = () => {
|
||||
onClick={handleReloadSchema}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ReloadSchemaButton);
|
||||
}
|
||||
|
@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { InputFieldValue, OutputFieldValue } from 'features/nodes/types/types';
|
||||
import { useMemo } from 'react';
|
||||
import { useFieldType } from './useNodeData';
|
||||
|
||||
const selectIsConnectionInProgress = createSelector(
|
||||
stateSelector,
|
||||
@ -12,19 +12,23 @@ const selectIsConnectionInProgress = createSelector(
|
||||
nodes.connectionStartParams !== null
|
||||
);
|
||||
|
||||
export type UseConnectionStateProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
};
|
||||
export type UseConnectionStateProps =
|
||||
| {
|
||||
nodeId: string;
|
||||
field: InputFieldValue;
|
||||
kind: 'input';
|
||||
}
|
||||
| {
|
||||
nodeId: string;
|
||||
field: OutputFieldValue;
|
||||
kind: 'output';
|
||||
};
|
||||
|
||||
export const useConnectionState = ({
|
||||
nodeId,
|
||||
fieldName,
|
||||
field,
|
||||
kind,
|
||||
}: UseConnectionStateProps) => {
|
||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||
|
||||
const selectIsConnected = useMemo(
|
||||
() =>
|
||||
createSelector(stateSelector, ({ nodes }) =>
|
||||
@ -33,23 +37,23 @@ export const useConnectionState = ({
|
||||
return (
|
||||
(kind === 'input' ? edge.target : edge.source) === nodeId &&
|
||||
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) ===
|
||||
fieldName
|
||||
field.name
|
||||
);
|
||||
}).length
|
||||
)
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
[field.name, kind, nodeId]
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
() =>
|
||||
makeConnectionErrorSelector(
|
||||
nodeId,
|
||||
fieldName,
|
||||
field.name,
|
||||
kind === 'input' ? 'target' : 'source',
|
||||
fieldType
|
||||
field.type
|
||||
),
|
||||
[nodeId, fieldName, kind, fieldType]
|
||||
[nodeId, field.name, field.type, kind]
|
||||
);
|
||||
|
||||
const selectIsConnectionStartField = useMemo(
|
||||
@ -57,12 +61,12 @@ export const useConnectionState = ({
|
||||
createSelector(stateSelector, ({ nodes }) =>
|
||||
Boolean(
|
||||
nodes.connectionStartParams?.nodeId === nodeId &&
|
||||
nodes.connectionStartParams?.handleId === fieldName &&
|
||||
nodes.connectionStartParams?.handleId === field.name &&
|
||||
nodes.connectionStartParams?.handleType ===
|
||||
{ input: 'target', output: 'source' }[kind]
|
||||
)
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
[field.name, kind, nodeId]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
|
@ -1,286 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map, some } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { FOOTER_FIELDS, IMAGE_FIELDS } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
const KIND_MAP = {
|
||||
input: 'inputs' as const,
|
||||
output: 'outputs' as const,
|
||||
};
|
||||
|
||||
export const useNodeTemplate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
return nodeTemplate;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeTemplate = useAppSelector(selector);
|
||||
|
||||
return nodeTemplate;
|
||||
};
|
||||
|
||||
export const useNodeData = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
return node?.data;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeData = useAppSelector(selector);
|
||||
|
||||
return nodeData;
|
||||
};
|
||||
|
||||
export const useFieldData = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data.inputs[fieldName];
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldData = useAppSelector(selector);
|
||||
|
||||
return fieldData;
|
||||
};
|
||||
|
||||
export const useFieldType = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
kind: 'input' | 'output'
|
||||
) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data[KIND_MAP[kind]][fieldName]?.type;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
const fieldType = useAppSelector(selector);
|
||||
|
||||
return fieldType;
|
||||
};
|
||||
|
||||
export const useFieldNames = (nodeId: string, kind: 'input' | 'output') => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return [];
|
||||
}
|
||||
return map(node.data[KIND_MAP[kind]], (field) => field.name).filter(
|
||||
(fieldName) => fieldName !== 'is_intermediate'
|
||||
);
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[kind, nodeId]
|
||||
);
|
||||
|
||||
const fieldNames = useAppSelector(selector);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
export const useWithFooter = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return some(node.data.outputs, (output) =>
|
||||
FOOTER_FIELDS.includes(output.type)
|
||||
);
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const withFooter = useAppSelector(selector);
|
||||
return withFooter;
|
||||
};
|
||||
|
||||
export const useHasImageOutput = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return some(node.data.outputs, (output) =>
|
||||
IMAGE_FIELDS.includes(output.type)
|
||||
);
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const hasImageOutput = useAppSelector(selector);
|
||||
return hasImageOutput;
|
||||
};
|
||||
|
||||
export const useIsIntermediate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return Boolean(node.data.inputs.is_intermediate?.value);
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const is_intermediate = useAppSelector(selector);
|
||||
return is_intermediate;
|
||||
};
|
||||
|
||||
export const useNodeLabel = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return node.data.label;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const label = useAppSelector(selector);
|
||||
return label;
|
||||
};
|
||||
|
||||
export const useNodeTemplateTitle = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = node
|
||||
? nodes.nodeTemplates[node.data.type]
|
||||
: undefined;
|
||||
|
||||
return nodeTemplate?.title;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const title = useAppSelector(selector);
|
||||
return title;
|
||||
};
|
||||
|
||||
export const useFieldTemplate = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
kind: 'input' | 'output'
|
||||
) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.[KIND_MAP[kind]][fieldName];
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return Boolean(node?.data.inputs[fieldName]?.value);
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const doesFieldHaveValue = useAppSelector(selector);
|
||||
|
||||
return doesFieldHaveValue;
|
||||
};
|
@ -9,13 +9,9 @@ export const makeConnectionErrorSelector = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType
|
||||
fieldType: FieldType
|
||||
) =>
|
||||
createSelector(stateSelector, (state) => {
|
||||
if (!fieldType) {
|
||||
return 'No field type';
|
||||
}
|
||||
|
||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||
state.nodes;
|
||||
|
||||
|
@ -6,9 +6,6 @@ export const NODE_WIDTH = 320;
|
||||
export const NODE_MIN_WIDTH = 320;
|
||||
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
|
||||
|
||||
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
|
||||
export const FOOTER_FIELDS = IMAGE_FIELDS;
|
||||
|
||||
export const COLLECTION_TYPES: FieldType[] = [
|
||||
'Collection',
|
||||
'IntegerCollection',
|
||||
|
@ -457,13 +457,12 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
||||
};
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||
field: InputFieldValue | OutputFieldValue
|
||||
): field is InputFieldValue => field.fieldKind === 'input';
|
||||
|
||||
export const isInputFieldTemplate = (
|
||||
fieldTemplate?: InputFieldTemplate | OutputFieldTemplate
|
||||
): fieldTemplate is InputFieldTemplate =>
|
||||
Boolean(fieldTemplate && fieldTemplate.fieldKind === 'input');
|
||||
fieldTemplate: InputFieldTemplate | OutputFieldTemplate
|
||||
): fieldTemplate is InputFieldTemplate => fieldTemplate.fieldKind === 'input';
|
||||
|
||||
/**
|
||||
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
|
||||
@ -633,22 +632,20 @@ export type NodeData =
|
||||
|
||||
export const isInvocationNode = (
|
||||
node?: Node<NodeData>
|
||||
): node is Node<InvocationNodeData> =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
): node is Node<InvocationNodeData> => node?.type === 'invocation';
|
||||
|
||||
export const isInvocationNodeData = (
|
||||
node?: NodeData
|
||||
): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type));
|
||||
!['notes', 'current_image'].includes(node?.type ?? '');
|
||||
|
||||
export const isNotesNode = (
|
||||
node?: Node<NodeData>
|
||||
): node is Node<NotesNodeData> => Boolean(node && node.type === 'notes');
|
||||
): node is Node<NotesNodeData> => node?.type === 'notes';
|
||||
|
||||
export const isProgressImageNode = (
|
||||
node?: Node<NodeData>
|
||||
): node is Node<CurrentImageNodeData> =>
|
||||
Boolean(node && node.type === 'current_image');
|
||||
): node is Node<CurrentImageNodeData> => node?.type === 'current_image';
|
||||
|
||||
export enum NodeStatus {
|
||||
PENDING,
|
||||
|
@ -32,7 +32,6 @@ import {
|
||||
MAIN_MODEL_LOADER,
|
||||
MASK_BLUR,
|
||||
MASK_COMBINE,
|
||||
MASK_EDGE,
|
||||
MASK_FROM_ALPHA,
|
||||
MASK_RESIZE_DOWN,
|
||||
MASK_RESIZE_UP,
|
||||
@ -41,10 +40,6 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
SEAM_FIX_DENOISE_LATENTS,
|
||||
SEAM_MASK_COMBINE,
|
||||
SEAM_MASK_RESIZE_DOWN,
|
||||
SEAM_MASK_RESIZE_UP,
|
||||
} from './constants';
|
||||
|
||||
/**
|
||||
@ -72,12 +67,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
maskBlurMethod,
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamSteps,
|
||||
seamStrength,
|
||||
seamLowThreshold,
|
||||
seamHighThreshold,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
clipSkip,
|
||||
@ -141,11 +130,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
mask2: canvasMaskImage,
|
||||
},
|
||||
[SEAM_MASK_COMBINE]: {
|
||||
type: 'mask_combine',
|
||||
id: MASK_COMBINE,
|
||||
is_intermediate: true,
|
||||
},
|
||||
[MASK_BLUR]: {
|
||||
type: 'img_blur',
|
||||
id: MASK_BLUR,
|
||||
@ -181,25 +165,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
denoising_start: 1 - strength,
|
||||
denoising_end: 1,
|
||||
},
|
||||
[MASK_EDGE]: {
|
||||
type: 'mask_edge',
|
||||
id: MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_size: seamSize,
|
||||
edge_blur: seamBlur,
|
||||
low_threshold: seamLowThreshold,
|
||||
high_threshold: seamHighThreshold,
|
||||
},
|
||||
[SEAM_FIX_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: SEAM_FIX_DENOISE_LATENTS,
|
||||
is_intermediate: true,
|
||||
steps: seamSteps,
|
||||
cfg_scale: cfg_scale,
|
||||
scheduler: scheduler,
|
||||
denoising_start: 1 - seamStrength,
|
||||
denoising_end: 1,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
@ -368,61 +333,10 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
// Seam Paint
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
// Decode the result from Inpaint
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
@ -434,6 +348,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
};
|
||||
|
||||
// Add Infill Nodes
|
||||
|
||||
if (infillMethod === 'patchmatch') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_patchmatch',
|
||||
@ -463,13 +378,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_UP] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_UP,
|
||||
is_intermediate: true,
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: INPAINT_IMAGE_RESIZE_DOWN,
|
||||
@ -491,13 +399,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_DOWN,
|
||||
is_intermediate: true,
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
|
||||
graph.nodes[NOISE] = {
|
||||
...(graph.nodes[NOISE] as NoiseInvocation),
|
||||
@ -539,57 +440,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
@ -603,7 +453,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -611,16 +461,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
@ -654,7 +494,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -685,7 +525,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -713,6 +553,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
};
|
||||
graph.nodes[MASK_BLUR] = {
|
||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -727,47 +568,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Color Correct The Inpainted Result
|
||||
{
|
||||
source: {
|
||||
@ -791,7 +591,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -822,7 +622,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
|
@ -29,7 +29,6 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
MASK_BLUR,
|
||||
MASK_COMBINE,
|
||||
MASK_EDGE,
|
||||
MASK_FROM_ALPHA,
|
||||
MASK_RESIZE_DOWN,
|
||||
MASK_RESIZE_UP,
|
||||
@ -41,10 +40,6 @@ import {
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
SEAM_FIX_DENOISE_LATENTS,
|
||||
SEAM_MASK_COMBINE,
|
||||
SEAM_MASK_RESIZE_DOWN,
|
||||
SEAM_MASK_RESIZE_UP,
|
||||
} from './constants';
|
||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||
|
||||
@ -72,12 +67,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
maskBlurMethod,
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamSteps,
|
||||
seamStrength,
|
||||
seamLowThreshold,
|
||||
seamHighThreshold,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
} = state.generation;
|
||||
@ -144,11 +133,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
mask2: canvasMaskImage,
|
||||
},
|
||||
[SEAM_MASK_COMBINE]: {
|
||||
type: 'mask_combine',
|
||||
id: MASK_COMBINE,
|
||||
is_intermediate: true,
|
||||
},
|
||||
[MASK_BLUR]: {
|
||||
type: 'img_blur',
|
||||
id: MASK_BLUR,
|
||||
@ -186,25 +170,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
: 1 - strength,
|
||||
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||
},
|
||||
[MASK_EDGE]: {
|
||||
type: 'mask_edge',
|
||||
id: MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_size: seamSize,
|
||||
edge_blur: seamBlur,
|
||||
low_threshold: seamLowThreshold,
|
||||
high_threshold: seamHighThreshold,
|
||||
},
|
||||
[SEAM_FIX_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: SEAM_FIX_DENOISE_LATENTS,
|
||||
is_intermediate: true,
|
||||
steps: seamSteps,
|
||||
cfg_scale: cfg_scale,
|
||||
scheduler: scheduler,
|
||||
denoising_start: 1 - seamStrength,
|
||||
denoising_end: 1,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
@ -382,61 +347,10 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
// Seam Paint
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
// Decode inpainted latents to image
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
@ -478,13 +392,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_UP] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_UP,
|
||||
is_intermediate: true,
|
||||
width: scaledWidth,
|
||||
height: scaledHeight,
|
||||
};
|
||||
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: INPAINT_IMAGE_RESIZE_DOWN,
|
||||
@ -506,13 +413,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
|
||||
type: 'img_resize',
|
||||
id: SEAM_MASK_RESIZE_DOWN,
|
||||
is_intermediate: true,
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
|
||||
graph.nodes[NOISE] = {
|
||||
...(graph.nodes[NOISE] as NoiseInvocation),
|
||||
@ -554,57 +454,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
@ -618,7 +467,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -626,16 +475,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
@ -669,7 +508,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -700,7 +539,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_RESIZE_DOWN,
|
||||
node_id: MASK_RESIZE_DOWN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -728,6 +567,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
};
|
||||
graph.nodes[MASK_BLUR] = {
|
||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -742,47 +582,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Seam Paint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_FIX_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_FROM_ALPHA,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
field: 'mask2',
|
||||
},
|
||||
},
|
||||
// Color Correct The Inpainted Result
|
||||
{
|
||||
source: {
|
||||
@ -806,7 +605,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -837,7 +636,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SEAM_MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -870,7 +669,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (shouldUseSDXLRefiner) {
|
||||
addSDXLRefinerToGraph(state, graph, SEAM_FIX_DENOISE_LATENTS);
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
|
@ -18,6 +18,8 @@ export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
export const RESIZE = 'resize_image';
|
||||
export const CANVAS_OUTPUT = 'canvas_output';
|
||||
export const INPAINT = 'inpaint';
|
||||
export const INPAINT_SEAM_FIX = 'inpaint_seam_fix';
|
||||
export const INPAINT_IMAGE = 'inpaint_image';
|
||||
export const SCALED_INPAINT_IMAGE = 'scaled_inpaint_image';
|
||||
export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up';
|
||||
@ -25,14 +27,10 @@ export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down';
|
||||
export const INPAINT_INFILL = 'inpaint_infill';
|
||||
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
|
||||
export const INPAINT_FINAL_IMAGE = 'inpaint_final_image';
|
||||
export const SEAM_FIX_DENOISE_LATENTS = 'seam_fix_denoise_latents';
|
||||
export const MASK_FROM_ALPHA = 'tomask';
|
||||
export const MASK_EDGE = 'mask_edge';
|
||||
export const MASK_BLUR = 'mask_blur';
|
||||
export const MASK_COMBINE = 'mask_combine';
|
||||
export const SEAM_MASK_COMBINE = 'seam_mask_combine';
|
||||
export const SEAM_MASK_RESIZE_UP = 'seam_mask_resize_up';
|
||||
export const SEAM_MASK_RESIZE_DOWN = 'seam_mask_resize_down';
|
||||
export const MASK_RESIZE_UP = 'mask_resize_up';
|
||||
export const MASK_RESIZE_DOWN = 'mask_resize_down';
|
||||
export const COLOR_CORRECT = 'color_correct';
|
||||
|
@ -1,36 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamBlur } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamBlur = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamBlur = useAppSelector(
|
||||
(state: RootState) => state.generation.seamBlur
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamBlur')}
|
||||
min={0}
|
||||
max={64}
|
||||
step={8}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamBlur}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamBlur(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamBlur(8));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamBlur);
|
@ -1,27 +0,0 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ParamSeamBlur from './ParamSeamBlur';
|
||||
import ParamSeamSize from './ParamSeamSize';
|
||||
import ParamSeamSteps from './ParamSeamSteps';
|
||||
import ParamSeamStrength from './ParamSeamStrength';
|
||||
import ParamSeamThreshold from './ParamSeamThreshold';
|
||||
|
||||
const ParamSeamPaintingCollapse = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAICollapse label={t('parameters.seamPaintingHeader')}>
|
||||
<Flex sx={{ flexDirection: 'column', gap: 2, paddingBottom: 2 }}>
|
||||
<ParamSeamSize />
|
||||
<ParamSeamBlur />
|
||||
<ParamSeamSteps />
|
||||
<ParamSeamStrength />
|
||||
<ParamSeamThreshold />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamPaintingCollapse);
|
@ -1,36 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamSize } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamSize = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamSize = useAppSelector(
|
||||
(state: RootState) => state.generation.seamSize
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamSize')}
|
||||
min={0}
|
||||
max={128}
|
||||
step={8}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamSize}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSize(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamSize(16));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamSize);
|
@ -1,36 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamSteps } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamSteps = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamSteps = useAppSelector(
|
||||
(state: RootState) => state.generation.seamSteps
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamSteps')}
|
||||
min={0}
|
||||
max={100}
|
||||
step={1}
|
||||
sliderNumberInputProps={{ max: 999 }}
|
||||
value={seamSteps}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSteps(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamSteps(20));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamSteps);
|
@ -1,36 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamStrength } from 'features/parameters/store/generationSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamSeamStrength = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamStrength = useAppSelector(
|
||||
(state: RootState) => state.generation.seamStrength
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={t('parameters.seamStrength')}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
sliderNumberInputProps={{ max: 999 }}
|
||||
value={seamStrength}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamStrength(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamStrength(0.7));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamStrength);
|
@ -1,121 +0,0 @@
|
||||
import {
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
RangeSlider,
|
||||
RangeSliderFilledTrack,
|
||||
RangeSliderMark,
|
||||
RangeSliderThumb,
|
||||
RangeSliderTrack,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import {
|
||||
setSeamHighThreshold,
|
||||
setSeamLowThreshold,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
|
||||
const ParamSeamThreshold = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamLowThreshold = useAppSelector(
|
||||
(state: RootState) => state.generation.seamLowThreshold
|
||||
);
|
||||
|
||||
const seamHighThreshold = useAppSelector(
|
||||
(state: RootState) => state.generation.seamHighThreshold
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleSeamThresholdChange = useCallback(
|
||||
(v: number[]) => {
|
||||
dispatch(setSeamLowThreshold(v[0] as number));
|
||||
dispatch(setSeamHighThreshold(v[1] as number));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSeamThresholdReset = () => {
|
||||
dispatch(setSeamLowThreshold(100));
|
||||
dispatch(setSeamHighThreshold(200));
|
||||
};
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.seamThreshold')}</FormLabel>
|
||||
<HStack w="100%" gap={4} mt={-2}>
|
||||
<RangeSlider
|
||||
aria-label={[
|
||||
t('parameters.seamLowThreshold'),
|
||||
t('parameters.seamHighThreshold'),
|
||||
]}
|
||||
value={[seamLowThreshold, seamHighThreshold]}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
minStepsBetweenThumbs={1}
|
||||
onChange={handleSeamThresholdChange}
|
||||
>
|
||||
<RangeSliderTrack>
|
||||
<RangeSliderFilledTrack />
|
||||
</RangeSliderTrack>
|
||||
<Tooltip label={seamLowThreshold} placement="top" hasArrow>
|
||||
<RangeSliderThumb index={0} />
|
||||
</Tooltip>
|
||||
<Tooltip label={seamHighThreshold} placement="top" hasArrow>
|
||||
<RangeSliderThumb index={1} />
|
||||
</Tooltip>
|
||||
<RangeSliderMark
|
||||
value={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
}}
|
||||
>
|
||||
0
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.392}
|
||||
sx={{
|
||||
insetInlineStart: '38.4% !important',
|
||||
transform: 'translateX(-38.4%)',
|
||||
}}
|
||||
>
|
||||
100
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.784}
|
||||
sx={{
|
||||
insetInlineStart: '79.8% !important',
|
||||
transform: 'translateX(-79.8%)',
|
||||
}}
|
||||
>
|
||||
200
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={1}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
}}
|
||||
>
|
||||
255
|
||||
</RangeSliderMark>
|
||||
</RangeSlider>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label={t('accessibility.reset')}
|
||||
tooltip={t('accessibility.reset')}
|
||||
icon={<BiReset />}
|
||||
onClick={handleSeamThresholdReset}
|
||||
/>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSeamThreshold);
|
@ -37,12 +37,6 @@ export interface GenerationState {
|
||||
scheduler: SchedulerParam;
|
||||
maskBlur: number;
|
||||
maskBlurMethod: MaskBlurMethodParam;
|
||||
seamSize: number;
|
||||
seamBlur: number;
|
||||
seamSteps: number;
|
||||
seamStrength: StrengthParam;
|
||||
seamLowThreshold: number;
|
||||
seamHighThreshold: number;
|
||||
seed: SeedParam;
|
||||
seedWeights: string;
|
||||
shouldFitToWidthHeight: boolean;
|
||||
@ -80,12 +74,6 @@ export const initialGenerationState: GenerationState = {
|
||||
scheduler: 'euler',
|
||||
maskBlur: 16,
|
||||
maskBlurMethod: 'box',
|
||||
seamSize: 16,
|
||||
seamBlur: 8,
|
||||
seamSteps: 20,
|
||||
seamStrength: 0.7,
|
||||
seamLowThreshold: 100,
|
||||
seamHighThreshold: 200,
|
||||
seed: 0,
|
||||
seedWeights: '',
|
||||
shouldFitToWidthHeight: true,
|
||||
@ -212,24 +200,6 @@ export const generationSlice = createSlice({
|
||||
setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => {
|
||||
state.maskBlurMethod = action.payload;
|
||||
},
|
||||
setSeamSize: (state, action: PayloadAction<number>) => {
|
||||
state.seamSize = action.payload;
|
||||
},
|
||||
setSeamBlur: (state, action: PayloadAction<number>) => {
|
||||
state.seamBlur = action.payload;
|
||||
},
|
||||
setSeamSteps: (state, action: PayloadAction<number>) => {
|
||||
state.seamSteps = action.payload;
|
||||
},
|
||||
setSeamStrength: (state, action: PayloadAction<number>) => {
|
||||
state.seamStrength = action.payload;
|
||||
},
|
||||
setSeamLowThreshold: (state, action: PayloadAction<number>) => {
|
||||
state.seamLowThreshold = action.payload;
|
||||
},
|
||||
setSeamHighThreshold: (state, action: PayloadAction<number>) => {
|
||||
state.seamHighThreshold = action.payload;
|
||||
},
|
||||
setTileSize: (state, action: PayloadAction<number>) => {
|
||||
state.tileSize = action.payload;
|
||||
},
|
||||
@ -336,12 +306,6 @@ export const {
|
||||
setScheduler,
|
||||
setMaskBlur,
|
||||
setMaskBlurMethod,
|
||||
setSeamSize,
|
||||
setSeamBlur,
|
||||
setSeamSteps,
|
||||
setSeamStrength,
|
||||
setSeamLowThreshold,
|
||||
setSeamHighThreshold,
|
||||
setSeed,
|
||||
setSeedWeights,
|
||||
setShouldFitToWidthHeight,
|
||||
|
@ -2,7 +2,6 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
|
||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
||||
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
|
||||
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
@ -23,7 +22,6 @@ export default function SDXLUnifiedCanvasTabParameters() {
|
||||
<ParamNoiseCollapse />
|
||||
<ParamMaskAdjustmentCollapse />
|
||||
<ParamInfillAndScalingCollapse />
|
||||
<ParamSeamPaintingCollapse />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import ParamControlNetCollapse from 'features/parameters/components/Parameters/C
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
|
||||
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
|
||||
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
|
||||
@ -24,7 +23,6 @@ const UnifiedCanvasParameters = () => {
|
||||
<ParamSymmetryCollapse />
|
||||
<ParamMaskAdjustmentCollapse />
|
||||
<ParamInfillAndScalingCollapse />
|
||||
<ParamSeamPaintingCollapse />
|
||||
<ParamAdvancedCollapse />
|
||||
</>
|
||||
);
|
||||
|
@ -3,7 +3,4 @@ import { UIState } from './uiTypes';
|
||||
/**
|
||||
* UI slice persist denylist
|
||||
*/
|
||||
export const uiPersistDenylist: (keyof UIState)[] = [
|
||||
'shouldShowImageDetails',
|
||||
'globalContextMenuCloseTrigger',
|
||||
];
|
||||
export const uiPersistDenylist: (keyof UIState)[] = ['shouldShowImageDetails'];
|
||||
|
@ -20,7 +20,6 @@ export const initialUIState: UIState = {
|
||||
shouldShowProgressInViewer: true,
|
||||
shouldShowEmbeddingPicker: false,
|
||||
favoriteSchedulers: [],
|
||||
globalContextMenuCloseTrigger: 0,
|
||||
};
|
||||
|
||||
export const uiSlice = createSlice({
|
||||
@ -97,9 +96,6 @@ export const uiSlice = createSlice({
|
||||
toggleEmbeddingPicker: (state) => {
|
||||
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
|
||||
},
|
||||
contextMenusClosed: (state) => {
|
||||
state.globalContextMenuCloseTrigger += 1;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(initialImageChanged, (state) => {
|
||||
@ -126,7 +122,6 @@ export const {
|
||||
setShouldShowProgressInViewer,
|
||||
favoriteSchedulersChanged,
|
||||
toggleEmbeddingPicker,
|
||||
contextMenusClosed,
|
||||
} = uiSlice.actions;
|
||||
|
||||
export default uiSlice.reducer;
|
||||
|
@ -26,5 +26,4 @@ export interface UIState {
|
||||
shouldShowProgressInViewer: boolean;
|
||||
shouldShowEmbeddingPicker: boolean;
|
||||
favoriteSchedulers: SchedulerParam[];
|
||||
globalContextMenuCloseTrigger: number;
|
||||
}
|
||||
|
@ -573,7 +573,7 @@ export type components = {
|
||||
file: Blob;
|
||||
};
|
||||
/**
|
||||
* Boolean Primitive Collection
|
||||
* Boolean Collection
|
||||
* @description A collection of boolean primitive values
|
||||
*/
|
||||
BooleanCollectionInvocation: {
|
||||
@ -619,7 +619,7 @@ export type components = {
|
||||
collection?: (boolean)[];
|
||||
};
|
||||
/**
|
||||
* Boolean Primitive
|
||||
* Boolean
|
||||
* @description A boolean primitive value
|
||||
*/
|
||||
BooleanInvocation: {
|
||||
@ -1002,7 +1002,7 @@ export type components = {
|
||||
clip?: components["schemas"]["ClipField"];
|
||||
};
|
||||
/**
|
||||
* Conditioning Primitive Collection
|
||||
* Conditioning Collection
|
||||
* @description A collection of conditioning tensor primitive values
|
||||
*/
|
||||
ConditioningCollectionInvocation: {
|
||||
@ -1770,7 +1770,7 @@ export type components = {
|
||||
field: string;
|
||||
};
|
||||
/**
|
||||
* Float Primitive Collection
|
||||
* Float Collection
|
||||
* @description A collection of float primitive values
|
||||
*/
|
||||
FloatCollectionInvocation: {
|
||||
@ -1816,7 +1816,7 @@ export type components = {
|
||||
collection?: (number)[];
|
||||
};
|
||||
/**
|
||||
* Float Primitive
|
||||
* Float
|
||||
* @description A float primitive value
|
||||
*/
|
||||
FloatInvocation: {
|
||||
@ -2161,7 +2161,7 @@ export type components = {
|
||||
channel?: "A" | "R" | "G" | "B";
|
||||
};
|
||||
/**
|
||||
* Image Primitive Collection
|
||||
* Image Collection
|
||||
* @description A collection of image primitive values
|
||||
*/
|
||||
ImageCollectionInvocation: {
|
||||
@ -3113,7 +3113,7 @@ export type components = {
|
||||
seed?: number;
|
||||
};
|
||||
/**
|
||||
* Integer Primitive Collection
|
||||
* Integer Collection
|
||||
* @description A collection of integer primitive values
|
||||
*/
|
||||
IntegerCollectionInvocation: {
|
||||
@ -3159,7 +3159,7 @@ export type components = {
|
||||
collection?: (number)[];
|
||||
};
|
||||
/**
|
||||
* Integer Primitive
|
||||
* Integer
|
||||
* @description An integer primitive value
|
||||
*/
|
||||
IntegerInvocation: {
|
||||
@ -3256,7 +3256,7 @@ export type components = {
|
||||
item?: unknown;
|
||||
};
|
||||
/**
|
||||
* Latents Primitive Collection
|
||||
* Latents Collection
|
||||
* @description A collection of latents tensor primitive values
|
||||
*/
|
||||
LatentsCollectionInvocation: {
|
||||
@ -5786,7 +5786,7 @@ export type components = {
|
||||
show_easing_plot?: boolean;
|
||||
};
|
||||
/**
|
||||
* String Primitive Collection
|
||||
* String Collection
|
||||
* @description A collection of string primitive values
|
||||
*/
|
||||
StringCollectionInvocation: {
|
||||
@ -5832,7 +5832,7 @@ export type components = {
|
||||
collection?: (string)[];
|
||||
};
|
||||
/**
|
||||
* String Primitive
|
||||
* String
|
||||
* @description A string primitive value
|
||||
*/
|
||||
StringInvocation: {
|
||||
@ -6193,6 +6193,24 @@ export type components = {
|
||||
ui_hidden: boolean;
|
||||
ui_type?: components["schemas"]["UIType"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* ControlNetModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
@ -6205,24 +6223,6 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* ControlNetModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
@ -3557,6 +3557,11 @@ eslint-plugin-react-hooks@^4.6.0:
|
||||
resolved "https://registry.yarnpkg.com/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-4.6.0.tgz#4c3e697ad95b77e93f8646aaa1630c1ba607edd3"
|
||||
integrity sha512-oFc7Itz9Qxh2x4gNHStv3BqJq54ExXmfC+a1NjAta66IAN87Wu0R/QArgIS9qKzX3dXKPI9H5crl9QchNMY9+g==
|
||||
|
||||
eslint-plugin-react-memo@^0.0.3:
|
||||
version "0.0.3"
|
||||
resolved "https://registry.yarnpkg.com/eslint-plugin-react-memo/-/eslint-plugin-react-memo-0.0.3.tgz#26542aa2eeabed37f354c64c6b4eabc07051cf78"
|
||||
integrity sha512-IZzLDZJF4V84XL9+v74ypDSts/hAQtNeYFZGc3wvdX+YgIw4pkn4GiXPJ6MNUccNTPYJqr89Nnvqo3rcesEBOQ==
|
||||
|
||||
eslint-plugin-react@^7.32.2:
|
||||
version "7.32.2"
|
||||
resolved "https://registry.yarnpkg.com/eslint-plugin-react/-/eslint-plugin-react-7.32.2.tgz#e71f21c7c265ebce01bcbc9d0955170c55571f10"
|
||||
|
Reference in New Issue
Block a user