Compare commits

...

36 Commits

Author SHA1 Message Date
86c11f9e27 make session_list api return a raw dict rather than pydantic object 2023-08-17 22:22:06 -04:00
832335998f Update 'monkeypatched' controlnet class (#4269)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
Should be removed when added in diffusers
https://github.com/huggingface/diffusers/pull/4599
2023-08-17 15:49:54 -04:00
1102c12084 Merge branch 'main' into fix/sdxl_controlnet 2023-08-17 15:40:51 -04:00
b5cee7d20c blackify chore 2023-08-17 15:40:15 -04:00
89b82b3dc4 (feat): Add Seam Painting to Canvas (1.x, 2.x & SDXL w/ Refiner) (#4292)
## What type of PR is this? (check all applicable)

- [x] Feature

## Have you discussed this change with the InvokeAI team?
- [x] Yes
      
## Description

PR to add Seam Painting back to the Canvas.

## TODO Later

While the graph works as intended, it has become extremely large and
complex. I don't know if there's a simpler way to do this. Maybe there
is but there's soo many connections and visualizing the graph in my head
is extremely difficult. We might need to create some kind of tooling for
this. Coz it's going going to get crazier.

But well works for now.
2023-08-17 21:24:39 +12:00
8923201fdf Merge branch 'main' into seam-painting 2023-08-17 21:21:44 +12:00
226409107b Fix for Image Deletion issue 2023-08-17 17:18:11 +10:00
ae986bf873 Report RAM usage and RAM cache statistics after each generation (#4287)
## What type of PR is this? (check all applicable)

- [X] Feature

## Have you discussed this change with the InvokeAI team?
- [X] Yes

     
## Have you updated all relevant documentation?
- [X] Yes


## Description

This PR enhances the logging of performance statistics to include RAM
and model cache information. After each generation, the following will
be logged. The new information follows TOTAL GRAPH EXECUTION TIME.

```
[2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> Graph stats: 2408dbec-50d0-44a3-bbc4-427037e3f7d4
[2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> Node                 Calls    Seconds VRAM Used
[2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> main_model_loader        1     0.004s     0.000G
[2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> clip_skip                1     0.002s     0.000G
[2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> compel                   2     2.706s     0.246G
[2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> rand_int                 1     0.002s     0.244G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> range_of_size            1     0.002s     0.244G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> iterate                  1     0.002s     0.244G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> metadata_accumulator     1     0.002s     0.244G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> noise                    1     0.003s     0.244G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> denoise_latents          1     2.429s     2.022G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> l2i                      1     1.020s     1.858G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME:    6.171s
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> RAM used by InvokeAI process: 4.50G (delta=0.10G)
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> RAM used to load models: 1.99G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> VRAM in use: 0.303G
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> RAM cache statistics:
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO -->    Model cache hits: 2
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO -->    Model cache misses: 5
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO -->    Models cached: 5
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO -->    Models cleared from cache: 0
[2023-08-15 21:55:39,011]::[InvokeAI]::INFO -->    Cache high water mark: 1.99/7.50G    
```

There may be a memory leak in InvokeAI. I'm seeing the process memory
usage increasing by about 100 MB with each generation as shown in the
example above.
2023-08-17 16:10:18 +12:00
daf75a1361 blackify 2023-08-16 21:47:29 -04:00
fe4b2d53ed Merge branch 'feat/collect-more-stats' of github.com:invoke-ai/InvokeAI into feat/collect-more-stats 2023-08-16 21:39:29 -04:00
c39f8b478b fix misplaced ram_used and ram_changed attributes 2023-08-16 21:39:18 -04:00
1f82d8013e Merge branch 'main' into feat/collect-more-stats 2023-08-16 18:51:17 -04:00
e373bfca54 fix several broken links in the installation index 2023-08-16 17:54:39 -04:00
2ca8611723 add +/- sign in front of RAM delta 2023-08-16 15:53:01 -04:00
b12cf315a8 Merge branch 'main' into feat/collect-more-stats 2023-08-16 09:19:33 -04:00
975586bb40 Merge branch 'main' into seam-painting 2023-08-17 01:05:42 +12:00
a7ba142ad9 feat(ui): set min zoom on nodes to 0.1 2023-08-16 23:04:36 +10:00
0d36bab6cc fix(ui): do not rerender top panel buttons 2023-08-16 23:04:36 +10:00
c2e7f62701 fix(ui): do not rerender edges 2023-08-16 23:04:36 +10:00
1f194e3688 chore(ui): lint 2023-08-16 23:04:36 +10:00
f9b8b5cff2 fix(ui): improve node rendering performance
Previously the editor was using prop-drilling node data and templates to get values deep into nodes. This ended up causing very noticeable performance degradation. For example, any text entry fields were super laggy.

Refactor the whole thing to use memoized selectors via hooks. The hooks are mostly very narrow, returning only the data needed.

Data objects are never passed down, only node id and field name - sometimes the field kind ('input' or 'output').

The end result is a *much* smoother node editor with very minimal rerenders.
2023-08-16 23:04:36 +10:00
f7c92e1eff fix(ui): disable awkward resize animation for <Flow /> 2023-08-16 23:04:36 +10:00
70b8c3dfea fix(ui): fix context menu on workflow editor
There is a tricky mouse event interaction between chakra's `useOutsideClick()` hook (used by chakra `<Menu />`) and reactflow. The hook doesn't work when you click the main reactflow area.

To get around this, I've used a dirty hack, copy-pasting the simple context menu component we use, and extending it slightly to respond to a global `contextMenusClosed` redux action.
2023-08-16 23:04:36 +10:00
43b30355e4 feat: make primitive node titles consistent 2023-08-16 23:04:36 +10:00
a93bd01353 fix bad merge 2023-08-16 08:53:07 -04:00
bb1b8ceaa8 Update invokeai/backend/model_management/model_cache.py
Co-authored-by: StAlKeR7779 <stalkek7779@yandex.ru>
2023-08-16 08:48:44 -04:00
be8edaf3fd Merge branch 'main' into feat/collect-more-stats 2023-08-16 08:48:14 -04:00
9cbaefaa81 feat: Add Seam Painting to SDXL 2023-08-16 19:46:48 +12:00
cc7c6e5d41 feat: Add Seam Painting with Scale Before 2023-08-16 19:35:03 +12:00
f2ee8a3da8 wip: Basic Seam Painting (only normal models) (no scale) 2023-08-16 17:26:23 +12:00
e98d7a52d4 feat: Add Seam Painting Options 2023-08-16 17:25:55 +12:00
21e1c0a5f0 tweaked formatting 2023-08-15 22:25:30 -04:00
f9958de6be added memory used to load models 2023-08-15 21:56:19 -04:00
ec10aca91e report RAM and RAM cache statistics 2023-08-15 21:00:30 -04:00
a4b029d03c write RAM usage and change after each generation 2023-08-15 18:21:31 -04:00
4f82273fc4 Update 'monkeypatched' controlnet class 2023-08-15 11:07:43 -04:00
81 changed files with 2370 additions and 1008 deletions

View File

@ -25,10 +25,10 @@ This method is recommended for experienced users and developers
#### [Docker Installation](040_INSTALL_DOCKER.md)
This method is recommended for those familiar with running Docker containers
### Other Installation Guides
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
- [XFormers](installation/070_INSTALL_XFORMERS.md)
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
- [PyPatchMatch](060_INSTALL_PATCHMATCH.md)
- [XFormers](070_INSTALL_XFORMERS.md)
- [CUDA and ROCm Drivers](030_INSTALL_CUDA_AND_ROCM.md)
- [Installing New Models](050_INSTALLING_MODELS.md)
## :fontawesome-solid-computer: Hardware Requirements

View File

@ -40,7 +40,7 @@ async def create_session(
@session_router.get(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
responses={200: {"model": PaginatedResults[dict]}},
)
async def list_sessions(
page: int = Query(default=0, description="The page of results to get"),

View File

@ -48,7 +48,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
)
@title("Boolean")
@title("Boolean Primitive")
@tags("primitives", "boolean")
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
@ -62,7 +62,7 @@ class BooleanInvocation(BaseInvocation):
return BooleanOutput(a=self.a)
@title("Boolean Collection")
@title("Boolean Primitive Collection")
@tags("primitives", "boolean", "collection")
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
@ -101,7 +101,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
)
@title("Integer")
@title("Integer Primitive")
@tags("primitives", "integer")
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
@ -115,7 +115,7 @@ class IntegerInvocation(BaseInvocation):
return IntegerOutput(a=self.a)
@title("Integer Collection")
@title("Integer Primitive Collection")
@tags("primitives", "integer", "collection")
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
@ -154,7 +154,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
)
@title("Float")
@title("Float Primitive")
@tags("primitives", "float")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
@ -168,7 +168,7 @@ class FloatInvocation(BaseInvocation):
return FloatOutput(a=self.param)
@title("Float Collection")
@title("Float Primitive Collection")
@tags("primitives", "float", "collection")
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
@ -207,7 +207,7 @@ class StringCollectionOutput(BaseInvocationOutput):
)
@title("String")
@title("String Primitive")
@tags("primitives", "string")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
@ -221,7 +221,7 @@ class StringInvocation(BaseInvocation):
return StringOutput(text=self.text)
@title("String Collection")
@title("String Primitive Collection")
@tags("primitives", "string", "collection")
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
@ -289,7 +289,7 @@ class ImageInvocation(BaseInvocation):
)
@title("Image Collection")
@title("Image Primitive Collection")
@tags("primitives", "image", "collection")
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""
@ -357,7 +357,7 @@ class LatentsInvocation(BaseInvocation):
return build_latents_output(self.latents.latents_name, latents)
@title("Latents Collection")
@title("Latents Primitive Collection")
@tags("primitives", "latents", "collection")
class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values"""
@ -475,7 +475,7 @@ class ConditioningInvocation(BaseInvocation):
return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Collection")
@title("Conditioning Primitive Collection")
@tags("primitives", "conditioning", "collection")
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""

View File

@ -29,6 +29,7 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme
writes to the system log is stored in InvocationServices.performance_statistics.
"""
import psutil
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
@ -42,6 +43,11 @@ import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from invokeai.backend.model_management.model_cache import CacheStats
# size of GIG in bytes
GIG = 1073741824
class InvocationStatsServiceBase(ABC):
@ -89,6 +95,8 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
@ -97,6 +105,8 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
pass
@ -115,6 +125,9 @@ class NodeStats:
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
@ -133,31 +146,62 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0
self.ram_changed: float = 0.0
class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
"""Context manager for collecting statistics."""
invocation: BaseInvocation = None
collector: "InvocationStatsServiceBase" = None
graph_id: str = None
start_time: int = 0
ram_used: int = 0
model_manager: ModelManagerService = None
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerService,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
self.ram_used = 0
self.model_manager = model_manager
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_mem_stats(
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
graph_id=self.graph_id,
invocation_type=self.invocation.type,
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
@ -166,7 +210,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
def reset_all_stats(self):
"""Zero all statistics"""
@ -179,13 +224,36 @@ class InvocationStatsService(InvocationStatsServiceBase):
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used
self.ram_changed = ram_changed
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
@ -197,7 +265,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set()
@ -208,16 +276,30 @@ class InvocationStatsService(InvocationStatsServiceBase):
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
total_time += stats.time_used
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
logger.info(f"RAM used to load models: {loaded:4.2f}G")
if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
logger.info("RAM cache statistics:")
logger.info(f" Model cache hits: {cache_stats.hits}")
logger.info(f" Model cache misses: {cache_stats.misses}")
logger.info(f" Models cached: {cache_stats.in_cache}")
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_management import (
ModelNotFoundException,
)
from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.model_management.model_cache import CacheStats
import torch
from invokeai.app.models.exceptions import CanceledException
@ -276,6 +277,13 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
@ -500,6 +508,12 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.

View File

@ -86,7 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state.id):
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
# use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection

View File

@ -1,6 +1,7 @@
import sqlite3
import json
from threading import Lock
from typing import Generic, Optional, TypeVar, get_args
from typing import Generic, Optional, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
@ -99,7 +100,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
try:
self._lock.acquire()
self._cursor.execute(
@ -108,7 +109,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = [json.loads(r[0]) for r in result]
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
@ -117,9 +118,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
try:
self._lock.acquire()
self._cursor.execute(
@ -128,7 +129,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = [json.loads(r[0]) for r in result]
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
@ -140,4 +141,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)

View File

@ -21,12 +21,12 @@ import os
import sys
import hashlib
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any
import torch
import logging
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, SubModelType, ModelBase
@ -41,6 +41,18 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
GIG = 1073741824
@dataclass
class CacheStats(object):
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object):
"Forward declaration"
pass
@ -115,6 +127,9 @@ class ModelCache(object):
self.sha_chunksize = sha_chunksize
self.logger = logger
# used for stats collection
self.stats = None
self._cached_models = dict()
self._cache_stack = list()
@ -181,13 +196,14 @@ class ModelCache(object):
model_type=model_type,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1
# this will remove older cached models until
# there is sufficient room to load the requested model
@ -201,6 +217,17 @@ class ModelCache(object):
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
else:
if self.stats:
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
)
with suppress(Exception):
self._cache_stack.remove(key)
@ -280,14 +307,14 @@ class ModelCache(object):
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
"""
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == "cuda"
@ -310,12 +337,15 @@ class ModelCache(object):
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
)
def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()])
current_size = self._cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
@ -364,6 +394,8 @@ class ModelCache(object):
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry

View File

@ -240,6 +240,7 @@ class InvokeAIDiffuserComponent:
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)

View File

@ -4,8 +4,15 @@ import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
@ -18,10 +25,11 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
"""
A ControlNet model.
@ -52,12 +60,25 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
@ -98,10 +119,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
@ -109,6 +135,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads=64,
):
super().__init__()
@ -136,6 +163,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
@ -145,16 +175,43 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@ -178,6 +235,29 @@ class ControlNetModel(ModelMixin, ConfigMixin):
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
@ -212,6 +292,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
@ -248,6 +329,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
@ -277,7 +359,22 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
@ -463,6 +560,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False,
@ -486,7 +584,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
@ -549,6 +649,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
@ -560,11 +661,34 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if "addition_embed_type" in self.config:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down

View File

@ -506,10 +506,14 @@
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"seamPaintingHeader": "Seam Painting",
"seamSize": "Seam Size",
"seamBlur": "Seam Blur",
"seamStrength": "Seam Strength",
"seamSteps": "Seam Steps",
"seamStrength": "Seam Strength",
"seamThreshold": "Seam Threshold",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"scaleBeforeProcessing": "Scale Before Processing",
"scaledWidth": "Scaled W",
"scaledHeight": "Scaled H",

View File

@ -121,7 +121,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
// handle singles in separate listener
return;
}

View File

@ -0,0 +1,126 @@
/**
* This is a copy-paste of https://github.com/lukasbach/chakra-ui-contextmenu with a small change.
*
* The reactflow background element somehow prevents the chakra `useOutsideClick()` hook from working.
* With a menu open, clicking on the reactflow background element doesn't close the menu.
*
* Reactflow does provide an `onPaneClick` to handle clicks on the background element, but it is not
* straightforward to programatically close the menu.
*
* As a (hopefully temporary) workaround, we will use a dirty hack:
* - create `globalContextMenuCloseTrigger: number` in `ui` slice
* - increment it in `onPaneClick`
* - `useEffect()` to close the menu when `globalContextMenuCloseTrigger` changes
*/
import {
Menu,
MenuButton,
MenuButtonProps,
MenuProps,
Portal,
PortalProps,
useEventListener,
} from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import * as React from 'react';
import {
MutableRefObject,
useCallback,
useEffect,
useRef,
useState,
} from 'react';
export interface IAIContextMenuProps<T extends HTMLElement> {
renderMenu: () => JSX.Element | null;
children: (ref: MutableRefObject<T | null>) => JSX.Element | null;
menuProps?: Omit<MenuProps, 'children'> & { children?: React.ReactNode };
portalProps?: Omit<PortalProps, 'children'> & { children?: React.ReactNode };
menuButtonProps?: MenuButtonProps;
}
export function IAIContextMenu<T extends HTMLElement = HTMLElement>(
props: IAIContextMenuProps<T>
) {
const [isOpen, setIsOpen] = useState(false);
const [isRendered, setIsRendered] = useState(false);
const [isDeferredOpen, setIsDeferredOpen] = useState(false);
const [position, setPosition] = useState<[number, number]>([0, 0]);
const targetRef = useRef<T>(null);
const globalContextMenuCloseTrigger = useAppSelector(
(state) => state.ui.globalContextMenuCloseTrigger
);
useEffect(() => {
if (isOpen) {
setTimeout(() => {
setIsRendered(true);
setTimeout(() => {
setIsDeferredOpen(true);
});
});
} else {
setIsDeferredOpen(false);
const timeout = setTimeout(() => {
setIsRendered(isOpen);
}, 1000);
return () => clearTimeout(timeout);
}
}, [isOpen]);
useEffect(() => {
setIsOpen(false);
setIsDeferredOpen(false);
setIsRendered(false);
}, [globalContextMenuCloseTrigger]);
useEventListener('contextmenu', (e) => {
if (
targetRef.current?.contains(e.target as HTMLElement) ||
e.target === targetRef.current
) {
e.preventDefault();
setIsOpen(true);
setPosition([e.pageX, e.pageY]);
} else {
setIsOpen(false);
}
});
const onCloseHandler = useCallback(() => {
props.menuProps?.onClose?.();
setIsOpen(false);
}, [props.menuProps]);
return (
<>
{props.children(targetRef)}
{isRendered && (
<Portal {...props.portalProps}>
<Menu
isOpen={isDeferredOpen}
gutter={0}
{...props.menuProps}
onClose={onCloseHandler}
>
<MenuButton
aria-hidden={true}
w={1}
h={1}
style={{
position: 'absolute',
left: position[0],
top: position[1],
cursor: 'default',
}}
{...props.menuButtonProps}
/>
{props.renderMenu()}
</Menu>
</Portal>
)}
</>
);
}

View File

@ -16,6 +16,7 @@ import ImageContextMenu from 'features/gallery/components/ImageContextMenu/Image
import {
MouseEvent,
ReactElement,
ReactNode,
SyntheticEvent,
memo,
useCallback,
@ -32,6 +33,17 @@ import {
TypesafeDroppableData,
} from 'features/dnd/types';
const defaultUploadElement = (
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
);
const defaultNoContentFallback = <IAINoContentFallback icon={FaImage} />;
type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
@ -47,13 +59,14 @@ type IAIDndImageProps = FlexProps & {
fitContainer?: boolean;
droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData;
dropLabel?: string;
dropLabel?: ReactNode;
isSelected?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
withHoverOverlay?: boolean;
children?: JSX.Element;
uploadElement?: ReactNode;
};
const IAIDndImage = (props: IAIDndImageProps) => {
@ -74,7 +87,8 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel,
isSelected = false,
thumbnail = false,
noContentFallback = <IAINoContentFallback icon={FaImage} />,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
useThumbailFallback,
withHoverOverlay = false,
children,
@ -193,12 +207,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
{...getUploadButtonProps()}
>
<input {...getUploadInputProps()} />
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
{uploadElement}
</Flex>
</>
)}
@ -210,6 +219,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick}
/>
)}
{children}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
@ -217,7 +227,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel={dropLabel}
/>
)}
{children}
</Flex>
)}
</ImageContextMenu>

View File

@ -1,5 +1,8 @@
import { MenuList } from '@chakra-ui/react';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import {
IAIContextMenu,
IAIContextMenuProps,
} from 'common/components/IAIContextMenu';
import { MouseEvent, memo, useCallback } from 'react';
import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
@ -12,7 +15,7 @@ import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
type Props = {
imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children'];
children: IAIContextMenuProps<HTMLDivElement>['children'];
};
const selector = createSelector(
@ -33,7 +36,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}, []);
return (
<ContextMenu<HTMLDivElement>
<IAIContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
menuButtonProps={{
bg: 'transparent',
@ -68,7 +71,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}}
>
{children}
</ContextMenu>
</IAIContextMenu>
);
};

View File

@ -2,8 +2,9 @@ import { Badge, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useMemo } from 'react';
import { memo, useMemo } from 'react';
import {
BaseEdge,
EdgeLabelRenderer,
@ -20,78 +21,165 @@ const makeEdgeSelector = (
targetHandleId: string | null | undefined,
selected?: boolean
) =>
createSelector(stateSelector, ({ nodes }) => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
createSelector(
stateSelector,
({ nodes }) => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge =
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isInvocationToInvocationEdge =
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
const sourceType = isInvocationToInvocationEdge
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
: undefined;
const isSelected =
sourceNode?.selected || targetNode?.selected || selected;
const sourceType = isInvocationToInvocationEdge
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
: undefined;
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
: colorTokenToCssVar('base.500');
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
: colorTokenToCssVar('base.500');
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
});
const CollapsedEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
},
defaultSelectorOptions
);
const { isSelected, shouldAnimate } = useAppSelector(selector);
const [edgePath, labelX, labelY] = getBezierPath({
const CollapsedEdge = memo(
({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
sourcePosition,
targetPosition,
});
markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
);
const { base500 } = useChakraThemeTokens();
const { isSelected, shouldAnimate } = useAppSelector(selector);
return (
<>
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
const { base500 } = useChakraThemeTokens();
return (
<>
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate
? 'dashdraw 0.5s linear infinite'
: undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
}
);
CollapsedEdge.displayName = 'CollapsedEdge';
const DefaultEdge = memo(
({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate
? 'dashdraw 0.5s linear infinite'
@ -99,83 +187,11 @@ const CollapsedEdge = ({
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
};
);
}
);
const DefaultEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
);
};
DefaultEdge.displayName = 'DefaultEdge';
export const edgeTypes = {
collapsed: CollapsedEdge,

View File

@ -1,5 +1,6 @@
import { useToken } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { useCallback } from 'react';
import {
Background,
@ -114,6 +115,10 @@ export const Flow = () => {
[dispatch]
);
const handlePaneClick = useCallback(() => {
dispatch(contextMenusClosed());
}, [dispatch]);
return (
<ReactFlow
defaultViewport={viewport}
@ -132,12 +137,13 @@ export const Flow = () => {
connectionLineComponent={CustomConnectionLine}
onSelectionChange={handleSelectionChange}
isValidConnection={isValidConnection}
minZoom={0.2}
minZoom={0.1}
snapToGrid={shouldSnapToGrid}
snapGrid={[25, 25]}
connectionRadius={30}
proOptions={proOptions}
style={{ borderRadius }}
onPaneClick={handlePaneClick}
>
<TopLeftPanel />
<TopCenterPanel />

View File

@ -1,40 +1,34 @@
import { Flex } from '@chakra-ui/react';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map, some } from 'lodash-es';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import { useFieldNames, useWithFooter } from 'features/nodes/hooks/useNodeData';
import { memo } from 'react';
import InputField from '../fields/InputField';
import OutputField from '../fields/OutputField';
import NodeFooter, { FOOTER_FIELDS } from './NodeFooter';
import NodeFooter from './NodeFooter';
import NodeHeader from './NodeHeader';
import NodeWrapper from './NodeWrapper';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
};
const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => {
const { id: nodeId, data } = nodeProps;
const { inputs, outputs, isOpen } = data;
const inputFields = useMemo(
() => map(inputs).filter((i) => i.name !== 'is_intermediate'),
[inputs]
);
const outputFields = useMemo(() => map(outputs), [outputs]);
const withFooter = useMemo(
() => some(outputs, (output) => FOOTER_FIELDS.includes(output.type)),
[outputs]
);
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const inputFieldNames = useFieldNames(nodeId, 'input');
const outputFieldNames = useFieldNames(nodeId, 'output');
const withFooter = useWithFooter(nodeId);
return (
<NodeWrapper nodeProps={nodeProps}>
<NodeHeader nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
<NodeWrapper nodeId={nodeId} selected={selected}>
<NodeHeader
nodeId={nodeId}
isOpen={isOpen}
label={label}
selected={selected}
type={type}
/>
{isOpen && (
<>
<Flex
@ -54,27 +48,23 @@ const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => {
className="nopan"
sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}
>
{outputFields.map((field) => (
{outputFieldNames.map((fieldName) => (
<OutputField
key={`${nodeId}.${field.id}.input-field`}
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
key={`${nodeId}.${fieldName}.output-field`}
nodeId={nodeId}
fieldName={fieldName}
/>
))}
{inputFields.map((field) => (
{inputFieldNames.map((fieldName) => (
<InputField
key={`${nodeId}.${field.id}.input-field`}
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
key={`${nodeId}.${fieldName}.input-field`}
nodeId={nodeId}
fieldName={fieldName}
/>
))}
</Flex>
</Flex>
{withFooter && (
<NodeFooter nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
)}
{withFooter && <NodeFooter nodeId={nodeId} />}
</>
)}
</NodeWrapper>

View File

@ -2,16 +2,15 @@ import { ChevronUpIcon } from '@chakra-ui/icons';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
import { NodeData } from 'features/nodes/types/types';
import { memo, useCallback } from 'react';
import { NodeProps, useUpdateNodeInternals } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
interface Props {
nodeProps: NodeProps<NodeData>;
nodeId: string;
isOpen: boolean;
}
const NodeCollapseButton = (props: Props) => {
const { id: nodeId, isOpen } = props.nodeProps.data;
const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals();

View File

@ -1,20 +1,17 @@
import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, NodeProps, Position } from 'reactflow';
import { Handle, Position } from 'reactflow';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
nodeId: string;
}
const NodeCollapsedHandles = (props: Props) => {
const { data } = props.nodeProps;
const NodeCollapsedHandles = ({ nodeId }: Props) => {
const data = useNodeData(nodeId);
const { base400, base600 } = useChakraThemeTokens();
const backgroundColor = useColorModeValue(base400, base600);
@ -30,6 +27,10 @@ const NodeCollapsedHandles = (props: Props) => {
[backgroundColor]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<>
<Handle
@ -44,7 +45,7 @@ const NodeCollapsedHandles = (props: Props) => {
key={`${data.id}-${input.name}-collapsed-input-handle`}
type="target"
id={input.name}
isValidConnection={() => false}
isConnectable={false}
position={Position.Left}
style={{ visibility: 'hidden' }}
/>
@ -52,7 +53,6 @@ const NodeCollapsedHandles = (props: Props) => {
<Handle
type="source"
id={`${data.id}-collapsed-source`}
isValidConnection={() => false}
isConnectable={false}
position={Position.Right}
style={{ ...dummyHandleStyles, right: '-0.5rem' }}
@ -62,7 +62,7 @@ const NodeCollapsedHandles = (props: Props) => {
key={`${data.id}-${output.name}-collapsed-output-handle`}
type="source"
id={output.name}
isValidConnection={() => false}
isConnectable={false}
position={Position.Right}
style={{ visibility: 'hidden' }}
/>

View File

@ -6,49 +6,19 @@ import {
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import {
useHasImageOutput,
useIsIntermediate,
} from 'features/nodes/hooks/useNodeData';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
import { NodeProps } from 'reactflow';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS;
import { ChangeEvent, memo, useCallback } from 'react';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
nodeId: string;
};
const NodeFooter = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const dispatch = useAppDispatch();
const hasImageOutput = useMemo(
() =>
some(nodeTemplate?.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
),
[nodeTemplate?.outputs]
);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: nodeProps.data.id,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeProps.data.id]
);
const NodeFooter = ({ nodeId }: Props) => {
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
@ -62,19 +32,45 @@ const NodeFooter = (props: Props) => {
}}
>
<Spacer />
{hasImageOutput && (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
/>
</FormControl>
)}
<SaveImageCheckbox nodeId={nodeId} />
</Flex>
);
};
export default memo(NodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@ -1,10 +1,5 @@
import { Flex } from '@chakra-ui/react';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles';
import NodeNotesEdit from '../Invocation/NodeNotesEdit';
@ -12,14 +7,14 @@ import NodeStatusIndicator from '../Invocation/NodeStatusIndicator';
import NodeTitle from '../Invocation/NodeTitle';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
};
const NodeHeader = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const { isOpen } = nodeProps.data;
const NodeHeader = ({ nodeId, isOpen }: Props) => {
return (
<Flex
layerStyle="nodeHeader"
@ -35,18 +30,13 @@ const NodeHeader = (props: Props) => {
_dark: { color: 'base.200' },
}}
>
<NodeCollapseButton nodeProps={nodeProps} />
<NodeTitle nodeData={nodeProps.data} title={nodeTemplate.title} />
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeId={nodeId} />
<Flex alignItems="center">
<NodeStatusIndicator nodeProps={nodeProps} />
<NodeNotesEdit nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
<NodeStatusIndicator nodeId={nodeId} />
<NodeNotesEdit nodeId={nodeId} />
</Flex>
{!isOpen && (
<NodeCollapsedHandles
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
/>
)}
{!isOpen && <NodeCollapsedHandles nodeId={nodeId} />}
</Flex>
);
};

View File

@ -16,41 +16,31 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import {
useNodeData,
useNodeLabel,
useNodeTemplate,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
nodeId: string;
}
const NodeNotesEdit = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const { data } = nodeProps;
const NodeNotesEdit = ({ nodeId }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
},
[data.id, dispatch]
);
const label = useNodeLabel(nodeId);
const title = useNodeTemplateTitle(nodeId);
return (
<>
<Tooltip
label={
nodeTemplate ? (
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
) : undefined
}
label={<TooltipContent nodeId={nodeId} />}
placement="top"
shouldWrapChildren
>
@ -75,19 +65,10 @@ const NodeNotesEdit = (props: Props) => {
<Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>
{data.label || nodeTemplate?.title || 'Unknown Node'}
</ModalHeader>
<ModalHeader>{label || title || 'Unknown Node'}</ModalHeader>
<ModalCloseButton />
<ModalBody>
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
<NotesTextarea nodeId={nodeId} />
</ModalBody>
<ModalFooter />
</ModalContent>
@ -98,16 +79,49 @@ const NodeNotesEdit = (props: Props) => {
export default memo(NodeNotesEdit);
type TooltipContentProps = Props;
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
if (!isInvocationNodeData(data)) {
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
}
const TooltipContent = (props: TooltipContentProps) => {
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text>
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{props.nodeTemplate?.description}
{nodeTemplate?.description}
</Text>
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>}
{data?.notes && <Text>{data.notes}</Text>}
</Flex>
);
};
});
TooltipContent.displayName = 'TooltipContent';
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
});
NotesTextarea.displayName = 'NodesTextarea';

View File

@ -11,17 +11,12 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
NodeExecutionState,
NodeStatus,
} from 'features/nodes/types/types';
import { NodeExecutionState, NodeStatus } from 'features/nodes/types/types';
import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
nodeId: string;
};
const iconBoxSize = 3;
@ -33,8 +28,7 @@ const circleStyles = {
'.chakra-progress__track': { stroke: 'transparent' },
};
const NodeStatusIndicator = (props: Props) => {
const nodeId = props.nodeProps.data.id;
const NodeStatusIndicator = ({ nodeId }: Props) => {
const selectNodeExecutionState = useMemo(
() =>
createSelector(
@ -76,7 +70,7 @@ type TooltipLabelProps = {
nodeExecutionState: NodeExecutionState;
};
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState;
if (status === NodeStatus.PENDING) {
return <Text>Pending</Text>;
@ -118,13 +112,15 @@ const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
}
return null;
};
});
TooltipLabel.displayName = 'TooltipLabel';
type StatusIconProps = {
nodeExecutionState: NodeExecutionState;
};
const StatusIcon = (props: StatusIconProps) => {
const StatusIcon = memo((props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) {
return (
@ -182,4 +178,6 @@ const StatusIcon = (props: StatusIconProps) => {
);
}
return null;
};
});
StatusIcon.displayName = 'StatusIcon';

View File

@ -7,26 +7,29 @@ import {
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import {
useNodeLabel,
useNodeTemplateTitle,
} from 'features/nodes/hooks/useNodeData';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeData } from 'features/nodes/types/types';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
type Props = {
nodeData: NodeData;
title: string;
nodeId: string;
title?: string;
};
const NodeTitle = (props: Props) => {
const { title } = props;
const { id: nodeId, label } = props.nodeData;
const NodeTitle = ({ nodeId, title }: Props) => {
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title);
const label = useNodeLabel(nodeId);
const templateTitle = useNodeTemplateTitle(nodeId);
const [localTitle, setLocalTitle] = useState('');
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title);
setLocalTitle(newTitle || title || 'Problem Setting Title');
},
[nodeId, dispatch, title]
);
@ -37,8 +40,8 @@ const NodeTitle = (props: Props) => {
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || title);
}, [label, title]);
setLocalTitle(label || title || templateTitle || 'Problem Setting Title');
}, [label, templateTitle, title]);
return (
<Flex

View File

@ -6,10 +6,14 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeClicked } from 'features/nodes/store/nodesSlice';
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react';
import {
MouseEvent,
PropsWithChildren,
memo,
useCallback,
useMemo,
} from 'react';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
import { NodeData } from 'features/nodes/types/types';
import { NodeProps } from 'reactflow';
const useNodeSelect = (nodeId: string) => {
const dispatch = useAppDispatch();
@ -25,14 +29,13 @@ const useNodeSelect = (nodeId: string) => {
};
type NodeWrapperProps = PropsWithChildren & {
nodeProps: NodeProps<NodeData>;
nodeId: string;
selected: boolean;
width?: NonNullable<ChakraProps['sx']>['w'];
};
const NodeWrapper = (props: NodeWrapperProps) => {
const { width, children, nodeProps } = props;
const { data, selected } = nodeProps;
const nodeId = data.id;
const { width, children, nodeId, selected } = props;
const [
nodeSelectedOutlineLight,
@ -93,4 +96,4 @@ const NodeWrapper = (props: NodeWrapperProps) => {
);
};
export default NodeWrapper;
export default memo(NodeWrapper);

View File

@ -1,20 +1,26 @@
import { Box, Flex, Text } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { InvocationNodeData } from 'features/nodes/types/types';
import { memo } from 'react';
import { NodeProps } from 'reactflow';
import NodeCollapseButton from '../Invocation/NodeCollapseButton';
import NodeWrapper from '../Invocation/NodeWrapper';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
nodeId: string;
isOpen: boolean;
label: string;
type: string;
selected: boolean;
};
const UnknownNodeFallback = ({ nodeProps }: Props) => {
const { data } = nodeProps;
const { isOpen, label, type } = data;
const UnknownNodeFallback = ({
nodeId,
isOpen,
label,
type,
selected,
}: Props) => {
return (
<NodeWrapper nodeProps={nodeProps}>
<NodeWrapper nodeId={nodeId} selected={selected}>
<Flex
className={DRAG_HANDLE_CLASSNAME}
layerStyle="nodeHeader"
@ -27,7 +33,7 @@ const UnknownNodeFallback = ({ nodeProps }: Props) => {
fontSize: 'sm',
}}
>
<NodeCollapseButton nodeProps={nodeProps} />
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<Text
sx={{
w: 'full',

View File

@ -46,7 +46,6 @@ const NodeEditor = () => {
<AnimatePresence>
{isReady && (
<motion.div
layoutId="node-editor-flow"
initial={{
opacity: 0,
}}
@ -67,7 +66,6 @@ const NodeEditor = () => {
<AnimatePresence>
{!isReady && (
<motion.div
layoutId="node-editor-loading"
initial={{
opacity: 0,
}}

View File

@ -15,7 +15,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, useCallback } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaCog } from 'react-icons/fa';
import {
shouldAnimateEdgesChanged,
@ -23,21 +23,26 @@ import {
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from '../store/nodesSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
const selector = createSelector(stateSelector, ({ nodes }) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
};
});
const selector = createSelector(
stateSelector,
({ nodes }) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
};
},
defaultSelectorOptions
);
const NodeEditorSettings = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
@ -136,4 +141,4 @@ const NodeEditorSettings = () => {
);
};
export default NodeEditorSettings;
export default memo(NodeEditorSettings);

View File

@ -1,19 +1,12 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, NodeProps, Position } from 'reactflow';
import { Handle, HandleType, Position } from 'reactflow';
import {
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
colorTokenToCssVar,
} from '../../types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from '../../types/types';
import { InputFieldTemplate, OutputFieldTemplate } from '../../types/types';
export const handleBaseStyles: CSSProperties = {
position: 'absolute',
@ -32,9 +25,6 @@ export const outputHandleStyles: CSSProperties = {
};
type FieldHandleProps = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
handleType: HandleType;
isConnectionInProgress: boolean;

View File

@ -8,13 +8,11 @@ import {
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDraggable from 'common/components/IAIDraggable';
import { NodeFieldDraggableData } from 'features/dnd/types';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import {
MouseEvent,
memo,
@ -25,41 +23,43 @@ import {
} from 'react';
interface Props {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
nodeId: string;
fieldName: string;
isDraggable?: boolean;
kind: 'input' | 'output';
}
const FieldTitle = (props: Props) => {
const { nodeData, field, fieldTemplate, isDraggable = false } = props;
const { label } = field;
const { title, input } = fieldTemplate;
const { id: nodeId } = nodeData;
const { nodeId, fieldName, isDraggable = false, kind } = props;
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const field = useFieldData(nodeId, fieldName);
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title);
const [localTitle, setLocalTitle] = useState(
field?.label || fieldTemplate?.title || 'Unknown Field'
);
const draggableData: NodeFieldDraggableData | undefined = useMemo(
() =>
input !== 'connection' && isDraggable
field &&
fieldTemplate?.fieldKind === 'input' &&
fieldTemplate?.input !== 'connection' &&
isDraggable
? {
id: `${nodeId}-${field.name}`,
id: `${nodeId}-${fieldName}`,
payloadType: 'NODE_FIELD',
payload: { nodeId, field, fieldTemplate },
}
: undefined,
[field, fieldTemplate, input, isDraggable, nodeId]
[field, fieldName, fieldTemplate, isDraggable, nodeId]
);
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(
fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle })
);
setLocalTitle(newTitle || title);
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
setLocalTitle(newTitle || fieldTemplate?.title || 'Unknown Field');
},
[dispatch, nodeId, field.name, title]
[dispatch, nodeId, fieldName, fieldTemplate?.title]
);
const handleChange = useCallback((newTitle: string) => {
@ -68,8 +68,8 @@ const FieldTitle = (props: Props) => {
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || title);
}, [label, title]);
setLocalTitle(field?.label || fieldTemplate?.title || 'Unknown Field');
}, [field?.label, fieldTemplate?.title]);
return (
<Flex
@ -120,7 +120,7 @@ type EditableControlsProps = {
draggableData?: NodeFieldDraggableData;
};
function EditableControls(props: EditableControlsProps) {
const EditableControls = memo((props: EditableControlsProps) => {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
@ -158,4 +158,6 @@ function EditableControls(props: EditableControlsProps) {
cursor="text"
/>
);
}
});
EditableControls.displayName = 'EditableControls';

View File

@ -1,38 +1,53 @@
import { Flex, Text } from '@chakra-ui/react';
import {
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { FIELDS } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
isInputFieldTemplate,
isInputFieldValue,
} from 'features/nodes/types/types';
import { startCase } from 'lodash-es';
import { useMemo } from 'react';
interface Props {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
nodeId: string;
fieldName: string;
kind: 'input' | 'output';
}
const FieldTooltipContent = ({ field, fieldTemplate }: Props) => {
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const isInputTemplate = isInputFieldTemplate(fieldTemplate);
const fieldTitle = useMemo(() => {
if (isInputFieldValue(field)) {
if (field.label && fieldTemplate) {
return `${field.label} (${fieldTemplate.title})`;
}
if (field.label && !fieldTemplate) {
return field.label;
}
if (!field.label && fieldTemplate) {
return fieldTemplate.title;
}
return 'Unknown Field';
}
}, [field, fieldTemplate]);
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>
{isInputFieldValue(field) && field.label
? `${field.label} (${fieldTemplate.title})`
: fieldTemplate.title}
</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{fieldTemplate.description}
</Text>
<Text>Type: {FIELDS[fieldTemplate.type].title}</Text>
<Text sx={{ fontWeight: 600 }}>{fieldTitle}</Text>
{fieldTemplate && (
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{fieldTemplate.description}
</Text>
)}
{fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
</Flex>
);

View File

@ -1,27 +1,24 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
useDoesInputHaveValue,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, memo, useMemo } from 'react';
import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
nodeId: string;
fieldName: string;
}
const InputField = (props: Props) => {
const { nodeProps, nodeTemplate, field } = props;
const { id: nodeId } = nodeProps.data;
const InputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const {
isConnected,
@ -29,15 +26,10 @@ const InputField = (props: Props) => {
isConnectionStartField,
connectionError,
shouldDim,
} = useConnectionState({ nodeId, field, kind: 'input' });
const fieldTemplate = useMemo(
() => nodeTemplate.inputs[field.name],
[field.name, nodeTemplate.inputs]
);
} = useConnectionState({ nodeId, fieldName, kind: 'input' });
const isMissingInput = useMemo(() => {
if (!fieldTemplate) {
if (fieldTemplate?.fieldKind !== 'input') {
return false;
}
@ -49,18 +41,18 @@ const InputField = (props: Props) => {
return true;
}
if (!field.value && !isConnected && fieldTemplate.input === 'any') {
if (!doesFieldHaveValue && !isConnected && fieldTemplate.input === 'any') {
return true;
}
}, [fieldTemplate, isConnected, field.value]);
}, [fieldTemplate, isConnected, doesFieldHaveValue]);
if (!fieldTemplate) {
if (fieldTemplate?.fieldKind !== 'input') {
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
>
Unknown input: {field.name}
Unknown input: {fieldName}
</FormControl>
</InputFieldWrapper>
);
@ -82,10 +74,9 @@ const InputField = (props: Props) => {
<Tooltip
label={
<FieldTooltipContent
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -95,27 +86,18 @@ const InputField = (props: Props) => {
>
<FormLabel sx={{ mb: 0 }}>
<FieldTitle
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
nodeId={nodeId}
fieldName={fieldName}
kind="input"
isDraggable
/>
</FormLabel>
</Tooltip>
<InputFieldRenderer
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</FormControl>
{fieldTemplate.input !== 'direct' && (
<FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
handleType="target"
isConnectionInProgress={isConnectionInProgress}
@ -133,21 +115,25 @@ type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => (
<Flex
className="nopan"
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
w: 'full',
h: 'full',
}}
>
{children}
</Flex>
const InputFieldWrapper = memo(
({ shouldDim, children }: InputFieldWrapperProps) => (
<Flex
className="nopan"
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
w: 'full',
h: 'full',
}}
>
{children}
</Flex>
)
);
InputFieldWrapper.displayName = 'InputFieldWrapper';

View File

@ -1,11 +1,9 @@
import { Box } from '@chakra-ui/react';
import { memo } from 'react';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from '../../types/types';
useFieldData,
useFieldTemplate,
} from 'features/nodes/hooks/useNodeData';
import { memo } from 'react';
import BooleanInputField from './fieldTypes/BooleanInputField';
import ClipInputField from './fieldTypes/ClipInputField';
import CollectionInputField from './fieldTypes/CollectionInputField';
@ -29,33 +27,33 @@ import VaeInputField from './fieldTypes/VaeInputField';
import VaeModelInputField from './fieldTypes/VaeModelInputField';
type InputFieldProps = {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
nodeId: string;
fieldName: string;
};
// build an individual input element based on the schema
const InputFieldRenderer = (props: InputFieldProps) => {
const { nodeData, nodeTemplate, field, fieldTemplate } = props;
const { type } = field;
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
if (type === 'string' && fieldTemplate.type === 'string') {
if (fieldTemplate?.fieldKind === 'output') {
return <Box p={2}>Output field in input: {field?.type}</Box>;
}
if (field?.type === 'string' && fieldTemplate?.type === 'string') {
return (
<StringInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'boolean' && fieldTemplate.type === 'boolean') {
if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') {
return (
<BooleanInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -63,46 +61,32 @@ const InputFieldRenderer = (props: InputFieldProps) => {
}
if (
(type === 'integer' && fieldTemplate.type === 'integer') ||
(type === 'float' && fieldTemplate.type === 'float')
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(field?.type === 'float' && fieldTemplate?.type === 'float')
) {
return (
<NumberInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'enum' && fieldTemplate.type === 'enum') {
if (field?.type === 'enum' && fieldTemplate?.type === 'enum') {
return (
<EnumInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ImageField' && fieldTemplate.type === 'ImageField') {
if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') {
return (
<ImageInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') {
return (
<LatentsInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -110,68 +94,55 @@ const InputFieldRenderer = (props: InputFieldProps) => {
}
if (
type === 'ConditioningField' &&
fieldTemplate.type === 'ConditioningField'
field?.type === 'LatentsField' &&
fieldTemplate?.type === 'LatentsField'
) {
return (
<LatentsInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
) {
return (
<ConditioningInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'UNetField' && fieldTemplate.type === 'UNetField') {
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
return (
<UnetInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ClipField' && fieldTemplate.type === 'ClipField') {
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
return (
<ClipInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'VaeField' && fieldTemplate.type === 'VaeField') {
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
return (
<VaeInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ControlField' && fieldTemplate.type === 'ControlField') {
return (
<ControlInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') {
return (
<MainModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -179,35 +150,38 @@ const InputFieldRenderer = (props: InputFieldProps) => {
}
if (
type === 'SDXLRefinerModelField' &&
fieldTemplate.type === 'SDXLRefinerModelField'
field?.type === 'ControlField' &&
fieldTemplate?.type === 'ControlField'
) {
return (
<ControlInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
) {
return (
<MainModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLRefinerModelField' &&
fieldTemplate?.type === 'SDXLRefinerModelField'
) {
return (
<RefinerModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') {
return (
<VaeModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') {
return (
<LoRAModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -215,57 +189,48 @@ const InputFieldRenderer = (props: InputFieldProps) => {
}
if (
type === 'ControlNetModelField' &&
fieldTemplate.type === 'ControlNetModelField'
field?.type === 'VaeModelField' &&
fieldTemplate?.type === 'VaeModelField'
) {
return (
<VaeModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'LoRAModelField' &&
fieldTemplate?.type === 'LoRAModelField'
) {
return (
<LoRAModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlNetModelField' &&
fieldTemplate?.type === 'ControlNetModelField'
) {
return (
<ControlNetModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'Collection' && fieldTemplate.type === 'Collection') {
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
return (
<CollectionInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') {
return (
<CollectionItemInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ColorField' && fieldTemplate.type === 'ColorField') {
return (
<ColorInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') {
return (
<ImageCollectionInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
@ -273,20 +238,55 @@ const InputFieldRenderer = (props: InputFieldProps) => {
}
if (
type === 'SDXLMainModelField' &&
fieldTemplate.type === 'SDXLMainModelField'
field?.type === 'CollectionItem' &&
fieldTemplate?.type === 'CollectionItem'
) {
return (
<SDXLMainModelInputField
nodeData={nodeData}
nodeTemplate={nodeTemplate}
<CollectionItemInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ImageCollection' &&
fieldTemplate?.type === 'ImageCollection'
) {
return (
<ImageCollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
) {
return (
<SDXLMainModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return <Box p={2}>Unknown field type: {field?.type}</Box>;
};
export default memo(InputFieldRenderer);

View File

@ -1,39 +1,16 @@
import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
type Props = {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
nodeId: string;
fieldName: string;
};
const LinearViewField = ({
nodeData,
nodeTemplate,
field,
fieldTemplate,
}: Props) => {
// const dispatch = useAppDispatch();
// const handleRemoveField = useCallback(() => {
// dispatch(
// workflowExposedFieldRemoved({
// nodeId: nodeData.id,
// fieldName: field.name,
// })
// );
// }, [dispatch, field.name, nodeData.id]);
const LinearViewField = ({ nodeId, fieldName }: Props) => {
return (
<Flex
layerStyle="second"
@ -48,10 +25,9 @@ const LinearViewField = ({
<Tooltip
label={
<FieldTooltipContent
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -66,20 +42,10 @@ const LinearViewField = ({
mb: 0,
}}
>
<FieldTitle
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
</FormLabel>
</Tooltip>
<InputFieldRenderer
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
/>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</FormControl>
</Flex>
);

View File

@ -6,25 +6,19 @@ import {
Tooltip,
} from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldTemplate } from 'features/nodes/hooks/useNodeData';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
OutputFieldValue,
} from 'features/nodes/types/types';
import { PropsWithChildren, useMemo } from 'react';
import { NodeProps } from 'reactflow';
import { PropsWithChildren, memo } from 'react';
import FieldHandle from './FieldHandle';
import FieldTooltipContent from './FieldTooltipContent';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: OutputFieldValue;
nodeId: string;
fieldName: string;
}
const OutputField = (props: Props) => {
const { nodeTemplate, nodeProps, field } = props;
const OutputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'output');
const {
isConnected,
@ -32,20 +26,15 @@ const OutputField = (props: Props) => {
isConnectionStartField,
connectionError,
shouldDim,
} = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' });
} = useConnectionState({ nodeId, fieldName, kind: 'output' });
const fieldTemplate = useMemo(
() => nodeTemplate.outputs[field.name],
[field.name, nodeTemplate]
);
if (!fieldTemplate) {
if (fieldTemplate?.fieldKind !== 'output') {
return (
<OutputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'right', fontSize: 'sm' }}
>
Unknown output: {field.name}
Unknown output: {fieldName}
</FormControl>
</OutputFieldWrapper>
);
@ -57,10 +46,9 @@ const OutputField = (props: Props) => {
<Tooltip
label={
<FieldTooltipContent
nodeData={nodeProps.data}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
nodeId={nodeId}
fieldName={fieldName}
kind="output"
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
@ -75,9 +63,6 @@ const OutputField = (props: Props) => {
</FormControl>
</Tooltip>
<FieldHandle
nodeProps={nodeProps}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
handleType="source"
isConnectionInProgress={isConnectionInProgress}
@ -88,27 +73,28 @@ const OutputField = (props: Props) => {
);
};
export default OutputField;
export default memo(OutputField);
type OutputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const OutputFieldWrapper = ({
shouldDim,
children,
}: OutputFieldWrapperProps) => (
<Flex
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
}}
>
{children}
</Flex>
const OutputFieldWrapper = memo(
({ shouldDim, children }: OutputFieldWrapperProps) => (
<Flex
sx={{
position: 'relative',
minH: 8,
py: 0.5,
alignItems: 'center',
opacity: shouldDim ? 0.5 : 1,
transitionProperty: 'opacity',
transitionDuration: '0.1s',
}}
>
{children}
</Flex>
)
);
OutputFieldWrapper.displayName = 'OutputFieldWrapper';

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const dispatch = useAppDispatch();

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const dispatch = useAppDispatch();

View File

@ -19,8 +19,7 @@ const ControlNetModelInputFieldComponent = (
ControlNetModelInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const controlNetModel = field.value;
const dispatch = useAppDispatch();

View File

@ -11,8 +11,7 @@ import { FieldComponentProps } from './types';
const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => {
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();

View File

@ -19,8 +19,7 @@ const ImageCollectionInputFieldComponent = (
ImageCollectionInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
// const dispatch = useAppDispatch();

View File

@ -21,8 +21,7 @@ import { FieldComponentProps } from './types';
const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { currentData: imageDTO } = useGetImageDTOQuery(

View File

@ -21,8 +21,7 @@ const LoRAModelInputFieldComponent = (
LoRAModelInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const lora = field.value;
const dispatch = useAppDispatch();
const { data: loraModels } = useGetLoRAModelsQuery();

View File

@ -26,8 +26,7 @@ const MainModelInputFieldComponent = (
MainModelInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -23,8 +23,7 @@ const NumberInputFieldComponent = (
IntegerInputFieldTemplate | FloatInputFieldTemplate
>
) => {
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();
const [valueAsString, setValueAsString] = useState<string>(
String(field.value)

View File

@ -24,8 +24,7 @@ const RefinerModelInputFieldComponent = (
SDXLRefinerModelInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -27,8 +27,7 @@ const ModelInputFieldComponent = (
SDXLMainModelInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

View File

@ -12,8 +12,7 @@ import { FieldComponentProps } from './types';
const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
) => {
const { nodeData, field, fieldTemplate } = props;
const nodeId = nodeData.id;
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();
const handleValueChanged = useCallback(

View File

@ -20,8 +20,7 @@ const VaeModelInputFieldComponent = (
VaeModelInputFieldTemplate
>
) => {
const { nodeData, field } = props;
const nodeId = nodeData.id;
const { nodeId, field } = props;
const vae = field.value;
const dispatch = useAppDispatch();
const { data: vaeModels } = useGetVaeModelsQuery();

View File

@ -1,16 +1,13 @@
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
export type FieldComponentProps<
V extends InputFieldValue,
T extends InputFieldTemplate
> = {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
nodeId: string;
field: V;
fieldTemplate: T;
};

View File

@ -55,7 +55,11 @@ const CurrentImageNode = (props: NodeProps) => {
export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
<NodeWrapper nodeProps={props.nodeProps} width={384}>
<NodeWrapper
nodeId={props.nodeProps.data.id}
selected={props.nodeProps.selected}
width={384}
>
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { makeTemplateSelector } from 'features/nodes/store/util/makeTemplateSelector';
import { InvocationNodeData } from 'features/nodes/types/types';
import { memo, useMemo } from 'react';
import { NodeProps } from 'reactflow';
@ -7,18 +8,40 @@ import InvocationNode from '../Invocation/InvocationNode';
import UnknownNodeFallback from '../Invocation/UnknownNodeFallback';
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { data } = props;
const { type } = data;
const { data, selected } = props;
const { id: nodeId, type, isOpen, label } = data;
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
const hasTemplateSelector = useMemo(
() =>
createSelector(stateSelector, ({ nodes }) =>
Boolean(nodes.nodeTemplates[type])
),
[type]
);
const nodeTemplate = useAppSelector(templateSelector);
const nodeTemplate = useAppSelector(hasTemplateSelector);
if (!nodeTemplate) {
return <UnknownNodeFallback nodeProps={props} />;
return (
<UnknownNodeFallback
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
}
return <InvocationNode nodeProps={props} nodeTemplate={nodeTemplate} />;
return (
<InvocationNode
nodeId={nodeId}
isOpen={isOpen}
label={label}
type={type}
selected={selected}
/>
);
};
export default memo(InvocationNodeWrapper);

View File

@ -10,7 +10,7 @@ import NodeTitle from '../Invocation/NodeTitle';
import NodeWrapper from '../Invocation/NodeWrapper';
const NotesNode = (props: NodeProps<NotesNodeData>) => {
const { id: nodeId, data } = props;
const { id: nodeId, data, selected } = props;
const { notes, isOpen } = data;
const dispatch = useAppDispatch();
const handleChange = useCallback(
@ -21,7 +21,7 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
);
return (
<NodeWrapper nodeProps={props}>
<NodeWrapper nodeId={nodeId} selected={selected}>
<Flex
layerStyle="nodeHeader"
sx={{
@ -32,8 +32,8 @@ const NotesNode = (props: NodeProps<NotesNodeData>) => {
h: 8,
}}
>
<NodeCollapseButton nodeProps={props} />
<NodeTitle nodeData={props.data} title="Notes" />
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeId={nodeId} title="Notes" />
<Box minW={8} />
</Flex>
{isOpen && (

View File

@ -6,39 +6,11 @@ import {
TabPanels,
Tabs,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
node: lastSelectedNode,
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
import NodeDataInspector from './NodeDataInspector';
import NodeTemplateInspector from './NodeTemplateInspector';
const InspectorPanel = () => {
const { node, template } = useAppSelector(selector);
return (
<Flex
layerStyle="first"
@ -60,37 +32,10 @@ const InspectorPanel = () => {
<TabPanels>
<TabPanel>
{template ? (
<Flex
sx={{
flexDir: 'column',
alignItems: 'flex-start',
gap: 2,
h: 'full',
}}
>
<ImageMetadataJSON
jsonObject={template}
label="Node Template"
/>
</Flex>
) : (
<IAINoContentFallback
label={
node
? 'No template found for selected node'
: 'No node selected'
}
icon={null}
/>
)}
<NodeTemplateInspector />
</TabPanel>
<TabPanel>
{node ? (
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
) : (
<IAINoContentFallback label="No node selected" icon={null} />
)}
<NodeDataInspector />
</TabPanel>
</TabPanels>
</Tabs>

View File

@ -17,20 +17,20 @@ const selector = createSelector(
);
return {
node: lastSelectedNode,
data: lastSelectedNode?.data,
};
},
defaultSelectorOptions
);
const NodeDataInspector = () => {
const { node } = useAppSelector(selector);
const { data } = useAppSelector(selector);
return node ? (
<ImageMetadataJSON jsonObject={node.data} label="Node Data" />
) : (
<IAINoContentFallback label="No node data" icon={null} />
);
if (!data) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <ImageMetadataJSON jsonObject={data} label="Node Data" />;
};
export default memo(NodeDataInspector);

View File

@ -0,0 +1,40 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { memo } from 'react';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector);
if (!template) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <ImageMetadataJSON jsonObject={template} label="Node Template" />;
};
export default memo(NodeTemplateInspector);

View File

@ -6,14 +6,6 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AddFieldToLinearViewDropData } from 'features/dnd/types';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
isInvocationNode,
} from 'features/nodes/types/types';
import { forEach } from 'lodash-es';
import { memo } from 'react';
import LinearViewField from '../../fields/LinearViewField';
import ScrollableContent from '../ScrollableContent';
@ -21,41 +13,8 @@ import ScrollableContent from '../ScrollableContent';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const fields: {
nodeData: InvocationNodeData;
nodeTemplate: InvocationTemplate;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
}[] = [];
const { exposedFields } = nodes.workflow;
nodes.nodes.filter(isInvocationNode).forEach((node) => {
const nodeTemplate = nodes.nodeTemplates[node.data.type];
if (!nodeTemplate) {
return;
}
forEach(node.data.inputs, (field) => {
if (
!exposedFields.some(
(f) => f.nodeId === node.id && f.fieldName === field.name
)
) {
return;
}
const fieldTemplate = nodeTemplate.inputs[field.name];
if (!fieldTemplate) {
return;
}
fields.push({
nodeData: node.data,
nodeTemplate,
field,
fieldTemplate,
});
});
});
return {
fields,
fields: nodes.workflow.exposedFields,
};
},
defaultSelectorOptions
@ -89,13 +48,11 @@ const LinearTabContent = () => {
}}
>
{fields.length ? (
fields.map(({ nodeData, nodeTemplate, field, fieldTemplate }) => (
fields.map(({ nodeId, fieldName }) => (
<LinearViewField
key={field.id}
nodeData={nodeData}
nodeTemplate={nodeTemplate}
field={field}
fieldTemplate={fieldTemplate}
key={`${nodeId}-${fieldName}`}
nodeId={nodeId}
fieldName={fieldName}
/>
))
) : (

View File

@ -25,7 +25,9 @@ const ClearGraphButton = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
const cancelRef = useRef<HTMLButtonElement | null>(null);
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const nodesCount = useAppSelector(
(state: RootState) => state.nodes.nodes.length
);
const handleConfirmClear = useCallback(() => {
dispatch(nodeEditorReset());
@ -49,7 +51,7 @@ const ClearGraphButton = () => {
tooltip={t('nodes.clearGraph')}
aria-label={t('nodes.clearGraph')}
onClick={onOpen}
isDisabled={nodes.length === 0}
isDisabled={!nodesCount}
/>
<AlertDialog

View File

@ -8,7 +8,7 @@ import IAIIconButton, {
import { selectIsReadyNodes } from 'features/nodes/store/selectors';
import ProgressBar from 'features/system/components/ProgressBar';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
@ -18,7 +18,7 @@ interface InvokeButton
iconButton?: boolean;
}
export default function NodeInvokeButton(props: InvokeButton) {
const NodeInvokeButton = (props: InvokeButton) => {
const { iconButton = false, ...rest } = props;
const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector);
@ -92,4 +92,6 @@ export default function NodeInvokeButton(props: InvokeButton) {
</Box>
</Box>
);
}
};
export default memo(NodeInvokeButton);

View File

@ -1,11 +1,11 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSyncAlt } from 'react-icons/fa';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
export default function ReloadSchemaButton() {
const ReloadSchemaButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
@ -21,4 +21,6 @@ export default function ReloadSchemaButton() {
onClick={handleReloadSchema}
/>
);
}
};
export default memo(ReloadSchemaButton);

View File

@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { InputFieldValue, OutputFieldValue } from 'features/nodes/types/types';
import { useMemo } from 'react';
import { useFieldType } from './useNodeData';
const selectIsConnectionInProgress = createSelector(
stateSelector,
@ -12,23 +12,19 @@ const selectIsConnectionInProgress = createSelector(
nodes.connectionStartParams !== null
);
export type UseConnectionStateProps =
| {
nodeId: string;
field: InputFieldValue;
kind: 'input';
}
| {
nodeId: string;
field: OutputFieldValue;
kind: 'output';
};
export type UseConnectionStateProps = {
nodeId: string;
fieldName: string;
kind: 'input' | 'output';
};
export const useConnectionState = ({
nodeId,
field,
fieldName,
kind,
}: UseConnectionStateProps) => {
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo(
() =>
createSelector(stateSelector, ({ nodes }) =>
@ -37,23 +33,23 @@ export const useConnectionState = ({
return (
(kind === 'input' ? edge.target : edge.source) === nodeId &&
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) ===
field.name
fieldName
);
}).length
)
),
[field.name, kind, nodeId]
[fieldName, kind, nodeId]
);
const selectConnectionError = useMemo(
() =>
makeConnectionErrorSelector(
nodeId,
field.name,
fieldName,
kind === 'input' ? 'target' : 'source',
field.type
fieldType
),
[nodeId, field.name, field.type, kind]
[nodeId, fieldName, kind, fieldType]
);
const selectIsConnectionStartField = useMemo(
@ -61,12 +57,12 @@ export const useConnectionState = ({
createSelector(stateSelector, ({ nodes }) =>
Boolean(
nodes.connectionStartParams?.nodeId === nodeId &&
nodes.connectionStartParams?.handleId === field.name &&
nodes.connectionStartParams?.handleId === fieldName &&
nodes.connectionStartParams?.handleType ===
{ input: 'target', output: 'source' }[kind]
)
),
[field.name, kind, nodeId]
[fieldName, kind, nodeId]
);
const isConnected = useAppSelector(selectIsConnected);

View File

@ -0,0 +1,286 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map, some } from 'lodash-es';
import { useMemo } from 'react';
import { FOOTER_FIELDS, IMAGE_FIELDS } from '../types/constants';
import { isInvocationNode } from '../types/types';
const KIND_MAP = {
input: 'inputs' as const,
output: 'outputs' as const,
};
export const useNodeTemplate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
};
export const useNodeData = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
return node?.data;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeData = useAppSelector(selector);
return nodeData;
};
export const useFieldData = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data.inputs[fieldName];
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const fieldData = useAppSelector(selector);
return fieldData;
};
export const useFieldType = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldType = useAppSelector(selector);
return fieldType;
};
export const useFieldNames = (nodeId: string, kind: 'input' | 'output') => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return [];
}
return map(node.data[KIND_MAP[kind]], (field) => field.name).filter(
(fieldName) => fieldName !== 'is_intermediate'
);
},
defaultSelectorOptions
),
[kind, nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames;
};
export const useWithFooter = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
FOOTER_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const withFooter = useAppSelector(selector);
return withFooter;
};
export const useHasImageOutput = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return some(node.data.outputs, (output) =>
IMAGE_FIELDS.includes(output.type)
);
},
defaultSelectorOptions
),
[nodeId]
);
const hasImageOutput = useAppSelector(selector);
return hasImageOutput;
};
export const useIsIntermediate = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return Boolean(node.data.inputs.is_intermediate?.value);
},
defaultSelectorOptions
),
[nodeId]
);
const is_intermediate = useAppSelector(selector);
return is_intermediate;
};
export const useNodeLabel = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.label;
},
defaultSelectorOptions
),
[nodeId]
);
const label = useAppSelector(selector);
return label;
};
export const useNodeTemplateTitle = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = node
? nodes.nodeTemplates[node.data.type]
: undefined;
return nodeTemplate?.title;
},
defaultSelectorOptions
),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};
export const useFieldTemplate = (
nodeId: string,
fieldName: string,
kind: 'input' | 'output'
) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate?.[KIND_MAP[kind]][fieldName];
},
defaultSelectorOptions
),
[fieldName, kind, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate;
};
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
return Boolean(node?.data.inputs[fieldName]?.value);
},
defaultSelectorOptions
),
[fieldName, nodeId]
);
const doesFieldHaveValue = useAppSelector(selector);
return doesFieldHaveValue;
};

View File

@ -9,9 +9,13 @@ export const makeConnectionErrorSelector = (
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType: FieldType
fieldType?: FieldType
) =>
createSelector(stateSelector, (state) => {
if (!fieldType) {
return 'No field type';
}
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes;

View File

@ -6,6 +6,9 @@ export const NODE_WIDTH = 320;
export const NODE_MIN_WIDTH = 320;
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
export const FOOTER_FIELDS = IMAGE_FIELDS;
export const COLLECTION_TYPES: FieldType[] = [
'Collection',
'IntegerCollection',

View File

@ -457,12 +457,13 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
};
export const isInputFieldValue = (
field: InputFieldValue | OutputFieldValue
): field is InputFieldValue => field.fieldKind === 'input';
field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
export const isInputFieldTemplate = (
fieldTemplate: InputFieldTemplate | OutputFieldTemplate
): fieldTemplate is InputFieldTemplate => fieldTemplate.fieldKind === 'input';
fieldTemplate?: InputFieldTemplate | OutputFieldTemplate
): fieldTemplate is InputFieldTemplate =>
Boolean(fieldTemplate && fieldTemplate.fieldKind === 'input');
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
@ -632,20 +633,22 @@ export type NodeData =
export const isInvocationNode = (
node?: Node<NodeData>
): node is Node<InvocationNodeData> => node?.type === 'invocation';
): node is Node<InvocationNodeData> =>
Boolean(node && node.type === 'invocation');
export const isInvocationNodeData = (
node?: NodeData
): node is InvocationNodeData =>
!['notes', 'current_image'].includes(node?.type ?? '');
Boolean(node && !['notes', 'current_image'].includes(node.type));
export const isNotesNode = (
node?: Node<NodeData>
): node is Node<NotesNodeData> => node?.type === 'notes';
): node is Node<NotesNodeData> => Boolean(node && node.type === 'notes');
export const isProgressImageNode = (
node?: Node<NodeData>
): node is Node<CurrentImageNodeData> => node?.type === 'current_image';
): node is Node<CurrentImageNodeData> =>
Boolean(node && node.type === 'current_image');
export enum NodeStatus {
PENDING,

View File

@ -32,6 +32,7 @@ import {
MAIN_MODEL_LOADER,
MASK_BLUR,
MASK_COMBINE,
MASK_EDGE,
MASK_FROM_ALPHA,
MASK_RESIZE_DOWN,
MASK_RESIZE_UP,
@ -40,6 +41,10 @@ import {
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
SEAM_FIX_DENOISE_LATENTS,
SEAM_MASK_COMBINE,
SEAM_MASK_RESIZE_DOWN,
SEAM_MASK_RESIZE_UP,
} from './constants';
/**
@ -67,6 +72,12 @@ export const buildCanvasOutpaintGraph = (
shouldUseCpuNoise,
maskBlur,
maskBlurMethod,
seamSize,
seamBlur,
seamSteps,
seamStrength,
seamLowThreshold,
seamHighThreshold,
tileSize,
infillMethod,
clipSkip,
@ -130,6 +141,11 @@ export const buildCanvasOutpaintGraph = (
is_intermediate: true,
mask2: canvasMaskImage,
},
[SEAM_MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
@ -165,6 +181,25 @@ export const buildCanvasOutpaintGraph = (
denoising_start: 1 - strength,
denoising_end: 1,
},
[MASK_EDGE]: {
type: 'mask_edge',
id: MASK_EDGE,
is_intermediate: true,
edge_size: seamSize,
edge_blur: seamBlur,
low_threshold: seamLowThreshold,
high_threshold: seamHighThreshold,
},
[SEAM_FIX_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SEAM_FIX_DENOISE_LATENTS,
is_intermediate: true,
steps: seamSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - seamStrength,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
@ -333,12 +368,63 @@ export const buildCanvasOutpaintGraph = (
field: 'seed',
},
},
// Decode the result from Inpaint
// Seam Paint
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
},
// Decode the result from Inpaint
{
source: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
@ -348,7 +434,6 @@ export const buildCanvasOutpaintGraph = (
};
// Add Infill Nodes
if (infillMethod === 'patchmatch') {
graph.nodes[INPAINT_INFILL] = {
type: 'infill_patchmatch',
@ -378,6 +463,13 @@ export const buildCanvasOutpaintGraph = (
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[SEAM_MASK_RESIZE_UP] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
@ -399,6 +491,13 @@ export const buildCanvasOutpaintGraph = (
width: width,
height: height,
};
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation),
@ -440,6 +539,57 @@ export const buildCanvasOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Resize Results Down
{
source: {
@ -453,7 +603,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_BLUR,
node_id: MASK_RESIZE_UP,
field: 'image',
},
destination: {
@ -461,6 +611,16 @@ export const buildCanvasOutpaintGraph = (
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
},
{
source: {
node_id: INPAINT_INFILL,
@ -494,7 +654,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_DOWN,
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -525,7 +685,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_DOWN,
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -553,7 +713,6 @@ export const buildCanvasOutpaintGraph = (
};
graph.nodes[MASK_BLUR] = {
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
image: canvasMaskImage,
};
graph.edges.push(
@ -568,6 +727,47 @@ export const buildCanvasOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Color Correct The Inpainted Result
{
source: {
@ -591,7 +791,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_BLUR,
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
@ -622,7 +822,7 @@ export const buildCanvasOutpaintGraph = (
},
{
source: {
node_id: MASK_BLUR,
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {

View File

@ -29,6 +29,7 @@ import {
LATENTS_TO_IMAGE,
MASK_BLUR,
MASK_COMBINE,
MASK_EDGE,
MASK_FROM_ALPHA,
MASK_RESIZE_DOWN,
MASK_RESIZE_UP,
@ -40,6 +41,10 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SEAM_FIX_DENOISE_LATENTS,
SEAM_MASK_COMBINE,
SEAM_MASK_RESIZE_DOWN,
SEAM_MASK_RESIZE_UP,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -67,6 +72,12 @@ export const buildCanvasSDXLOutpaintGraph = (
shouldUseCpuNoise,
maskBlur,
maskBlurMethod,
seamSize,
seamBlur,
seamSteps,
seamStrength,
seamLowThreshold,
seamHighThreshold,
tileSize,
infillMethod,
} = state.generation;
@ -133,6 +144,11 @@ export const buildCanvasSDXLOutpaintGraph = (
is_intermediate: true,
mask2: canvasMaskImage,
},
[SEAM_MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
@ -170,6 +186,25 @@ export const buildCanvasSDXLOutpaintGraph = (
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[MASK_EDGE]: {
type: 'mask_edge',
id: MASK_EDGE,
is_intermediate: true,
edge_size: seamSize,
edge_blur: seamBlur,
low_threshold: seamLowThreshold,
high_threshold: seamHighThreshold,
},
[SEAM_FIX_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SEAM_FIX_DENOISE_LATENTS,
is_intermediate: true,
steps: seamSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - seamStrength,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
@ -347,12 +382,63 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'seed',
},
},
// Decode inpainted latents to image
// Seam Paint
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: SDXL_DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
},
// Decode inpainted latents to image
{
source: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
@ -392,6 +478,13 @@ export const buildCanvasSDXLOutpaintGraph = (
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[SEAM_MASK_RESIZE_UP] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
@ -413,6 +506,13 @@ export const buildCanvasSDXLOutpaintGraph = (
width: width,
height: height,
};
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation),
@ -454,6 +554,57 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Resize Results Down
{
source: {
@ -467,7 +618,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_BLUR,
node_id: MASK_RESIZE_UP,
field: 'image',
},
destination: {
@ -475,6 +626,16 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
},
{
source: {
node_id: INPAINT_INFILL,
@ -508,7 +669,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_DOWN,
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -539,7 +700,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_RESIZE_DOWN,
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
@ -567,7 +728,6 @@ export const buildCanvasSDXLOutpaintGraph = (
};
graph.nodes[MASK_BLUR] = {
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
image: canvasMaskImage,
};
graph.edges.push(
@ -582,6 +742,47 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image',
},
},
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Color Correct The Inpainted Result
{
source: {
@ -605,7 +806,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_BLUR,
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
@ -636,7 +837,7 @@ export const buildCanvasSDXLOutpaintGraph = (
},
{
source: {
node_id: MASK_BLUR,
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
@ -669,7 +870,7 @@ export const buildCanvasSDXLOutpaintGraph = (
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
addSDXLRefinerToGraph(state, graph, SEAM_FIX_DENOISE_LATENTS);
}
// optionally add custom VAE

View File

@ -18,8 +18,6 @@ export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';
export const CANVAS_OUTPUT = 'canvas_output';
export const INPAINT = 'inpaint';
export const INPAINT_SEAM_FIX = 'inpaint_seam_fix';
export const INPAINT_IMAGE = 'inpaint_image';
export const SCALED_INPAINT_IMAGE = 'scaled_inpaint_image';
export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up';
@ -27,10 +25,14 @@ export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down';
export const INPAINT_INFILL = 'inpaint_infill';
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
export const INPAINT_FINAL_IMAGE = 'inpaint_final_image';
export const SEAM_FIX_DENOISE_LATENTS = 'seam_fix_denoise_latents';
export const MASK_FROM_ALPHA = 'tomask';
export const MASK_EDGE = 'mask_edge';
export const MASK_BLUR = 'mask_blur';
export const MASK_COMBINE = 'mask_combine';
export const SEAM_MASK_COMBINE = 'seam_mask_combine';
export const SEAM_MASK_RESIZE_UP = 'seam_mask_resize_up';
export const SEAM_MASK_RESIZE_DOWN = 'seam_mask_resize_down';
export const MASK_RESIZE_UP = 'mask_resize_up';
export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const COLOR_CORRECT = 'color_correct';

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamBlur } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamBlur = () => {
const dispatch = useAppDispatch();
const seamBlur = useAppSelector(
(state: RootState) => state.generation.seamBlur
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamBlur')}
min={0}
max={64}
step={8}
sliderNumberInputProps={{ max: 512 }}
value={seamBlur}
onChange={(v) => {
dispatch(setSeamBlur(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamBlur(8));
}}
/>
);
};
export default memo(ParamSeamBlur);

View File

@ -0,0 +1,27 @@
import { Flex } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamSeamBlur from './ParamSeamBlur';
import ParamSeamSize from './ParamSeamSize';
import ParamSeamSteps from './ParamSeamSteps';
import ParamSeamStrength from './ParamSeamStrength';
import ParamSeamThreshold from './ParamSeamThreshold';
const ParamSeamPaintingCollapse = () => {
const { t } = useTranslation();
return (
<IAICollapse label={t('parameters.seamPaintingHeader')}>
<Flex sx={{ flexDirection: 'column', gap: 2, paddingBottom: 2 }}>
<ParamSeamSize />
<ParamSeamBlur />
<ParamSeamSteps />
<ParamSeamStrength />
<ParamSeamThreshold />
</Flex>
</IAICollapse>
);
};
export default memo(ParamSeamPaintingCollapse);

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamSize } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamSize = () => {
const dispatch = useAppDispatch();
const seamSize = useAppSelector(
(state: RootState) => state.generation.seamSize
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamSize')}
min={0}
max={128}
step={8}
sliderNumberInputProps={{ max: 512 }}
value={seamSize}
onChange={(v) => {
dispatch(setSeamSize(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamSize(16));
}}
/>
);
};
export default memo(ParamSeamSize);

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamSteps } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamSteps = () => {
const dispatch = useAppDispatch();
const seamSteps = useAppSelector(
(state: RootState) => state.generation.seamSteps
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamSteps')}
min={0}
max={100}
step={1}
sliderNumberInputProps={{ max: 999 }}
value={seamSteps}
onChange={(v) => {
dispatch(setSeamSteps(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamSteps(20));
}}
/>
);
};
export default memo(ParamSeamSteps);

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamStrength } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamStrength = () => {
const dispatch = useAppDispatch();
const seamStrength = useAppSelector(
(state: RootState) => state.generation.seamStrength
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamStrength')}
min={0}
max={1}
step={0.01}
sliderNumberInputProps={{ max: 999 }}
value={seamStrength}
onChange={(v) => {
dispatch(setSeamStrength(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamStrength(0.7));
}}
/>
);
};
export default memo(ParamSeamStrength);

View File

@ -0,0 +1,121 @@
import {
FormControl,
FormLabel,
HStack,
RangeSlider,
RangeSliderFilledTrack,
RangeSliderMark,
RangeSliderThumb,
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
setSeamHighThreshold,
setSeamLowThreshold,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
const ParamSeamThreshold = () => {
const dispatch = useAppDispatch();
const seamLowThreshold = useAppSelector(
(state: RootState) => state.generation.seamLowThreshold
);
const seamHighThreshold = useAppSelector(
(state: RootState) => state.generation.seamHighThreshold
);
const { t } = useTranslation();
const handleSeamThresholdChange = useCallback(
(v: number[]) => {
dispatch(setSeamLowThreshold(v[0] as number));
dispatch(setSeamHighThreshold(v[1] as number));
},
[dispatch]
);
const handleSeamThresholdReset = () => {
dispatch(setSeamLowThreshold(100));
dispatch(setSeamHighThreshold(200));
};
return (
<FormControl>
<FormLabel>{t('parameters.seamThreshold')}</FormLabel>
<HStack w="100%" gap={4} mt={-2}>
<RangeSlider
aria-label={[
t('parameters.seamLowThreshold'),
t('parameters.seamHighThreshold'),
]}
value={[seamLowThreshold, seamHighThreshold]}
min={0}
max={255}
step={1}
minStepsBetweenThumbs={1}
onChange={handleSeamThresholdChange}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={seamLowThreshold} placement="top" hasArrow>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={seamHighThreshold} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
}}
>
0
</RangeSliderMark>
<RangeSliderMark
value={0.392}
sx={{
insetInlineStart: '38.4% !important',
transform: 'translateX(-38.4%)',
}}
>
100
</RangeSliderMark>
<RangeSliderMark
value={0.784}
sx={{
insetInlineStart: '79.8% !important',
transform: 'translateX(-79.8%)',
}}
>
200
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
}}
>
255
</RangeSliderMark>
</RangeSlider>
<IAIIconButton
size="sm"
aria-label={t('accessibility.reset')}
tooltip={t('accessibility.reset')}
icon={<BiReset />}
onClick={handleSeamThresholdReset}
/>
</HStack>
</FormControl>
);
};
export default memo(ParamSeamThreshold);

View File

@ -37,6 +37,12 @@ export interface GenerationState {
scheduler: SchedulerParam;
maskBlur: number;
maskBlurMethod: MaskBlurMethodParam;
seamSize: number;
seamBlur: number;
seamSteps: number;
seamStrength: StrengthParam;
seamLowThreshold: number;
seamHighThreshold: number;
seed: SeedParam;
seedWeights: string;
shouldFitToWidthHeight: boolean;
@ -74,6 +80,12 @@ export const initialGenerationState: GenerationState = {
scheduler: 'euler',
maskBlur: 16,
maskBlurMethod: 'box',
seamSize: 16,
seamBlur: 8,
seamSteps: 20,
seamStrength: 0.7,
seamLowThreshold: 100,
seamHighThreshold: 200,
seed: 0,
seedWeights: '',
shouldFitToWidthHeight: true,
@ -200,6 +212,24 @@ export const generationSlice = createSlice({
setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => {
state.maskBlurMethod = action.payload;
},
setSeamSize: (state, action: PayloadAction<number>) => {
state.seamSize = action.payload;
},
setSeamBlur: (state, action: PayloadAction<number>) => {
state.seamBlur = action.payload;
},
setSeamSteps: (state, action: PayloadAction<number>) => {
state.seamSteps = action.payload;
},
setSeamStrength: (state, action: PayloadAction<number>) => {
state.seamStrength = action.payload;
},
setSeamLowThreshold: (state, action: PayloadAction<number>) => {
state.seamLowThreshold = action.payload;
},
setSeamHighThreshold: (state, action: PayloadAction<number>) => {
state.seamHighThreshold = action.payload;
},
setTileSize: (state, action: PayloadAction<number>) => {
state.tileSize = action.payload;
},
@ -306,6 +336,12 @@ export const {
setScheduler,
setMaskBlur,
setMaskBlurMethod,
setSeamSize,
setSeamBlur,
setSeamSteps,
setSeamStrength,
setSeamLowThreshold,
setSeamHighThreshold,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,

View File

@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
@ -22,6 +23,7 @@ export default function SDXLUnifiedCanvasTabParameters() {
<ParamNoiseCollapse />
<ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse />
<ParamSeamPaintingCollapse />
</>
);
}

View File

@ -6,6 +6,7 @@ import ParamControlNetCollapse from 'features/parameters/components/Parameters/C
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
@ -23,6 +24,7 @@ const UnifiedCanvasParameters = () => {
<ParamSymmetryCollapse />
<ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse />
<ParamSeamPaintingCollapse />
<ParamAdvancedCollapse />
</>
);

View File

@ -3,4 +3,7 @@ import { UIState } from './uiTypes';
/**
* UI slice persist denylist
*/
export const uiPersistDenylist: (keyof UIState)[] = ['shouldShowImageDetails'];
export const uiPersistDenylist: (keyof UIState)[] = [
'shouldShowImageDetails',
'globalContextMenuCloseTrigger',
];

View File

@ -20,6 +20,7 @@ export const initialUIState: UIState = {
shouldShowProgressInViewer: true,
shouldShowEmbeddingPicker: false,
favoriteSchedulers: [],
globalContextMenuCloseTrigger: 0,
};
export const uiSlice = createSlice({
@ -96,6 +97,9 @@ export const uiSlice = createSlice({
toggleEmbeddingPicker: (state) => {
state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker;
},
contextMenusClosed: (state) => {
state.globalContextMenuCloseTrigger += 1;
},
},
extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => {
@ -122,6 +126,7 @@ export const {
setShouldShowProgressInViewer,
favoriteSchedulersChanged,
toggleEmbeddingPicker,
contextMenusClosed,
} = uiSlice.actions;
export default uiSlice.reducer;

View File

@ -26,4 +26,5 @@ export interface UIState {
shouldShowProgressInViewer: boolean;
shouldShowEmbeddingPicker: boolean;
favoriteSchedulers: SchedulerParam[];
globalContextMenuCloseTrigger: number;
}

View File

@ -573,7 +573,7 @@ export type components = {
file: Blob;
};
/**
* Boolean Collection
* Boolean Primitive Collection
* @description A collection of boolean primitive values
*/
BooleanCollectionInvocation: {
@ -619,7 +619,7 @@ export type components = {
collection?: (boolean)[];
};
/**
* Boolean
* Boolean Primitive
* @description A boolean primitive value
*/
BooleanInvocation: {
@ -1002,7 +1002,7 @@ export type components = {
clip?: components["schemas"]["ClipField"];
};
/**
* Conditioning Collection
* Conditioning Primitive Collection
* @description A collection of conditioning tensor primitive values
*/
ConditioningCollectionInvocation: {
@ -1770,7 +1770,7 @@ export type components = {
field: string;
};
/**
* Float Collection
* Float Primitive Collection
* @description A collection of float primitive values
*/
FloatCollectionInvocation: {
@ -1816,7 +1816,7 @@ export type components = {
collection?: (number)[];
};
/**
* Float
* Float Primitive
* @description A float primitive value
*/
FloatInvocation: {
@ -2161,7 +2161,7 @@ export type components = {
channel?: "A" | "R" | "G" | "B";
};
/**
* Image Collection
* Image Primitive Collection
* @description A collection of image primitive values
*/
ImageCollectionInvocation: {
@ -3113,7 +3113,7 @@ export type components = {
seed?: number;
};
/**
* Integer Collection
* Integer Primitive Collection
* @description A collection of integer primitive values
*/
IntegerCollectionInvocation: {
@ -3159,7 +3159,7 @@ export type components = {
collection?: (number)[];
};
/**
* Integer
* Integer Primitive
* @description An integer primitive value
*/
IntegerInvocation: {
@ -3256,7 +3256,7 @@ export type components = {
item?: unknown;
};
/**
* Latents Collection
* Latents Primitive Collection
* @description A collection of latents tensor primitive values
*/
LatentsCollectionInvocation: {
@ -5786,7 +5786,7 @@ export type components = {
show_easing_plot?: boolean;
};
/**
* String Collection
* String Primitive Collection
* @description A collection of string primitive values
*/
StringCollectionInvocation: {
@ -5832,7 +5832,7 @@ export type components = {
collection?: (string)[];
};
/**
* String
* String Primitive
* @description A string primitive value
*/
StringInvocation: {
@ -6193,24 +6193,6 @@ export type components = {
ui_hidden: boolean;
ui_type?: components["schemas"]["UIType"];
};
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
@ -6223,6 +6205,24 @@ export type components = {
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;