mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bug/ip-adapter-calc-size
This commit is contained in:
commit
0fc14afcf0
@ -121,18 +121,6 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
|
|||||||
**Example Usage:**
|
**Example Usage:**
|
||||||

|

|
||||||
|
|
||||||
--------------------------------
|
|
||||||
### Enhance Image (simple adjustments)
|
|
||||||
|
|
||||||
**Description:** Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
|
||||||
|
|
||||||
Color inversion is toggled with a simple switch, while each of the four enhancer modes are activated by entering a value other than 1 in each corresponding input field. Values less than 1 will reduce the corresponding property, while values greater than 1 will enhance it.
|
|
||||||
|
|
||||||
**Node Link:** https://github.com/dwringer/image-enhance-node
|
|
||||||
|
|
||||||
**Example Usage:**
|
|
||||||

|
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Generative Grammar-Based Prompt Nodes
|
### Generative Grammar-Based Prompt Nodes
|
||||||
|
|
||||||
@ -153,16 +141,26 @@ This includes 3 Nodes:
|
|||||||
|
|
||||||
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
|
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
|
||||||
|
|
||||||
This includes 4 Nodes:
|
This includes 14 Nodes:
|
||||||
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
- *Adjust Image Hue Plus* - Rotate the hue of an image in one of several different color spaces.
|
||||||
|
- *Blend Latents/Noise (Masked)* - Use a mask to blend part of one latents tensor [including Noise outputs] into another. Can be used to "renoise" sections during a multi-stage [masked] denoising process.
|
||||||
|
- *Enhance Image* - Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
||||||
|
- *Equivalent Achromatic Lightness* - Calculates image lightness accounting for Helmholtz-Kohlrausch effect based on a method described by High, Green, and Nussbaum (2023).
|
||||||
|
- *Text to Mask (Clipseg)* - Input a prompt and an image to generate a mask representing areas of the image matched by the prompt.
|
||||||
|
- *Text to Mask Advanced (Clipseg)* - Output up to four prompt masks combined with logical "and", logical "or", or as separate channels of an RGBA image.
|
||||||
|
- *Image Layer Blend* - Perform a layered blend of two images using alpha compositing. Opacity of top layer is selectable, with optional mask and several different blend modes/color spaces.
|
||||||
- *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal.
|
- *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal.
|
||||||
|
- *Image Dilate or Erode* - Dilate or expand a mask (or any image!). This is equivalent to an expand/contract operation.
|
||||||
|
- *Image Value Thresholds* - Clip an image to pure black/white beyond specified thresholds.
|
||||||
- *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around.
|
- *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around.
|
||||||
- *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around.
|
- *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around.
|
||||||
|
- *Shadows/Highlights/Midtones* - Extract three masks (with adjustable hard or soft thresholds) representing shadows, midtones, and highlights regions of an image.
|
||||||
|
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
||||||
|
|
||||||
**Node Link:** https://github.com/dwringer/composition-nodes
|
**Node Link:** https://github.com/dwringer/composition-nodes
|
||||||
|
|
||||||
**Example Usage:**
|
**Nodes and Output Examples:**
|
||||||

|

|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Size Stepper Nodes
|
### Size Stepper Nodes
|
||||||
|
@ -332,6 +332,7 @@ class InvokeAiInstance:
|
|||||||
Configure the InvokeAI runtime directory
|
Configure the InvokeAI runtime directory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
auto_install = False
|
||||||
# set sys.argv to a consistent state
|
# set sys.argv to a consistent state
|
||||||
new_argv = [sys.argv[0]]
|
new_argv = [sys.argv[0]]
|
||||||
for i in range(1, len(sys.argv)):
|
for i in range(1, len(sys.argv)):
|
||||||
@ -340,13 +341,17 @@ class InvokeAiInstance:
|
|||||||
new_argv.append(el)
|
new_argv.append(el)
|
||||||
new_argv.append(sys.argv[i + 1])
|
new_argv.append(sys.argv[i + 1])
|
||||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||||
new_argv.append(el)
|
auto_install = True
|
||||||
sys.argv = new_argv
|
sys.argv = new_argv
|
||||||
|
|
||||||
|
import messages
|
||||||
import requests # to catch download exceptions
|
import requests # to catch download exceptions
|
||||||
from messages import introduction
|
|
||||||
|
|
||||||
introduction()
|
auto_install = auto_install or messages.user_wants_auto_configuration()
|
||||||
|
if auto_install:
|
||||||
|
sys.argv.append("--yes")
|
||||||
|
else:
|
||||||
|
messages.introduction()
|
||||||
|
|
||||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import os
|
|||||||
import platform
|
import platform
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import HTML, prompt
|
||||||
from prompt_toolkit.completion import PathCompleter
|
from prompt_toolkit.completion import PathCompleter
|
||||||
from prompt_toolkit.validation import Validator
|
from prompt_toolkit.validation import Validator
|
||||||
from rich import box, print
|
from rich import box, print
|
||||||
@ -65,17 +65,50 @@ def confirm_install(dest: Path) -> bool:
|
|||||||
if dest.exists():
|
if dest.exists():
|
||||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||||
dest_confirmed = Confirm.ask(
|
dest_confirmed = Confirm.ask(
|
||||||
":stop_sign: Are you sure you want to (re)install in this location?",
|
":stop_sign: (re)install in this location?",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"InvokeAI will be installed in {dest}")
|
print(f"InvokeAI will be installed in {dest}")
|
||||||
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
dest_confirmed = Confirm.ask("Use this location?", default=True)
|
||||||
console.line()
|
console.line()
|
||||||
|
|
||||||
return dest_confirmed
|
return dest_confirmed
|
||||||
|
|
||||||
|
|
||||||
|
def user_wants_auto_configuration() -> bool:
|
||||||
|
"""Prompt the user to choose between manual and auto configuration."""
|
||||||
|
console.rule("InvokeAI Configuration Section")
|
||||||
|
console.print(
|
||||||
|
Panel(
|
||||||
|
Group(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
|
||||||
|
"",
|
||||||
|
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
|
||||||
|
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
|
||||||
|
"",
|
||||||
|
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
box=box.MINIMAL,
|
||||||
|
padding=(1, 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
choice = (
|
||||||
|
prompt(
|
||||||
|
HTML("Choose <b><a></b>utomatic or <b><m></b>anual configuration [a/m] (a): "),
|
||||||
|
validator=Validator.from_callable(
|
||||||
|
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
or "a"
|
||||||
|
)
|
||||||
|
return choice.lower().startswith("a")
|
||||||
|
|
||||||
|
|
||||||
def dest_path(dest=None) -> Path:
|
def dest_path(dest=None) -> Path:
|
||||||
"""
|
"""
|
||||||
Prompt the user for the destination path and create the path
|
Prompt the user for the destination path and create the path
|
||||||
|
@ -156,8 +156,6 @@ async def import_model(
|
|||||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
print(f"DEBUG: prediction_type = {prediction_type}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||||
|
@ -91,6 +91,9 @@ class FieldDescriptions:
|
|||||||
board = "The board to save the image to"
|
board = "The board to save the image to"
|
||||||
image = "The image to process"
|
image = "The image to process"
|
||||||
tile_size = "Tile size"
|
tile_size = "Tile size"
|
||||||
|
inclusive_low = "The inclusive low value"
|
||||||
|
exclusive_high = "The exclusive high value"
|
||||||
|
decimal_places = "The number of decimal places to round to"
|
||||||
|
|
||||||
|
|
||||||
class Input(str, Enum):
|
class Input(str, Enum):
|
||||||
|
@ -65,13 +65,27 @@ class DivideInvocation(BaseInvocation):
|
|||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
|
|
||||||
low: int = InputField(default=0, description="The inclusive low value")
|
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
|
||||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
|
||||||
|
class RandomFloatInvocation(BaseInvocation):
|
||||||
|
"""Outputs a single random float"""
|
||||||
|
|
||||||
|
low: float = InputField(default=0.0, description=FieldDescriptions.inclusive_low)
|
||||||
|
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
|
||||||
|
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
|
random_float = np.random.uniform(self.low, self.high)
|
||||||
|
rounded_float = round(random_float, self.decimals)
|
||||||
|
return FloatOutput(value=rounded_float)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"float_to_int",
|
"float_to_int",
|
||||||
title="Float To Integer",
|
title="Float To Integer",
|
||||||
|
@ -241,7 +241,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||||
|
|
||||||
# CACHE
|
# CACHE
|
||||||
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from queue import Queue
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from threading import Lock
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
@ -7,22 +9,28 @@ from invokeai.app.services.invocation_cache.invocation_cache_common import Invoc
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(order=True)
|
||||||
|
class CachedItem:
|
||||||
|
invocation_output: BaseInvocationOutput = field(compare=False)
|
||||||
|
invocation_output_json: str = field(compare=False)
|
||||||
|
|
||||||
|
|
||||||
class MemoryInvocationCache(InvocationCacheBase):
|
class MemoryInvocationCache(InvocationCacheBase):
|
||||||
_cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
|
_cache: OrderedDict[Union[int, str], CachedItem]
|
||||||
_max_cache_size: int
|
_max_cache_size: int
|
||||||
_disabled: bool
|
_disabled: bool
|
||||||
_hits: int
|
_hits: int
|
||||||
_misses: int
|
_misses: int
|
||||||
_cache_ids: Queue
|
|
||||||
_invoker: Invoker
|
_invoker: Invoker
|
||||||
|
_lock: Lock
|
||||||
|
|
||||||
def __init__(self, max_cache_size: int = 0) -> None:
|
def __init__(self, max_cache_size: int = 0) -> None:
|
||||||
self._cache = dict()
|
self._cache = OrderedDict()
|
||||||
self._max_cache_size = max_cache_size
|
self._max_cache_size = max_cache_size
|
||||||
self._disabled = False
|
self._disabled = False
|
||||||
self._hits = 0
|
self._hits = 0
|
||||||
self._misses = 0
|
self._misses = 0
|
||||||
self._cache_ids = Queue()
|
self._lock = Lock()
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._invoker = invoker
|
self._invoker = invoker
|
||||||
@ -32,80 +40,87 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
||||||
|
|
||||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||||
if self._max_cache_size == 0 or self._disabled:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0 or self._disabled:
|
||||||
|
return None
|
||||||
item = self._cache.get(key, None)
|
item = self._cache.get(key, None)
|
||||||
if item is not None:
|
if item is not None:
|
||||||
self._hits += 1
|
self._hits += 1
|
||||||
return item[0]
|
self._cache.move_to_end(key)
|
||||||
self._misses += 1
|
return item.invocation_output
|
||||||
|
self._misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||||
if self._max_cache_size == 0 or self._disabled:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0 or self._disabled or key in self._cache:
|
||||||
|
return
|
||||||
|
# If the cache is full, we need to remove the least used
|
||||||
|
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||||
|
self._delete_oldest_access(number_to_delete)
|
||||||
|
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
||||||
|
|
||||||
if key not in self._cache:
|
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||||
self._cache[key] = (invocation_output, invocation_output.json())
|
number_to_delete = min(number_to_delete, len(self._cache))
|
||||||
self._cache_ids.put(key)
|
for _ in range(number_to_delete):
|
||||||
if self._cache_ids.qsize() > self._max_cache_size:
|
self._cache.popitem(last=False)
|
||||||
try:
|
|
||||||
self._cache.pop(self._cache_ids.get())
|
|
||||||
except KeyError:
|
|
||||||
# this means the cache_ids are somehow out of sync w/ the cache
|
|
||||||
pass
|
|
||||||
|
|
||||||
def delete(self, key: Union[int, str]) -> None:
|
def _delete(self, key: Union[int, str]) -> None:
|
||||||
if self._max_cache_size == 0:
|
if self._max_cache_size == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if key in self._cache:
|
if key in self._cache:
|
||||||
del self._cache[key]
|
del self._cache[key]
|
||||||
|
|
||||||
|
def delete(self, key: Union[int, str]) -> None:
|
||||||
|
with self._lock:
|
||||||
|
return self._delete(key)
|
||||||
|
|
||||||
def clear(self, *args, **kwargs) -> None:
|
def clear(self, *args, **kwargs) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
|
return
|
||||||
|
self._cache.clear()
|
||||||
|
self._misses = 0
|
||||||
|
self._hits = 0
|
||||||
|
|
||||||
self._cache.clear()
|
@staticmethod
|
||||||
self._cache_ids = Queue()
|
def create_key(invocation: BaseInvocation) -> int:
|
||||||
self._misses = 0
|
|
||||||
self._hits = 0
|
|
||||||
|
|
||||||
def create_key(self, invocation: BaseInvocation) -> int:
|
|
||||||
return hash(invocation.json(exclude={"id"}))
|
return hash(invocation.json(exclude={"id"}))
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
self._disabled = True
|
return
|
||||||
|
self._disabled = True
|
||||||
|
|
||||||
def enable(self) -> None:
|
def enable(self) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
self._disabled = False
|
return
|
||||||
|
self._disabled = False
|
||||||
|
|
||||||
def get_status(self) -> InvocationCacheStatus:
|
def get_status(self) -> InvocationCacheStatus:
|
||||||
return InvocationCacheStatus(
|
with self._lock:
|
||||||
hits=self._hits,
|
return InvocationCacheStatus(
|
||||||
misses=self._misses,
|
hits=self._hits,
|
||||||
enabled=not self._disabled and self._max_cache_size > 0,
|
misses=self._misses,
|
||||||
size=len(self._cache),
|
enabled=not self._disabled and self._max_cache_size > 0,
|
||||||
max_size=self._max_cache_size,
|
size=len(self._cache),
|
||||||
)
|
max_size=self._max_cache_size,
|
||||||
|
)
|
||||||
|
|
||||||
def _delete_by_match(self, to_match: str) -> None:
|
def _delete_by_match(self, to_match: str) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
|
return
|
||||||
keys_to_delete = set()
|
keys_to_delete = set()
|
||||||
for key, value_tuple in self._cache.items():
|
for key, cached_item in self._cache.items():
|
||||||
if to_match in value_tuple[1]:
|
if to_match in cached_item.invocation_output_json:
|
||||||
keys_to_delete.add(key)
|
keys_to_delete.add(key)
|
||||||
|
if not keys_to_delete:
|
||||||
if not keys_to_delete:
|
return
|
||||||
return
|
for key in keys_to_delete:
|
||||||
|
self._delete(key)
|
||||||
for key in keys_to_delete:
|
self._invoker.services.logger.debug(
|
||||||
self.delete(key)
|
f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}"
|
||||||
|
)
|
||||||
self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")
|
|
||||||
|
@ -70,7 +70,6 @@ def get_literal_fields(field) -> list[Any]:
|
|||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
|
|
||||||
Default_config_file = config.model_conf_path
|
Default_config_file = config.model_conf_path
|
||||||
SD_Configs = config.legacy_conf_path
|
SD_Configs = config.legacy_conf_path
|
||||||
|
|
||||||
@ -458,7 +457,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
)
|
)
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||||
begin_entry_at=0,
|
begin_entry_at=0,
|
||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
@ -651,8 +650,19 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
|||||||
return editApp.new_opts()
|
return editApp.new_opts()
|
||||||
|
|
||||||
|
|
||||||
|
def default_ramcache() -> float:
|
||||||
|
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||||
|
|
||||||
|
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
|
||||||
|
# So we adjust everthing down a bit.
|
||||||
|
return (
|
||||||
|
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
|
||||||
|
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||||
|
|
||||||
|
|
||||||
def default_startup_options(init_file: Path) -> Namespace:
|
def default_startup_options(init_file: Path) -> Namespace:
|
||||||
opts = InvokeAIAppConfig.get_config()
|
opts = InvokeAIAppConfig.get_config()
|
||||||
|
opts.ram = default_ramcache()
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@
|
|||||||
"githubLabel": "Github",
|
"githubLabel": "Github",
|
||||||
"hotkeysLabel": "Hotkeys",
|
"hotkeysLabel": "Hotkeys",
|
||||||
"imagePrompt": "Image Prompt",
|
"imagePrompt": "Image Prompt",
|
||||||
|
"imageFailedToLoad": "Unable to Load Image",
|
||||||
"img2img": "Image To Image",
|
"img2img": "Image To Image",
|
||||||
"langArabic": "العربية",
|
"langArabic": "العربية",
|
||||||
"langBrPortuguese": "Português do Brasil",
|
"langBrPortuguese": "Português do Brasil",
|
||||||
@ -79,7 +80,7 @@
|
|||||||
"lightMode": "Light Mode",
|
"lightMode": "Light Mode",
|
||||||
"linear": "Linear",
|
"linear": "Linear",
|
||||||
"load": "Load",
|
"load": "Load",
|
||||||
"loading": "Loading",
|
"loading": "Loading $t({{noun}})...",
|
||||||
"loadingInvokeAI": "Loading Invoke AI",
|
"loadingInvokeAI": "Loading Invoke AI",
|
||||||
"learnMore": "Learn More",
|
"learnMore": "Learn More",
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
@ -716,6 +717,7 @@
|
|||||||
"cannotConnectInputToInput": "Cannot connect input to input",
|
"cannotConnectInputToInput": "Cannot connect input to input",
|
||||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||||
"cannotConnectToSelf": "Cannot connect to self",
|
"cannotConnectToSelf": "Cannot connect to self",
|
||||||
|
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||||
"clipField": "Clip",
|
"clipField": "Clip",
|
||||||
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
||||||
"collection": "Collection",
|
"collection": "Collection",
|
||||||
@ -1442,6 +1444,8 @@
|
|||||||
"showCanvasDebugInfo": "Show Additional Canvas Info",
|
"showCanvasDebugInfo": "Show Additional Canvas Info",
|
||||||
"showGrid": "Show Grid",
|
"showGrid": "Show Grid",
|
||||||
"showHide": "Show/Hide",
|
"showHide": "Show/Hide",
|
||||||
|
"showResultsOn": "Show Results (On)",
|
||||||
|
"showResultsOff": "Show Results (Off)",
|
||||||
"showIntermediates": "Show Intermediates",
|
"showIntermediates": "Show Intermediates",
|
||||||
"snapToGrid": "Snap to Grid",
|
"snapToGrid": "Snap to Grid",
|
||||||
"undo": "Undo"
|
"undo": "Undo"
|
||||||
|
@ -5,7 +5,7 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { ReactNode, memo, useEffect, useMemo } from 'react';
|
import { ReactNode, memo, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { theme as invokeAITheme } from 'theme/theme';
|
import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
|
||||||
|
|
||||||
import '@fontsource-variable/inter';
|
import '@fontsource-variable/inter';
|
||||||
import { MantineProvider } from '@mantine/core';
|
import { MantineProvider } from '@mantine/core';
|
||||||
@ -39,7 +39,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<MantineProvider theme={mantineTheme}>
|
<MantineProvider theme={mantineTheme}>
|
||||||
<ChakraProvider theme={theme} colorModeManager={manager}>
|
<ChakraProvider
|
||||||
|
theme={theme}
|
||||||
|
colorModeManager={manager}
|
||||||
|
toastOptions={TOAST_OPTIONS}
|
||||||
|
>
|
||||||
{children}
|
{children}
|
||||||
</ChakraProvider>
|
</ChakraProvider>
|
||||||
</MantineProvider>
|
</MantineProvider>
|
||||||
|
@ -54,21 +54,6 @@ import { addModelSelectedListener } from './listeners/modelSelected';
|
|||||||
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
||||||
import { addDynamicPromptsListener } from './listeners/promptChanged';
|
import { addDynamicPromptsListener } from './listeners/promptChanged';
|
||||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||||
import {
|
|
||||||
addSessionCanceledFulfilledListener,
|
|
||||||
addSessionCanceledPendingListener,
|
|
||||||
addSessionCanceledRejectedListener,
|
|
||||||
} from './listeners/sessionCanceled';
|
|
||||||
import {
|
|
||||||
addSessionCreatedFulfilledListener,
|
|
||||||
addSessionCreatedPendingListener,
|
|
||||||
addSessionCreatedRejectedListener,
|
|
||||||
} from './listeners/sessionCreated';
|
|
||||||
import {
|
|
||||||
addSessionInvokedFulfilledListener,
|
|
||||||
addSessionInvokedPendingListener,
|
|
||||||
addSessionInvokedRejectedListener,
|
|
||||||
} from './listeners/sessionInvoked';
|
|
||||||
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||||
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||||
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||||
@ -86,6 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
|
|||||||
import { addTabChangedListener } from './listeners/tabChanged';
|
import { addTabChangedListener } from './listeners/tabChanged';
|
||||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||||
|
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -136,6 +122,7 @@ addEnqueueRequestedCanvasListener();
|
|||||||
addEnqueueRequestedNodes();
|
addEnqueueRequestedNodes();
|
||||||
addEnqueueRequestedLinear();
|
addEnqueueRequestedLinear();
|
||||||
addAnyEnqueuedListener();
|
addAnyEnqueuedListener();
|
||||||
|
addBatchEnqueuedListener();
|
||||||
|
|
||||||
// Canvas actions
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener();
|
addCanvasSavedToGalleryListener();
|
||||||
@ -175,21 +162,6 @@ addSessionRetrievalErrorEventListener();
|
|||||||
addInvocationRetrievalErrorEventListener();
|
addInvocationRetrievalErrorEventListener();
|
||||||
addSocketQueueItemStatusChangedEventListener();
|
addSocketQueueItemStatusChangedEventListener();
|
||||||
|
|
||||||
// Session Created
|
|
||||||
addSessionCreatedPendingListener();
|
|
||||||
addSessionCreatedFulfilledListener();
|
|
||||||
addSessionCreatedRejectedListener();
|
|
||||||
|
|
||||||
// Session Invoked
|
|
||||||
addSessionInvokedPendingListener();
|
|
||||||
addSessionInvokedFulfilledListener();
|
|
||||||
addSessionInvokedRejectedListener();
|
|
||||||
|
|
||||||
// Session Canceled
|
|
||||||
addSessionCanceledPendingListener();
|
|
||||||
addSessionCanceledFulfilledListener();
|
|
||||||
addSessionCanceledRejectedListener();
|
|
||||||
|
|
||||||
// ControlNet
|
// ControlNet
|
||||||
addControlNetImageProcessedListener();
|
addControlNetImageProcessedListener();
|
||||||
addControlNetAutoProcessListener();
|
addControlNetAutoProcessListener();
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
import { createStandaloneToast } from '@chakra-ui/react';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { get, truncate, upperFirst } from 'lodash-es';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import { TOAST_OPTIONS, theme } from 'theme/theme';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
const { toast } = createStandaloneToast({
|
||||||
|
theme: theme,
|
||||||
|
defaultOptions: TOAST_OPTIONS.defaultOptions,
|
||||||
|
});
|
||||||
|
|
||||||
|
export const addBatchEnqueuedListener = () => {
|
||||||
|
// success
|
||||||
|
startAppListening({
|
||||||
|
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||||
|
effect: async (action) => {
|
||||||
|
const response = action.payload;
|
||||||
|
const arg = action.meta.arg.originalArgs;
|
||||||
|
logger('queue').debug(
|
||||||
|
{ enqueueResult: parseify(response) },
|
||||||
|
'Batch enqueued'
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!toast.isActive('batch-queued')) {
|
||||||
|
toast({
|
||||||
|
id: 'batch-queued',
|
||||||
|
title: t('queue.batchQueued'),
|
||||||
|
description: t('queue.batchQueuedDesc', {
|
||||||
|
item_count: response.enqueued,
|
||||||
|
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
||||||
|
}),
|
||||||
|
duration: 1000,
|
||||||
|
status: 'success',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// error
|
||||||
|
startAppListening({
|
||||||
|
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
||||||
|
effect: async (action) => {
|
||||||
|
const response = action.payload;
|
||||||
|
const arg = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
|
if (!response) {
|
||||||
|
toast({
|
||||||
|
title: t('queue.batchFailedToQueue'),
|
||||||
|
status: 'error',
|
||||||
|
description: 'Unknown Error',
|
||||||
|
});
|
||||||
|
logger('queue').error(
|
||||||
|
{ batchConfig: parseify(arg), error: parseify(response) },
|
||||||
|
t('queue.batchFailedToQueue')
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zPydanticValidationError.safeParse(response);
|
||||||
|
if (result.success) {
|
||||||
|
result.data.data.detail.map((e) => {
|
||||||
|
toast({
|
||||||
|
id: 'batch-failed-to-queue',
|
||||||
|
title: truncate(upperFirst(e.msg), { length: 128 }),
|
||||||
|
status: 'error',
|
||||||
|
description: truncate(
|
||||||
|
`Path:
|
||||||
|
${e.loc.join('.')}`,
|
||||||
|
{ length: 128 }
|
||||||
|
),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
let detail = 'Unknown Error';
|
||||||
|
if (response.status === 403 && 'body' in response) {
|
||||||
|
detail = get(response, 'body.detail', 'Unknown Error');
|
||||||
|
} else if (response.status === 403 && 'error' in response) {
|
||||||
|
detail = get(response, 'error.detail', 'Unknown Error');
|
||||||
|
}
|
||||||
|
toast({
|
||||||
|
title: t('queue.batchFailedToQueue'),
|
||||||
|
status: 'error',
|
||||||
|
description: detail,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
logger('queue').error(
|
||||||
|
{ batchConfig: parseify(arg), error: parseify(response) },
|
||||||
|
t('queue.batchFailedToQueue')
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -25,7 +25,7 @@ export const addBoardIdSelectedListener = () => {
|
|||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const board_id = boardIdSelected.match(action)
|
const board_id = boardIdSelected.match(action)
|
||||||
? action.payload
|
? action.payload.boardId
|
||||||
: state.gallery.selectedBoardId;
|
: state.gallery.selectedBoardId;
|
||||||
|
|
||||||
const galleryView = galleryViewChanged.match(action)
|
const galleryView = galleryViewChanged.match(action)
|
||||||
@ -55,7 +55,12 @@ export const addBoardIdSelectedListener = () => {
|
|||||||
|
|
||||||
if (boardImagesData) {
|
if (boardImagesData) {
|
||||||
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||||
dispatch(imageSelected(firstImage ?? null));
|
const selectedImage = imagesSelectors.selectById(
|
||||||
|
boardImagesData,
|
||||||
|
action.payload.selectedImageName
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(imageSelected(selectedImage || firstImage || null));
|
||||||
} else {
|
} else {
|
||||||
// board has no images - deselect
|
// board has no images - deselect
|
||||||
dispatch(imageSelected(null));
|
dispatch(imageSelected(null));
|
||||||
|
@ -3,9 +3,9 @@ import { canvasImageToControlNet } from 'features/canvas/store/actions';
|
|||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { t } from 'i18next';
|
|
||||||
|
|
||||||
export const addCanvasImageToControlNetListener = () => {
|
export const addCanvasImageToControlNetListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -16,7 +16,7 @@ export const addCanvasImageToControlNetListener = () => {
|
|||||||
|
|
||||||
let blob;
|
let blob;
|
||||||
try {
|
try {
|
||||||
blob = await getBaseLayerBlob(state);
|
blob = await getBaseLayerBlob(state, true);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
log.error(String(err));
|
log.error(String(err));
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -36,10 +36,10 @@ export const addCanvasImageToControlNetListener = () => {
|
|||||||
file: new File([blob], 'savedCanvas.png', {
|
file: new File([blob], 'savedCanvas.png', {
|
||||||
type: 'image/png',
|
type: 'image/png',
|
||||||
}),
|
}),
|
||||||
image_category: 'mask',
|
image_category: 'control',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
crop_visible: true,
|
crop_visible: false,
|
||||||
postUploadAction: {
|
postUploadAction: {
|
||||||
type: 'TOAST',
|
type: 'TOAST',
|
||||||
toastOptions: { title: t('toast.canvasSentControlnetAssets') },
|
toastOptions: { title: t('toast.canvasSentControlnetAssets') },
|
||||||
|
@ -3,9 +3,9 @@ import { canvasMaskToControlNet } from 'features/canvas/store/actions';
|
|||||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { t } from 'i18next';
|
|
||||||
|
|
||||||
export const addCanvasMaskToControlNetListener = () => {
|
export const addCanvasMaskToControlNetListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -50,7 +50,7 @@ export const addCanvasMaskToControlNetListener = () => {
|
|||||||
image_category: 'mask',
|
image_category: 'mask',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
crop_visible: true,
|
crop_visible: false,
|
||||||
postUploadAction: {
|
postUploadAction: {
|
||||||
type: 'TOAST',
|
type: 'TOAST',
|
||||||
toastOptions: { title: t('toast.maskSentControlnetAssets') },
|
toastOptions: { title: t('toast.maskSentControlnetAssets') },
|
||||||
|
@ -12,8 +12,6 @@ import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGeneratio
|
|||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||||
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
@ -140,8 +138,6 @@ export const addEnqueueRequestedCanvasListener = () => {
|
|||||||
const enqueueResult = await req.unwrap();
|
const enqueueResult = await req.unwrap();
|
||||||
req.reset();
|
req.reset();
|
||||||
|
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
|
||||||
|
|
||||||
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
||||||
|
|
||||||
// Prep the canvas staging area if it is not yet initialized
|
// Prep the canvas staging area if it is not yet initialized
|
||||||
@ -158,28 +154,8 @@ export const addEnqueueRequestedCanvasListener = () => {
|
|||||||
|
|
||||||
// Associate the session with the canvas session ID
|
// Associate the session with the canvas session ID
|
||||||
dispatch(canvasBatchIdAdded(batchId));
|
dispatch(canvasBatchIdAdded(batchId));
|
||||||
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchQueued'),
|
|
||||||
description: t('queue.batchQueuedDesc', {
|
|
||||||
item_count: enqueueResult.enqueued,
|
|
||||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
} catch {
|
} catch {
|
||||||
log.error(
|
// no-op
|
||||||
{ batchConfig: parseify(batchConfig) },
|
|
||||||
t('queue.batchFailedToQueue')
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchFailedToQueue'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { enqueueRequested } from 'app/store/actions';
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||||
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
||||||
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
||||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
@ -18,7 +14,6 @@ export const addEnqueueRequestedLinear = () => {
|
|||||||
(action.payload.tabName === 'txt2img' ||
|
(action.payload.tabName === 'txt2img' ||
|
||||||
action.payload.tabName === 'img2img'),
|
action.payload.tabName === 'img2img'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
const log = logger('queue');
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const model = state.generation.model;
|
const model = state.generation.model;
|
||||||
const { prepend } = action.payload;
|
const { prepend } = action.payload;
|
||||||
@ -41,38 +36,12 @@ export const addEnqueueRequestedLinear = () => {
|
|||||||
|
|
||||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||||
|
|
||||||
try {
|
const req = dispatch(
|
||||||
const req = dispatch(
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
fixedCacheKey: 'enqueueBatch',
|
||||||
fixedCacheKey: 'enqueueBatch',
|
})
|
||||||
})
|
);
|
||||||
);
|
req.reset();
|
||||||
const enqueueResult = await req.unwrap();
|
|
||||||
req.reset();
|
|
||||||
|
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchQueued'),
|
|
||||||
description: t('queue.batchQueuedDesc', {
|
|
||||||
item_count: enqueueResult.enqueued,
|
|
||||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
} catch {
|
|
||||||
log.error(
|
|
||||||
{ batchConfig: parseify(batchConfig) },
|
|
||||||
t('queue.batchFailedToQueue')
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchFailedToQueue'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { enqueueRequested } from 'app/store/actions';
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import { BatchConfig } from 'services/api/types';
|
import { BatchConfig } from 'services/api/types';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
@ -13,9 +9,7 @@ export const addEnqueueRequestedNodes = () => {
|
|||||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||||
enqueueRequested.match(action) && action.payload.tabName === 'nodes',
|
enqueueRequested.match(action) && action.payload.tabName === 'nodes',
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
const log = logger('queue');
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { prepend } = action.payload;
|
|
||||||
const graph = buildNodesGraph(state.nodes);
|
const graph = buildNodesGraph(state.nodes);
|
||||||
const batchConfig: BatchConfig = {
|
const batchConfig: BatchConfig = {
|
||||||
batch: {
|
batch: {
|
||||||
@ -25,38 +19,12 @@ export const addEnqueueRequestedNodes = () => {
|
|||||||
prepend: action.payload.prepend,
|
prepend: action.payload.prepend,
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
const req = dispatch(
|
||||||
const req = dispatch(
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
fixedCacheKey: 'enqueueBatch',
|
||||||
fixedCacheKey: 'enqueueBatch',
|
})
|
||||||
})
|
);
|
||||||
);
|
req.reset();
|
||||||
const enqueueResult = await req.unwrap();
|
|
||||||
req.reset();
|
|
||||||
|
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchQueued'),
|
|
||||||
description: t('queue.batchQueuedDesc', {
|
|
||||||
item_count: enqueueResult.enqueued,
|
|
||||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
} catch {
|
|
||||||
log.error(
|
|
||||||
{ batchConfig: parseify(batchConfig) },
|
|
||||||
'Failed to enqueue batch'
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchFailedToQueue'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,44 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { serializeError } from 'serialize-error';
|
|
||||||
import { sessionCanceled } from 'services/api/thunks/session';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
|
|
||||||
export const addSessionCanceledPendingListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionCanceled.pending,
|
|
||||||
effect: () => {
|
|
||||||
//
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
export const addSessionCanceledFulfilledListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionCanceled.fulfilled,
|
|
||||||
effect: (action) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { session_id } = action.meta.arg;
|
|
||||||
log.debug({ session_id }, `Session canceled (${session_id})`);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
export const addSessionCanceledRejectedListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionCanceled.rejected,
|
|
||||||
effect: (action) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { session_id } = action.meta.arg;
|
|
||||||
if (action.payload) {
|
|
||||||
const { error } = action.payload;
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
session_id,
|
|
||||||
error: serializeError(error),
|
|
||||||
},
|
|
||||||
`Problem canceling session`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,45 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { serializeError } from 'serialize-error';
|
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
|
|
||||||
export const addSessionCreatedPendingListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionCreated.pending,
|
|
||||||
effect: () => {
|
|
||||||
//
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
export const addSessionCreatedFulfilledListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionCreated.fulfilled,
|
|
||||||
effect: (action) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const session = action.payload;
|
|
||||||
log.debug(
|
|
||||||
{ session: parseify(session) },
|
|
||||||
`Session created (${session.id})`
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
export const addSessionCreatedRejectedListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionCreated.rejected,
|
|
||||||
effect: (action) => {
|
|
||||||
const log = logger('session');
|
|
||||||
if (action.payload) {
|
|
||||||
const { error, status } = action.payload;
|
|
||||||
const graph = parseify(action.meta.arg);
|
|
||||||
log.error(
|
|
||||||
{ graph, status, error: serializeError(error) },
|
|
||||||
`Problem creating session`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,44 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { serializeError } from 'serialize-error';
|
|
||||||
import { sessionInvoked } from 'services/api/thunks/session';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
|
|
||||||
export const addSessionInvokedPendingListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionInvoked.pending,
|
|
||||||
effect: () => {
|
|
||||||
//
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
export const addSessionInvokedFulfilledListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionInvoked.fulfilled,
|
|
||||||
effect: (action) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { session_id } = action.meta.arg;
|
|
||||||
log.debug({ session_id }, `Session invoked (${session_id})`);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
export const addSessionInvokedRejectedListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionInvoked.rejected,
|
|
||||||
effect: (action) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { session_id } = action.meta.arg;
|
|
||||||
if (action.payload) {
|
|
||||||
const { error } = action.payload;
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
session_id,
|
|
||||||
error: serializeError(error),
|
|
||||||
},
|
|
||||||
`Problem invoking session`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -81,9 +81,32 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
|
|
||||||
// If auto-switch is enabled, select the new image
|
// If auto-switch is enabled, select the new image
|
||||||
if (shouldAutoSwitch) {
|
if (shouldAutoSwitch) {
|
||||||
// if auto-add is enabled, switch the board as the image comes in
|
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
|
||||||
dispatch(galleryViewChanged('images'));
|
if (gallery.galleryView !== 'images') {
|
||||||
dispatch(boardIdSelected(imageDTO.board_id ?? 'none'));
|
dispatch(galleryViewChanged('images'));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
imageDTO.board_id &&
|
||||||
|
imageDTO.board_id !== gallery.selectedBoardId
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
boardIdSelected({
|
||||||
|
boardId: imageDTO.board_id,
|
||||||
|
selectedImageName: imageDTO.image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
|
||||||
|
dispatch(
|
||||||
|
boardIdSelected({
|
||||||
|
boardId: 'none',
|
||||||
|
selectedImageName: imageDTO.image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(imageSelected(imageDTO));
|
dispatch(imageSelected(imageDTO));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
|||||||
queueApi.util.invalidateTags([
|
queueApi.util.invalidateTags([
|
||||||
'CurrentSessionQueueItem',
|
'CurrentSessionQueueItem',
|
||||||
'NextSessionQueueItem',
|
'NextSessionQueueItem',
|
||||||
|
'InvocationCacheStatus',
|
||||||
{ type: 'SessionQueueItem', id: item_id },
|
{ type: 'SessionQueueItem', id: item_id },
|
||||||
{ type: 'SessionQueueItemDTO', id: item_id },
|
{ type: 'SessionQueueItemDTO', id: item_id },
|
||||||
{ type: 'BatchStatus', id: queue_batch_id },
|
{ type: 'BatchStatus', id: queue_batch_id },
|
||||||
|
@ -1,54 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { AppThunkDispatch } from 'app/store/store';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
|
||||||
import { BatchConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const enqueueBatch = async (
|
|
||||||
batchConfig: BatchConfig,
|
|
||||||
dispatch: AppThunkDispatch
|
|
||||||
) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { prepend } = batchConfig;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const req = dispatch(
|
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
|
||||||
fixedCacheKey: 'enqueueBatch',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
const enqueueResult = await req.unwrap();
|
|
||||||
req.reset();
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
|
||||||
fixedCacheKey: 'resumeProcessor',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchQueued'),
|
|
||||||
description: t('queue.batchQueuedDesc', {
|
|
||||||
item_count: enqueueResult.enqueued,
|
|
||||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
} catch {
|
|
||||||
log.error(
|
|
||||||
{ batchConfig: parseify(batchConfig) },
|
|
||||||
t('queue.batchFailedToQueue')
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
title: t('queue.batchFailedToQueue'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
@ -1,18 +1,9 @@
|
|||||||
import { chakra, ChakraProps } from '@chakra-ui/react';
|
import { Box, ChakraProps } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { RgbaColorPicker } from 'react-colorful';
|
import { RgbaColorPicker } from 'react-colorful';
|
||||||
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
|
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
|
||||||
|
|
||||||
type IAIColorPickerProps = Omit<ColorPickerBaseProps<RgbaColor>, 'color'> &
|
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
|
||||||
ChakraProps & {
|
|
||||||
pickerColor: RgbaColor;
|
|
||||||
styleClass?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ChakraRgbaColorPicker = chakra(RgbaColorPicker, {
|
|
||||||
baseStyle: { paddingInline: 4 },
|
|
||||||
shouldForwardProp: (prop) => !['pickerColor'].includes(prop),
|
|
||||||
});
|
|
||||||
|
|
||||||
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||||
width: 6,
|
width: 6,
|
||||||
@ -20,19 +11,17 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
|||||||
borderColor: 'base.100',
|
borderColor: 'base.100',
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
const sx = {
|
||||||
const { styleClass = '', ...rest } = props;
|
'.react-colorful__hue-pointer': colorPickerStyles,
|
||||||
|
'.react-colorful__saturation-pointer': colorPickerStyles,
|
||||||
|
'.react-colorful__alpha-pointer': colorPickerStyles,
|
||||||
|
};
|
||||||
|
|
||||||
|
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||||
return (
|
return (
|
||||||
<ChakraRgbaColorPicker
|
<Box sx={sx}>
|
||||||
sx={{
|
<RgbaColorPicker {...props} />
|
||||||
'.react-colorful__hue-pointer': colorPickerStyles,
|
</Box>
|
||||||
'.react-colorful__saturation-pointer': colorPickerStyles,
|
|
||||||
'.react-colorful__alpha-pointer': colorPickerStyles,
|
|
||||||
}}
|
|
||||||
className={styleClass}
|
|
||||||
{...rest}
|
|
||||||
/>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -81,3 +81,38 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type IAINoImageFallbackWithSpinnerProps = FlexProps & {
|
||||||
|
label?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const IAINoContentFallbackWithSpinner = (
|
||||||
|
props: IAINoImageFallbackWithSpinnerProps
|
||||||
|
) => {
|
||||||
|
const { sx, ...rest } = props;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
flexDir: 'column',
|
||||||
|
gap: 2,
|
||||||
|
userSelect: 'none',
|
||||||
|
opacity: 0.7,
|
||||||
|
color: 'base.700',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.500',
|
||||||
|
},
|
||||||
|
...sx,
|
||||||
|
}}
|
||||||
|
{...rest}
|
||||||
|
>
|
||||||
|
<Spinner size="xl" />
|
||||||
|
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
@ -139,6 +139,11 @@ const IAICanvas = () => {
|
|||||||
const { handleDragStart, handleDragMove, handleDragEnd } =
|
const { handleDragStart, handleDragMove, handleDragEnd } =
|
||||||
useCanvasDragMove();
|
useCanvasDragMove();
|
||||||
|
|
||||||
|
const handleContextMenu = useCallback(
|
||||||
|
(e: KonvaEventObject<MouseEvent>) => e.evt.preventDefault(),
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!containerRef.current) {
|
if (!containerRef.current) {
|
||||||
return;
|
return;
|
||||||
@ -205,9 +210,7 @@ const IAICanvas = () => {
|
|||||||
onDragStart={handleDragStart}
|
onDragStart={handleDragStart}
|
||||||
onDragMove={handleDragMove}
|
onDragMove={handleDragMove}
|
||||||
onDragEnd={handleDragEnd}
|
onDragEnd={handleDragEnd}
|
||||||
onContextMenu={(e: KonvaEventObject<MouseEvent>) =>
|
onContextMenu={handleContextMenu}
|
||||||
e.evt.preventDefault()
|
|
||||||
}
|
|
||||||
onWheel={handleWheel}
|
onWheel={handleWheel}
|
||||||
draggable={(tool === 'move' || isStaging) && !isModifyingBoundingBox}
|
draggable={(tool === 'move' || isStaging) && !isModifyingBoundingBox}
|
||||||
>
|
>
|
||||||
@ -223,7 +226,11 @@ const IAICanvas = () => {
|
|||||||
>
|
>
|
||||||
<IAICanvasObjectRenderer />
|
<IAICanvasObjectRenderer />
|
||||||
</Layer>
|
</Layer>
|
||||||
<Layer id="mask" visible={isMaskEnabled} listening={false}>
|
<Layer
|
||||||
|
id="mask"
|
||||||
|
visible={isMaskEnabled && !isStaging}
|
||||||
|
listening={false}
|
||||||
|
>
|
||||||
<IAICanvasMaskLines visible={true} listening={false} />
|
<IAICanvasMaskLines visible={true} listening={false} />
|
||||||
<IAICanvasMaskCompositer listening={false} />
|
<IAICanvasMaskCompositer listening={false} />
|
||||||
</Layer>
|
</Layer>
|
||||||
|
@ -1,26 +1,27 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { Image, Rect } from 'react-konva';
|
import { memo } from 'react';
|
||||||
|
import { Image } from 'react-konva';
|
||||||
|
import { $authToken } from 'services/api/client';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import useImage from 'use-image';
|
import useImage from 'use-image';
|
||||||
import { CanvasImage } from '../store/canvasTypes';
|
import { CanvasImage } from '../store/canvasTypes';
|
||||||
import { $authToken } from 'services/api/client';
|
import IAICanvasImageErrorFallback from './IAICanvasImageErrorFallback';
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
type IAICanvasImageProps = {
|
type IAICanvasImageProps = {
|
||||||
canvasImage: CanvasImage;
|
canvasImage: CanvasImage;
|
||||||
};
|
};
|
||||||
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||||
const { width, height, x, y, imageName } = props.canvasImage;
|
const { x, y, imageName } = props.canvasImage;
|
||||||
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
|
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
|
||||||
imageName ?? skipToken
|
imageName ?? skipToken
|
||||||
);
|
);
|
||||||
const [image] = useImage(
|
const [image, status] = useImage(
|
||||||
imageDTO?.image_url ?? '',
|
imageDTO?.image_url ?? '',
|
||||||
$authToken.get() ? 'use-credentials' : 'anonymous'
|
$authToken.get() ? 'use-credentials' : 'anonymous'
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isError) {
|
if (isError || status === 'failed') {
|
||||||
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
return <IAICanvasImageErrorFallback canvasImage={props.canvasImage} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return <Image x={x} y={y} image={image} listening={false} />;
|
return <Image x={x} y={y} image={image} listening={false} />;
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
import { useColorModeValue, useToken } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { Group, Rect, Text } from 'react-konva';
|
||||||
|
import { CanvasImage } from '../store/canvasTypes';
|
||||||
|
|
||||||
|
type IAICanvasImageErrorFallbackProps = {
|
||||||
|
canvasImage: CanvasImage;
|
||||||
|
};
|
||||||
|
const IAICanvasImageErrorFallback = ({
|
||||||
|
canvasImage,
|
||||||
|
}: IAICanvasImageErrorFallbackProps) => {
|
||||||
|
const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] =
|
||||||
|
useToken('colors', ['base.400', 'base.500', 'base.700', 'base.900']);
|
||||||
|
const errorColor = useColorModeValue(errorColorLight, errorColorDark);
|
||||||
|
const fontColor = useColorModeValue(fontColorLight, fontColorDark);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
return (
|
||||||
|
<Group>
|
||||||
|
<Rect
|
||||||
|
x={canvasImage.x}
|
||||||
|
y={canvasImage.y}
|
||||||
|
width={canvasImage.width}
|
||||||
|
height={canvasImage.height}
|
||||||
|
fill={errorColor}
|
||||||
|
/>
|
||||||
|
<Text
|
||||||
|
x={canvasImage.x}
|
||||||
|
y={canvasImage.y}
|
||||||
|
width={canvasImage.width}
|
||||||
|
height={canvasImage.height}
|
||||||
|
align="center"
|
||||||
|
verticalAlign="middle"
|
||||||
|
fontFamily='"Inter Variable", sans-serif'
|
||||||
|
fontSize={canvasImage.width / 16}
|
||||||
|
fontStyle="600"
|
||||||
|
text={t('common.imageFailedToLoad')}
|
||||||
|
fill={fontColor}
|
||||||
|
/>
|
||||||
|
</Group>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAICanvasImageErrorFallback);
|
@ -3,10 +3,9 @@ import { useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { GroupConfig } from 'konva/lib/Group';
|
import { GroupConfig } from 'konva/lib/Group';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { memo } from 'react';
|
||||||
import { Group, Rect } from 'react-konva';
|
import { Group, Rect } from 'react-konva';
|
||||||
import IAICanvasImage from './IAICanvasImage';
|
import IAICanvasImage from './IAICanvasImage';
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[canvasSelector],
|
[canvasSelector],
|
||||||
@ -15,11 +14,11 @@ const selector = createSelector(
|
|||||||
layerState,
|
layerState,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
boundingBoxCoordinates: { x, y },
|
boundingBoxCoordinates: stageBoundingBoxCoordinates,
|
||||||
boundingBoxDimensions: { width, height },
|
boundingBoxDimensions: stageBoundingBoxDimensions,
|
||||||
} = canvas;
|
} = canvas;
|
||||||
|
|
||||||
const { selectedImageIndex, images } = layerState.stagingArea;
|
const { selectedImageIndex, images, boundingBox } = layerState.stagingArea;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
currentStagingAreaImage:
|
currentStagingAreaImage:
|
||||||
@ -30,10 +29,10 @@ const selector = createSelector(
|
|||||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
x,
|
x: boundingBox?.x ?? stageBoundingBoxCoordinates.x,
|
||||||
y,
|
y: boundingBox?.y ?? stageBoundingBoxCoordinates.y,
|
||||||
width,
|
width: boundingBox?.width ?? stageBoundingBoxDimensions.width,
|
||||||
height,
|
height: boundingBox?.height ?? stageBoundingBoxDimensions.height,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -14,6 +14,7 @@ import {
|
|||||||
|
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -23,8 +24,8 @@ import {
|
|||||||
FaCheck,
|
FaCheck,
|
||||||
FaEye,
|
FaEye,
|
||||||
FaEyeSlash,
|
FaEyeSlash,
|
||||||
FaPlus,
|
|
||||||
FaSave,
|
FaSave,
|
||||||
|
FaTimes,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { stagingAreaImageSaved } from '../store/actions';
|
import { stagingAreaImageSaved } from '../store/actions';
|
||||||
@ -41,10 +42,10 @@ const selector = createSelector(
|
|||||||
} = canvas;
|
} = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
currentIndex: selectedImageIndex,
|
||||||
|
total: images.length,
|
||||||
currentStagingAreaImage:
|
currentStagingAreaImage:
|
||||||
images.length > 0 ? images[selectedImageIndex] : undefined,
|
images.length > 0 ? images[selectedImageIndex] : undefined,
|
||||||
isOnFirstImage: selectedImageIndex === 0,
|
|
||||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
};
|
};
|
||||||
@ -55,10 +56,10 @@ const selector = createSelector(
|
|||||||
const IAICanvasStagingAreaToolbar = () => {
|
const IAICanvasStagingAreaToolbar = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const {
|
const {
|
||||||
isOnFirstImage,
|
|
||||||
isOnLastImage,
|
|
||||||
currentStagingAreaImage,
|
currentStagingAreaImage,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
|
currentIndex,
|
||||||
|
total,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -71,39 +72,6 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
dispatch(setShouldShowStagingOutline(false));
|
dispatch(setShouldShowStagingOutline(false));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
useHotkeys(
|
|
||||||
['left'],
|
|
||||||
() => {
|
|
||||||
handlePrevImage();
|
|
||||||
},
|
|
||||||
{
|
|
||||||
enabled: () => true,
|
|
||||||
preventDefault: true,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
useHotkeys(
|
|
||||||
['right'],
|
|
||||||
() => {
|
|
||||||
handleNextImage();
|
|
||||||
},
|
|
||||||
{
|
|
||||||
enabled: () => true,
|
|
||||||
preventDefault: true,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
useHotkeys(
|
|
||||||
['enter'],
|
|
||||||
() => {
|
|
||||||
handleAccept();
|
|
||||||
},
|
|
||||||
{
|
|
||||||
enabled: () => true,
|
|
||||||
preventDefault: true,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const handlePrevImage = useCallback(
|
const handlePrevImage = useCallback(
|
||||||
() => dispatch(prevStagingAreaImage()),
|
() => dispatch(prevStagingAreaImage()),
|
||||||
[dispatch]
|
[dispatch]
|
||||||
@ -119,10 +87,45 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useHotkeys(['left'], handlePrevImage, {
|
||||||
|
enabled: () => true,
|
||||||
|
preventDefault: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys(['right'], handleNextImage, {
|
||||||
|
enabled: () => true,
|
||||||
|
preventDefault: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys(['enter'], () => handleAccept, {
|
||||||
|
enabled: () => true,
|
||||||
|
preventDefault: true,
|
||||||
|
});
|
||||||
|
|
||||||
const { data: imageDTO } = useGetImageDTOQuery(
|
const { data: imageDTO } = useGetImageDTOQuery(
|
||||||
currentStagingAreaImage?.imageName ?? skipToken
|
currentStagingAreaImage?.imageName ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleToggleShouldShowStagingImage = useCallback(() => {
|
||||||
|
dispatch(setShouldShowStagingImage(!shouldShowStagingImage));
|
||||||
|
}, [dispatch, shouldShowStagingImage]);
|
||||||
|
|
||||||
|
const handleSaveToGallery = useCallback(() => {
|
||||||
|
if (!imageDTO) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
stagingAreaImageSaved({
|
||||||
|
imageDTO,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}, [dispatch, imageDTO]);
|
||||||
|
|
||||||
|
const handleDiscardStagingArea = useCallback(() => {
|
||||||
|
dispatch(discardStagedImages());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
if (!currentStagingAreaImage) {
|
if (!currentStagingAreaImage) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -131,11 +134,12 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
<Flex
|
<Flex
|
||||||
pos="absolute"
|
pos="absolute"
|
||||||
bottom={4}
|
bottom={4}
|
||||||
|
gap={2}
|
||||||
w="100%"
|
w="100%"
|
||||||
align="center"
|
align="center"
|
||||||
justify="center"
|
justify="center"
|
||||||
onMouseOver={handleMouseOver}
|
onMouseEnter={handleMouseOver}
|
||||||
onMouseOut={handleMouseOut}
|
onMouseLeave={handleMouseOut}
|
||||||
>
|
>
|
||||||
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
|
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
@ -144,16 +148,29 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
icon={<FaArrowLeft />}
|
icon={<FaArrowLeft />}
|
||||||
onClick={handlePrevImage}
|
onClick={handlePrevImage}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
isDisabled={isOnFirstImage}
|
isDisabled={!shouldShowStagingImage}
|
||||||
/>
|
/>
|
||||||
|
<IAIButton
|
||||||
|
colorScheme="accent"
|
||||||
|
pointerEvents="none"
|
||||||
|
isDisabled={!shouldShowStagingImage}
|
||||||
|
sx={{
|
||||||
|
background: 'base.600',
|
||||||
|
_dark: {
|
||||||
|
background: 'base.800',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>{`${currentIndex + 1}/${total}`}</IAIButton>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={`${t('unifiedCanvas.next')} (Right)`}
|
tooltip={`${t('unifiedCanvas.next')} (Right)`}
|
||||||
aria-label={`${t('unifiedCanvas.next')} (Right)`}
|
aria-label={`${t('unifiedCanvas.next')} (Right)`}
|
||||||
icon={<FaArrowRight />}
|
icon={<FaArrowRight />}
|
||||||
onClick={handleNextImage}
|
onClick={handleNextImage}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
isDisabled={isOnLastImage}
|
isDisabled={!shouldShowStagingImage}
|
||||||
/>
|
/>
|
||||||
|
</ButtonGroup>
|
||||||
|
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={`${t('unifiedCanvas.accept')} (Enter)`}
|
tooltip={`${t('unifiedCanvas.accept')} (Enter)`}
|
||||||
aria-label={`${t('unifiedCanvas.accept')} (Enter)`}
|
aria-label={`${t('unifiedCanvas.accept')} (Enter)`}
|
||||||
@ -162,13 +179,19 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={t('unifiedCanvas.showHide')}
|
tooltip={
|
||||||
aria-label={t('unifiedCanvas.showHide')}
|
shouldShowStagingImage
|
||||||
|
? t('unifiedCanvas.showResultsOn')
|
||||||
|
: t('unifiedCanvas.showResultsOff')
|
||||||
|
}
|
||||||
|
aria-label={
|
||||||
|
shouldShowStagingImage
|
||||||
|
? t('unifiedCanvas.showResultsOn')
|
||||||
|
: t('unifiedCanvas.showResultsOff')
|
||||||
|
}
|
||||||
data-alert={!shouldShowStagingImage}
|
data-alert={!shouldShowStagingImage}
|
||||||
icon={shouldShowStagingImage ? <FaEye /> : <FaEyeSlash />}
|
icon={shouldShowStagingImage ? <FaEye /> : <FaEyeSlash />}
|
||||||
onClick={() =>
|
onClick={handleToggleShouldShowStagingImage}
|
||||||
dispatch(setShouldShowStagingImage(!shouldShowStagingImage))
|
|
||||||
}
|
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
@ -176,24 +199,14 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
aria-label={t('unifiedCanvas.saveToGallery')}
|
aria-label={t('unifiedCanvas.saveToGallery')}
|
||||||
isDisabled={!imageDTO || !imageDTO.is_intermediate}
|
isDisabled={!imageDTO || !imageDTO.is_intermediate}
|
||||||
icon={<FaSave />}
|
icon={<FaSave />}
|
||||||
onClick={() => {
|
onClick={handleSaveToGallery}
|
||||||
if (!imageDTO) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
stagingAreaImageSaved({
|
|
||||||
imageDTO,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}}
|
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={t('unifiedCanvas.discardAll')}
|
tooltip={t('unifiedCanvas.discardAll')}
|
||||||
aria-label={t('unifiedCanvas.discardAll')}
|
aria-label={t('unifiedCanvas.discardAll')}
|
||||||
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
|
icon={<FaTimes />}
|
||||||
onClick={() => dispatch(discardStagedImages())}
|
onClick={handleDiscardStagingArea}
|
||||||
colorScheme="error"
|
colorScheme="error"
|
||||||
fontSize={20}
|
fontSize={20}
|
||||||
/>
|
/>
|
||||||
|
@ -213,45 +213,45 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
|||||||
[scaledStep]
|
[scaledStep]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleStartedTransforming = () => {
|
const handleStartedTransforming = useCallback(() => {
|
||||||
dispatch(setIsTransformingBoundingBox(true));
|
dispatch(setIsTransformingBoundingBox(true));
|
||||||
};
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleEndedTransforming = () => {
|
const handleEndedTransforming = useCallback(() => {
|
||||||
dispatch(setIsTransformingBoundingBox(false));
|
dispatch(setIsTransformingBoundingBox(false));
|
||||||
dispatch(setIsMovingBoundingBox(false));
|
dispatch(setIsMovingBoundingBox(false));
|
||||||
dispatch(setIsMouseOverBoundingBox(false));
|
dispatch(setIsMouseOverBoundingBox(false));
|
||||||
setIsMouseOverBoundingBoxOutline(false);
|
setIsMouseOverBoundingBoxOutline(false);
|
||||||
};
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleStartedMoving = () => {
|
const handleStartedMoving = useCallback(() => {
|
||||||
dispatch(setIsMovingBoundingBox(true));
|
dispatch(setIsMovingBoundingBox(true));
|
||||||
};
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleEndedModifying = () => {
|
const handleEndedModifying = useCallback(() => {
|
||||||
dispatch(setIsTransformingBoundingBox(false));
|
dispatch(setIsTransformingBoundingBox(false));
|
||||||
dispatch(setIsMovingBoundingBox(false));
|
dispatch(setIsMovingBoundingBox(false));
|
||||||
dispatch(setIsMouseOverBoundingBox(false));
|
dispatch(setIsMouseOverBoundingBox(false));
|
||||||
setIsMouseOverBoundingBoxOutline(false);
|
setIsMouseOverBoundingBoxOutline(false);
|
||||||
};
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleMouseOver = () => {
|
const handleMouseOver = useCallback(() => {
|
||||||
setIsMouseOverBoundingBoxOutline(true);
|
setIsMouseOverBoundingBoxOutline(true);
|
||||||
};
|
}, []);
|
||||||
|
|
||||||
const handleMouseOut = () => {
|
const handleMouseOut = useCallback(() => {
|
||||||
!isTransformingBoundingBox &&
|
!isTransformingBoundingBox &&
|
||||||
!isMovingBoundingBox &&
|
!isMovingBoundingBox &&
|
||||||
setIsMouseOverBoundingBoxOutline(false);
|
setIsMouseOverBoundingBoxOutline(false);
|
||||||
};
|
}, [isMovingBoundingBox, isTransformingBoundingBox]);
|
||||||
|
|
||||||
const handleMouseEnterBoundingBox = () => {
|
const handleMouseEnterBoundingBox = useCallback(() => {
|
||||||
dispatch(setIsMouseOverBoundingBox(true));
|
dispatch(setIsMouseOverBoundingBox(true));
|
||||||
};
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleMouseLeaveBoundingBox = () => {
|
const handleMouseLeaveBoundingBox = useCallback(() => {
|
||||||
dispatch(setIsMouseOverBoundingBox(false));
|
dispatch(setIsMouseOverBoundingBox(false));
|
||||||
};
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Group {...rest}>
|
<Group {...rest}>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
@ -135,11 +135,12 @@ const IAICanvasMaskOptions = () => {
|
|||||||
dispatch(setShouldPreserveMaskedArea(e.target.checked))
|
dispatch(setShouldPreserveMaskedArea(e.target.checked))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<IAIColorPicker
|
<Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
|
||||||
sx={{ paddingTop: 2, paddingBottom: 2 }}
|
<IAIColorPicker
|
||||||
pickerColor={maskColor}
|
color={maskColor}
|
||||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||||
/>
|
/>
|
||||||
|
</Box>
|
||||||
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
|
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
|
||||||
Save Mask
|
Save Mask
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
import { ButtonGroup, Flex, Box } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -237,15 +237,18 @@ const IAICanvasToolChooserOptions = () => {
|
|||||||
sliderNumberInputProps={{ max: 500 }}
|
sliderNumberInputProps={{ max: 500 }}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
<IAIColorPicker
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
width: '100%',
|
width: '100%',
|
||||||
paddingTop: 2,
|
paddingTop: 2,
|
||||||
paddingBottom: 2,
|
paddingBottom: 2,
|
||||||
}}
|
}}
|
||||||
pickerColor={brushColor}
|
>
|
||||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
<IAIColorPicker
|
||||||
/>
|
color={brushColor}
|
||||||
|
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAIPopover>
|
</IAIPopover>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
@ -6,7 +6,7 @@ export const canvasSelector = (state: RootState): CanvasState => state.canvas;
|
|||||||
|
|
||||||
export const isStagingSelector = createSelector(
|
export const isStagingSelector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ canvas }) => canvas.layerState.stagingArea.images.length > 0
|
({ canvas }) => canvas.batchIds.length > 0
|
||||||
);
|
);
|
||||||
|
|
||||||
export const initialCanvasImageSelector = (
|
export const initialCanvasImageSelector = (
|
||||||
|
@ -8,7 +8,6 @@ import { setAspectRatio } from 'features/parameters/store/generationSlice';
|
|||||||
import { IRect, Vector2d } from 'konva/lib/types';
|
import { IRect, Vector2d } from 'konva/lib/types';
|
||||||
import { clamp, cloneDeep } from 'lodash-es';
|
import { clamp, cloneDeep } from 'lodash-es';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { sessionCanceled } from 'services/api/thunks/session';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import calculateCoordinates from '../util/calculateCoordinates';
|
import calculateCoordinates from '../util/calculateCoordinates';
|
||||||
import calculateScale from '../util/calculateScale';
|
import calculateScale from '../util/calculateScale';
|
||||||
@ -187,7 +186,7 @@ export const canvasSlice = createSlice({
|
|||||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||||
|
|
||||||
state.layerState = {
|
state.layerState = {
|
||||||
...initialLayerState,
|
...cloneDeep(initialLayerState),
|
||||||
objects: [
|
objects: [
|
||||||
{
|
{
|
||||||
kind: 'image',
|
kind: 'image',
|
||||||
@ -201,6 +200,7 @@ export const canvasSlice = createSlice({
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
|
state.batchIds = [];
|
||||||
|
|
||||||
const newScale = calculateScale(
|
const newScale = calculateScale(
|
||||||
stageDimensions.width,
|
stageDimensions.width,
|
||||||
@ -350,11 +350,14 @@ export const canvasSlice = createSlice({
|
|||||||
state.pastLayerStates.shift();
|
state.pastLayerStates.shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
state.layerState.stagingArea = { ...initialLayerState.stagingArea };
|
state.layerState.stagingArea = cloneDeep(
|
||||||
|
cloneDeep(initialLayerState)
|
||||||
|
).stagingArea;
|
||||||
|
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
state.shouldShowStagingOutline = true;
|
state.shouldShowStagingOutline = true;
|
||||||
state.shouldShowStagingOutline = true;
|
state.shouldShowStagingImage = true;
|
||||||
|
state.batchIds = [];
|
||||||
},
|
},
|
||||||
addFillRect: (state) => {
|
addFillRect: (state) => {
|
||||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } =
|
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } =
|
||||||
@ -491,8 +494,9 @@ export const canvasSlice = createSlice({
|
|||||||
resetCanvas: (state) => {
|
resetCanvas: (state) => {
|
||||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||||
|
|
||||||
state.layerState = initialLayerState;
|
state.layerState = cloneDeep(initialLayerState);
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
|
state.batchIds = [];
|
||||||
},
|
},
|
||||||
canvasResized: (
|
canvasResized: (
|
||||||
state,
|
state,
|
||||||
@ -617,25 +621,22 @@ export const canvasSlice = createSlice({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
const nextIndex = state.layerState.stagingArea.selectedImageIndex + 1;
|
||||||
const length = state.layerState.stagingArea.images.length;
|
const lastIndex = state.layerState.stagingArea.images.length - 1;
|
||||||
|
|
||||||
state.layerState.stagingArea.selectedImageIndex = Math.min(
|
state.layerState.stagingArea.selectedImageIndex =
|
||||||
currentIndex + 1,
|
nextIndex > lastIndex ? 0 : nextIndex;
|
||||||
length - 1
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
prevStagingAreaImage: (state) => {
|
prevStagingAreaImage: (state) => {
|
||||||
if (!state.layerState.stagingArea.images.length) {
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
const prevIndex = state.layerState.stagingArea.selectedImageIndex - 1;
|
||||||
|
const lastIndex = state.layerState.stagingArea.images.length - 1;
|
||||||
|
|
||||||
state.layerState.stagingArea.selectedImageIndex = Math.max(
|
state.layerState.stagingArea.selectedImageIndex =
|
||||||
currentIndex - 1,
|
prevIndex < 0 ? lastIndex : prevIndex;
|
||||||
0
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
commitStagingAreaImage: (state) => {
|
commitStagingAreaImage: (state) => {
|
||||||
if (!state.layerState.stagingArea.images.length) {
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
@ -657,13 +658,12 @@ export const canvasSlice = createSlice({
|
|||||||
...imageToCommit,
|
...imageToCommit,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
state.layerState.stagingArea = {
|
state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea;
|
||||||
...initialLayerState.stagingArea,
|
|
||||||
};
|
|
||||||
|
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
state.shouldShowStagingOutline = true;
|
state.shouldShowStagingOutline = true;
|
||||||
state.shouldShowStagingImage = true;
|
state.shouldShowStagingImage = true;
|
||||||
|
state.batchIds = [];
|
||||||
},
|
},
|
||||||
fitBoundingBoxToStage: (state) => {
|
fitBoundingBoxToStage: (state) => {
|
||||||
const {
|
const {
|
||||||
@ -786,11 +786,6 @@ export const canvasSlice = createSlice({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(sessionCanceled.pending, (state) => {
|
|
||||||
if (!state.layerState.stagingArea.images.length) {
|
|
||||||
state.layerState.stagingArea = initialLayerState.stagingArea;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
builder.addCase(setAspectRatio, (state, action) => {
|
builder.addCase(setAspectRatio, (state, action) => {
|
||||||
const ratio = action.payload;
|
const ratio = action.payload;
|
||||||
if (ratio) {
|
if (ratio) {
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import { getCanvasBaseLayer } from './konvaInstanceProvider';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { getCanvasBaseLayer } from './konvaInstanceProvider';
|
||||||
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the canvas base layer blob, with or without bounding box according to `shouldCropToBoundingBoxOnSave`
|
* Get the canvas base layer blob, with or without bounding box according to `shouldCropToBoundingBoxOnSave`
|
||||||
*/
|
*/
|
||||||
export const getBaseLayerBlob = async (state: RootState) => {
|
export const getBaseLayerBlob = async (
|
||||||
|
state: RootState,
|
||||||
|
alwaysUseBoundingBox: boolean = false
|
||||||
|
) => {
|
||||||
const canvasBaseLayer = getCanvasBaseLayer();
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
|
|
||||||
if (!canvasBaseLayer) {
|
if (!canvasBaseLayer) {
|
||||||
@ -24,14 +27,15 @@ export const getBaseLayerBlob = async (state: RootState) => {
|
|||||||
|
|
||||||
const absPos = clonedBaseLayer.getAbsolutePosition();
|
const absPos = clonedBaseLayer.getAbsolutePosition();
|
||||||
|
|
||||||
const boundingBox = shouldCropToBoundingBoxOnSave
|
const boundingBox =
|
||||||
? {
|
shouldCropToBoundingBoxOnSave || alwaysUseBoundingBox
|
||||||
x: boundingBoxCoordinates.x + absPos.x,
|
? {
|
||||||
y: boundingBoxCoordinates.y + absPos.y,
|
x: boundingBoxCoordinates.x + absPos.x,
|
||||||
width: boundingBoxDimensions.width,
|
y: boundingBoxCoordinates.y + absPos.y,
|
||||||
height: boundingBoxDimensions.height,
|
width: boundingBoxDimensions.width,
|
||||||
}
|
height: boundingBoxDimensions.height,
|
||||||
: clonedBaseLayer.getClientRect();
|
}
|
||||||
|
: clonedBaseLayer.getClientRect();
|
||||||
|
|
||||||
return konvaNodeToBlob(clonedBaseLayer, boundingBox);
|
return konvaNodeToBlob(clonedBaseLayer, boundingBox);
|
||||||
};
|
};
|
||||||
|
@ -6,7 +6,6 @@ import {
|
|||||||
import { cloneDeep, forEach } from 'lodash-es';
|
import { cloneDeep, forEach } from 'lodash-es';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { components } from 'services/api/schema';
|
import { components } from 'services/api/schema';
|
||||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { appSocketInvocationError } from 'services/events/actions';
|
import { appSocketInvocationError } from 'services/events/actions';
|
||||||
import { controlNetImageProcessed } from './actions';
|
import { controlNetImageProcessed } from './actions';
|
||||||
@ -99,6 +98,9 @@ export const controlNetSlice = createSlice({
|
|||||||
isControlNetEnabledToggled: (state) => {
|
isControlNetEnabledToggled: (state) => {
|
||||||
state.isEnabled = !state.isEnabled;
|
state.isEnabled = !state.isEnabled;
|
||||||
},
|
},
|
||||||
|
controlNetEnabled: (state) => {
|
||||||
|
state.isEnabled = true;
|
||||||
|
},
|
||||||
controlNetAdded: (
|
controlNetAdded: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -112,6 +114,12 @@ export const controlNetSlice = createSlice({
|
|||||||
controlNetId,
|
controlNetId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
|
||||||
|
const controlNet = action.payload;
|
||||||
|
state.controlNets[controlNet.controlNetId] = {
|
||||||
|
...controlNet,
|
||||||
|
};
|
||||||
|
},
|
||||||
controlNetDuplicated: (
|
controlNetDuplicated: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -418,10 +426,6 @@ export const controlNetSlice = createSlice({
|
|||||||
state.pendingControlImages = [];
|
state.pendingControlImages = [];
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addMatcher(isAnySessionRejected, (state) => {
|
|
||||||
state.pendingControlImages = [];
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addMatcher(
|
builder.addMatcher(
|
||||||
imagesApi.endpoints.deleteImage.matchFulfilled,
|
imagesApi.endpoints.deleteImage.matchFulfilled,
|
||||||
(state, action) => {
|
(state, action) => {
|
||||||
@ -444,7 +448,9 @@ export const controlNetSlice = createSlice({
|
|||||||
|
|
||||||
export const {
|
export const {
|
||||||
isControlNetEnabledToggled,
|
isControlNetEnabledToggled,
|
||||||
|
controlNetEnabled,
|
||||||
controlNetAdded,
|
controlNetAdded,
|
||||||
|
controlNetRecalled,
|
||||||
controlNetDuplicated,
|
controlNetDuplicated,
|
||||||
controlNetAddedFromImage,
|
controlNetAddedFromImage,
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
|
@ -93,7 +93,7 @@ const GalleryBoard = ({
|
|||||||
const [localBoardName, setLocalBoardName] = useState(board_name);
|
const [localBoardName, setLocalBoardName] = useState(board_name);
|
||||||
|
|
||||||
const handleSelectBoard = useCallback(() => {
|
const handleSelectBoard = useCallback(() => {
|
||||||
dispatch(boardIdSelected(board_id));
|
dispatch(boardIdSelected({ boardId: board_id }));
|
||||||
if (autoAssignBoardOnClick) {
|
if (autoAssignBoardOnClick) {
|
||||||
dispatch(autoAddBoardIdChanged(board_id));
|
dispatch(autoAddBoardIdChanged(board_id));
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector);
|
const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector);
|
||||||
const boardName = useBoardName('none');
|
const boardName = useBoardName('none');
|
||||||
const handleSelectBoard = useCallback(() => {
|
const handleSelectBoard = useCallback(() => {
|
||||||
dispatch(boardIdSelected('none'));
|
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||||
if (autoAssignBoardOnClick) {
|
if (autoAssignBoardOnClick) {
|
||||||
dispatch(autoAddBoardIdChanged('none'));
|
dispatch(autoAddBoardIdChanged('none'));
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ const SystemBoardButton = ({ board_id }: Props) => {
|
|||||||
const boardName = useBoardName(board_id);
|
const boardName = useBoardName(board_id);
|
||||||
|
|
||||||
const handleClick = useCallback(() => {
|
const handleClick = useCallback(() => {
|
||||||
dispatch(boardIdSelected(board_id));
|
dispatch(boardIdSelected({ boardId: board_id }));
|
||||||
}, [board_id, dispatch]);
|
}, [board_id, dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -1,8 +1,15 @@
|
|||||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
import {
|
||||||
|
ControlNetMetadataItem,
|
||||||
|
CoreMetadata,
|
||||||
|
LoRAMetadataItem,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useMemo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
import {
|
||||||
|
isValidControlNetModel,
|
||||||
|
isValidLoRAModel,
|
||||||
|
} from '../../../parameters/types/parameterSchemas';
|
||||||
import ImageMetadataItem from './ImageMetadataItem';
|
import ImageMetadataItem from './ImageMetadataItem';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallHeight,
|
recallHeight,
|
||||||
recallStrength,
|
recallStrength,
|
||||||
recallLoRA,
|
recallLoRA,
|
||||||
|
recallControlNet,
|
||||||
} = useRecallParameters();
|
} = useRecallParameters();
|
||||||
|
|
||||||
const handleRecallPositivePrompt = useCallback(() => {
|
const handleRecallPositivePrompt = useCallback(() => {
|
||||||
@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
[recallLoRA]
|
[recallLoRA]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleRecallControlNet = useCallback(
|
||||||
|
(controlnet: ControlNetMetadataItem) => {
|
||||||
|
recallControlNet(controlnet);
|
||||||
|
},
|
||||||
|
[recallControlNet]
|
||||||
|
);
|
||||||
|
|
||||||
|
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||||
|
return metadata?.controlnets
|
||||||
|
? metadata.controlnets.filter((controlnet) =>
|
||||||
|
isValidControlNetModel(controlnet.control_model)
|
||||||
|
)
|
||||||
|
: [];
|
||||||
|
}, [metadata?.controlnets]);
|
||||||
|
|
||||||
if (!metadata || Object.keys(metadata).length === 0) {
|
if (!metadata || Object.keys(metadata).length === 0) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
})}
|
})}
|
||||||
|
{validControlNets.map((controlnet, index) => (
|
||||||
|
<ImageMetadataItem
|
||||||
|
key={index}
|
||||||
|
label="ControlNet"
|
||||||
|
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
|
||||||
|
onClick={() => handleRecallControlNet(controlnet)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -35,8 +35,11 @@ export const gallerySlice = createSlice({
|
|||||||
autoAssignBoardOnClickChanged: (state, action: PayloadAction<boolean>) => {
|
autoAssignBoardOnClickChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.autoAssignBoardOnClick = action.payload;
|
state.autoAssignBoardOnClick = action.payload;
|
||||||
},
|
},
|
||||||
boardIdSelected: (state, action: PayloadAction<BoardId>) => {
|
boardIdSelected: (
|
||||||
state.selectedBoardId = action.payload;
|
state,
|
||||||
|
action: PayloadAction<{ boardId: BoardId; selectedImageName?: string }>
|
||||||
|
) => {
|
||||||
|
state.selectedBoardId = action.payload.boardId;
|
||||||
state.galleryView = 'images';
|
state.galleryView = 'images';
|
||||||
},
|
},
|
||||||
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
||||||
|
@ -12,6 +12,7 @@ import {
|
|||||||
OnConnect,
|
OnConnect,
|
||||||
OnConnectEnd,
|
OnConnectEnd,
|
||||||
OnConnectStart,
|
OnConnectStart,
|
||||||
|
OnEdgeUpdateFunc,
|
||||||
OnEdgesChange,
|
OnEdgesChange,
|
||||||
OnEdgesDelete,
|
OnEdgesDelete,
|
||||||
OnInit,
|
OnInit,
|
||||||
@ -21,6 +22,7 @@ import {
|
|||||||
OnSelectionChangeFunc,
|
OnSelectionChangeFunc,
|
||||||
ProOptions,
|
ProOptions,
|
||||||
ReactFlow,
|
ReactFlow,
|
||||||
|
ReactFlowProps,
|
||||||
XYPosition,
|
XYPosition,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
import { useIsValidConnection } from '../../hooks/useIsValidConnection';
|
import { useIsValidConnection } from '../../hooks/useIsValidConnection';
|
||||||
@ -28,6 +30,8 @@ import {
|
|||||||
connectionEnded,
|
connectionEnded,
|
||||||
connectionMade,
|
connectionMade,
|
||||||
connectionStarted,
|
connectionStarted,
|
||||||
|
edgeAdded,
|
||||||
|
edgeDeleted,
|
||||||
edgesChanged,
|
edgesChanged,
|
||||||
edgesDeleted,
|
edgesDeleted,
|
||||||
nodesChanged,
|
nodesChanged,
|
||||||
@ -167,6 +171,63 @@ export const Flow = () => {
|
|||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// #region Updatable Edges
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adapted from https://reactflow.dev/docs/examples/edges/updatable-edge/
|
||||||
|
* and https://reactflow.dev/docs/examples/edges/delete-edge-on-drop/
|
||||||
|
*
|
||||||
|
* - Edges can be dragged from one handle to another.
|
||||||
|
* - If the user drags the edge away from the node and drops it, delete the edge.
|
||||||
|
* - Do not delete the edge if the cursor didn't move (resolves annoying behaviour
|
||||||
|
* where the edge is deleted if you click it accidentally).
|
||||||
|
*/
|
||||||
|
|
||||||
|
// We have a ref for cursor position, but it is the *projected* cursor position.
|
||||||
|
// Easiest to just keep track of the last mouse event for this particular feature
|
||||||
|
const edgeUpdateMouseEvent = useRef<MouseEvent>();
|
||||||
|
|
||||||
|
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> =
|
||||||
|
useCallback(
|
||||||
|
(e, edge, _handleType) => {
|
||||||
|
// update mouse event
|
||||||
|
edgeUpdateMouseEvent.current = e;
|
||||||
|
// always delete the edge when starting an updated
|
||||||
|
dispatch(edgeDeleted(edge.id));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
|
||||||
|
(_oldEdge, newConnection) => {
|
||||||
|
// instead of updating the edge (we deleted it earlier), we instead create
|
||||||
|
// a new one.
|
||||||
|
dispatch(connectionMade(newConnection));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> =
|
||||||
|
useCallback(
|
||||||
|
(e, edge, _handleType) => {
|
||||||
|
// Handle the case where user begins a drag but didn't move the cursor -
|
||||||
|
// bc we deleted the edge, we need to add it back
|
||||||
|
if (
|
||||||
|
// ignore touch events
|
||||||
|
!('touches' in e) &&
|
||||||
|
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
|
||||||
|
edgeUpdateMouseEvent.current?.clientY === e.clientY
|
||||||
|
) {
|
||||||
|
dispatch(edgeAdded(edge));
|
||||||
|
}
|
||||||
|
// reset mouse event
|
||||||
|
edgeUpdateMouseEvent.current = undefined;
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
// #endregion
|
||||||
|
|
||||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
dispatch(selectionCopied());
|
dispatch(selectionCopied());
|
||||||
@ -196,6 +257,9 @@ export const Flow = () => {
|
|||||||
onNodesChange={onNodesChange}
|
onNodesChange={onNodesChange}
|
||||||
onEdgesChange={onEdgesChange}
|
onEdgesChange={onEdgesChange}
|
||||||
onEdgesDelete={onEdgesDelete}
|
onEdgesDelete={onEdgesDelete}
|
||||||
|
onEdgeUpdate={onEdgeUpdate}
|
||||||
|
onEdgeUpdateStart={onEdgeUpdateStart}
|
||||||
|
onEdgeUpdateEnd={onEdgeUpdateEnd}
|
||||||
onNodesDelete={onNodesDelete}
|
onNodesDelete={onNodesDelete}
|
||||||
onConnectStart={onConnectStart}
|
onConnectStart={onConnectStart}
|
||||||
onConnect={onConnect}
|
onConnect={onConnect}
|
||||||
|
@ -8,6 +8,7 @@ import InvocationNodeFooter from './InvocationNodeFooter';
|
|||||||
import InvocationNodeHeader from './InvocationNodeHeader';
|
import InvocationNodeHeader from './InvocationNodeHeader';
|
||||||
import InputField from './fields/InputField';
|
import InputField from './fields/InputField';
|
||||||
import OutputField from './fields/OutputField';
|
import OutputField from './fields/OutputField';
|
||||||
|
import { useWithFooter } from 'features/nodes/hooks/useWithFooter';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -20,6 +21,7 @@ type Props = {
|
|||||||
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||||
const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId);
|
const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId);
|
||||||
const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId);
|
const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId);
|
||||||
|
const withFooter = useWithFooter(nodeId);
|
||||||
const outputFieldNames = useOutputFieldNames(nodeId);
|
const outputFieldNames = useOutputFieldNames(nodeId);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -41,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
|||||||
h: 'full',
|
h: 'full',
|
||||||
py: 2,
|
py: 2,
|
||||||
gap: 1,
|
gap: 1,
|
||||||
borderBottomRadius: 0,
|
borderBottomRadius: withFooter ? 0 : 'base',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Flex sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}>
|
<Flex sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}>
|
||||||
@ -74,7 +76,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
|||||||
))}
|
))}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
<InvocationNodeFooter nodeId={nodeId} />
|
{withFooter && <InvocationNodeFooter nodeId={nodeId} />}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</NodeWrapper>
|
</NodeWrapper>
|
||||||
|
@ -5,6 +5,7 @@ import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
|
|||||||
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
||||||
import UseCacheCheckbox from './UseCacheCheckbox';
|
import UseCacheCheckbox from './UseCacheCheckbox';
|
||||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||||
|
import { useFeatureStatus } from '../../../../../system/hooks/useFeatureStatus';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -12,6 +13,7 @@ type Props = {
|
|||||||
|
|
||||||
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||||
const hasImageOutput = useHasImageOutput(nodeId);
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
|
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
className={DRAG_HANDLE_CLASSNAME}
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
@ -25,7 +27,7 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
|
|||||||
justifyContent: 'space-between',
|
justifyContent: 'space-between',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<UseCacheCheckbox nodeId={nodeId} />
|
{isCacheEnabled && <UseCacheCheckbox nodeId={nodeId} />}
|
||||||
{hasImageOutput && <EmbedWorkflowCheckbox nodeId={nodeId} />}
|
{hasImageOutput && <EmbedWorkflowCheckbox nodeId={nodeId} />}
|
||||||
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
|
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -53,13 +53,12 @@ export const useIsValidConnection = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
edges
|
edges.find((edge) => {
|
||||||
.filter((edge) => {
|
edge.target === target &&
|
||||||
return edge.target === target && edge.targetHandle === targetHandle;
|
edge.targetHandle === targetHandle &&
|
||||||
})
|
edge.source === source &&
|
||||||
.find((edge) => {
|
edge.sourceHandle === sourceHandle;
|
||||||
edge.source === source && edge.sourceHandle === sourceHandle;
|
})
|
||||||
})
|
|
||||||
) {
|
) {
|
||||||
// We already have a connection from this source to this target
|
// We already have a connection from this source to this target
|
||||||
return false;
|
return false;
|
||||||
|
@ -1,31 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import { some } from 'lodash-es';
|
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { FOOTER_FIELDS } from '../types/constants';
|
import { useHasImageOutput } from './useHasImageOutput';
|
||||||
import { isInvocationNode } from '../types/types';
|
|
||||||
|
|
||||||
export const useHasImageOutputs = (nodeId: string) => {
|
export const useWithFooter = (nodeId: string) => {
|
||||||
const selector = useMemo(
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
() =>
|
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
const withFooter = useMemo(
|
||||||
({ nodes }) => {
|
() => hasImageOutput || isCacheEnabled,
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
[hasImageOutput, isCacheEnabled]
|
||||||
if (!isInvocationNode(node)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return some(node.data.outputs, (output) =>
|
|
||||||
FOOTER_FIELDS.includes(output.type)
|
|
||||||
);
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[nodeId]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const withFooter = useAppSelector(selector);
|
|
||||||
return withFooter;
|
return withFooter;
|
||||||
};
|
};
|
||||||
|
@ -15,6 +15,7 @@ import {
|
|||||||
NodeChange,
|
NodeChange,
|
||||||
OnConnectStartParams,
|
OnConnectStartParams,
|
||||||
SelectionMode,
|
SelectionMode,
|
||||||
|
updateEdge,
|
||||||
Viewport,
|
Viewport,
|
||||||
XYPosition,
|
XYPosition,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
@ -182,6 +183,16 @@ const nodesSlice = createSlice({
|
|||||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||||
},
|
},
|
||||||
|
edgeAdded: (state, action: PayloadAction<Edge>) => {
|
||||||
|
state.edges = addEdge(action.payload, state.edges);
|
||||||
|
},
|
||||||
|
edgeUpdated: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ oldEdge: Edge; newConnection: Connection }>
|
||||||
|
) => {
|
||||||
|
const { oldEdge, newConnection } = action.payload;
|
||||||
|
state.edges = updateEdge(oldEdge, newConnection, state.edges);
|
||||||
|
},
|
||||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||||
state.connectionStartParams = action.payload;
|
state.connectionStartParams = action.payload;
|
||||||
const { nodeId, handleId, handleType } = action.payload;
|
const { nodeId, handleId, handleType } = action.payload;
|
||||||
@ -366,6 +377,7 @@ const nodesSlice = createSlice({
|
|||||||
target: edge.target,
|
target: edge.target,
|
||||||
type: 'collapsed',
|
type: 'collapsed',
|
||||||
data: { count: 1 },
|
data: { count: 1 },
|
||||||
|
updatable: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -388,6 +400,7 @@ const nodesSlice = createSlice({
|
|||||||
target: edge.target,
|
target: edge.target,
|
||||||
type: 'collapsed',
|
type: 'collapsed',
|
||||||
data: { count: 1 },
|
data: { count: 1 },
|
||||||
|
updatable: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -400,6 +413,9 @@ const nodesSlice = createSlice({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
edgeDeleted: (state, action: PayloadAction<string>) => {
|
||||||
|
state.edges = state.edges.filter((e) => e.id !== action.payload);
|
||||||
|
},
|
||||||
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
|
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
|
||||||
const edges = action.payload;
|
const edges = action.payload;
|
||||||
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
|
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
|
||||||
@ -890,69 +906,72 @@ const nodesSlice = createSlice({
|
|||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
nodesChanged,
|
addNodePopoverClosed,
|
||||||
edgesChanged,
|
addNodePopoverOpened,
|
||||||
nodeAdded,
|
addNodePopoverToggled,
|
||||||
nodesDeleted,
|
connectionEnded,
|
||||||
connectionMade,
|
connectionMade,
|
||||||
connectionStarted,
|
connectionStarted,
|
||||||
connectionEnded,
|
edgeDeleted,
|
||||||
shouldShowFieldTypeLegendChanged,
|
edgesChanged,
|
||||||
shouldShowMinimapPanelChanged,
|
edgesDeleted,
|
||||||
nodeTemplatesBuilt,
|
edgeUpdated,
|
||||||
nodeEditorReset,
|
|
||||||
imageCollectionFieldValueChanged,
|
|
||||||
fieldStringValueChanged,
|
|
||||||
fieldNumberValueChanged,
|
|
||||||
fieldBoardValueChanged,
|
fieldBoardValueChanged,
|
||||||
fieldBooleanValueChanged,
|
fieldBooleanValueChanged,
|
||||||
fieldImageValueChanged,
|
|
||||||
fieldColorValueChanged,
|
fieldColorValueChanged,
|
||||||
fieldMainModelValueChanged,
|
|
||||||
fieldVaeModelValueChanged,
|
|
||||||
fieldLoRAModelValueChanged,
|
|
||||||
fieldEnumModelValueChanged,
|
|
||||||
fieldControlNetModelValueChanged,
|
fieldControlNetModelValueChanged,
|
||||||
|
fieldEnumModelValueChanged,
|
||||||
|
fieldImageValueChanged,
|
||||||
fieldIPAdapterModelValueChanged,
|
fieldIPAdapterModelValueChanged,
|
||||||
|
fieldLabelChanged,
|
||||||
|
fieldLoRAModelValueChanged,
|
||||||
|
fieldMainModelValueChanged,
|
||||||
|
fieldNumberValueChanged,
|
||||||
fieldRefinerModelValueChanged,
|
fieldRefinerModelValueChanged,
|
||||||
fieldSchedulerValueChanged,
|
fieldSchedulerValueChanged,
|
||||||
|
fieldStringValueChanged,
|
||||||
|
fieldVaeModelValueChanged,
|
||||||
|
imageCollectionFieldValueChanged,
|
||||||
|
mouseOverFieldChanged,
|
||||||
|
mouseOverNodeChanged,
|
||||||
|
nodeAdded,
|
||||||
|
nodeEditorReset,
|
||||||
|
nodeEmbedWorkflowChanged,
|
||||||
|
nodeExclusivelySelected,
|
||||||
|
nodeIsIntermediateChanged,
|
||||||
nodeIsOpenChanged,
|
nodeIsOpenChanged,
|
||||||
nodeLabelChanged,
|
nodeLabelChanged,
|
||||||
nodeNotesChanged,
|
nodeNotesChanged,
|
||||||
edgesDeleted,
|
|
||||||
shouldValidateGraphChanged,
|
|
||||||
shouldAnimateEdgesChanged,
|
|
||||||
nodeOpacityChanged,
|
nodeOpacityChanged,
|
||||||
shouldSnapToGridChanged,
|
nodesChanged,
|
||||||
shouldColorEdgesChanged,
|
nodesDeleted,
|
||||||
selectedNodesChanged,
|
nodeTemplatesBuilt,
|
||||||
selectedEdgesChanged,
|
nodeUseCacheChanged,
|
||||||
workflowNameChanged,
|
|
||||||
workflowDescriptionChanged,
|
|
||||||
workflowTagsChanged,
|
|
||||||
workflowAuthorChanged,
|
|
||||||
workflowNotesChanged,
|
|
||||||
workflowVersionChanged,
|
|
||||||
workflowContactChanged,
|
|
||||||
workflowLoaded,
|
|
||||||
notesNodeValueChanged,
|
notesNodeValueChanged,
|
||||||
|
selectedAll,
|
||||||
|
selectedEdgesChanged,
|
||||||
|
selectedNodesChanged,
|
||||||
|
selectionCopied,
|
||||||
|
selectionModeChanged,
|
||||||
|
selectionPasted,
|
||||||
|
shouldAnimateEdgesChanged,
|
||||||
|
shouldColorEdgesChanged,
|
||||||
|
shouldShowFieldTypeLegendChanged,
|
||||||
|
shouldShowMinimapPanelChanged,
|
||||||
|
shouldSnapToGridChanged,
|
||||||
|
shouldValidateGraphChanged,
|
||||||
|
viewportChanged,
|
||||||
|
workflowAuthorChanged,
|
||||||
|
workflowContactChanged,
|
||||||
|
workflowDescriptionChanged,
|
||||||
workflowExposedFieldAdded,
|
workflowExposedFieldAdded,
|
||||||
workflowExposedFieldRemoved,
|
workflowExposedFieldRemoved,
|
||||||
fieldLabelChanged,
|
workflowLoaded,
|
||||||
viewportChanged,
|
workflowNameChanged,
|
||||||
mouseOverFieldChanged,
|
workflowNotesChanged,
|
||||||
selectionCopied,
|
workflowTagsChanged,
|
||||||
selectionPasted,
|
workflowVersionChanged,
|
||||||
selectedAll,
|
edgeAdded,
|
||||||
addNodePopoverOpened,
|
|
||||||
addNodePopoverClosed,
|
|
||||||
addNodePopoverToggled,
|
|
||||||
selectionModeChanged,
|
|
||||||
nodeEmbedWorkflowChanged,
|
|
||||||
nodeIsIntermediateChanged,
|
|
||||||
mouseOverNodeChanged,
|
|
||||||
nodeExclusivelySelected,
|
|
||||||
nodeUseCacheChanged,
|
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
|
|
||||||
export default nodesSlice.reducer;
|
export default nodesSlice.reducer;
|
||||||
|
@ -55,9 +55,29 @@ export const makeConnectionErrorSelector = (
|
|||||||
return i18n.t('nodes.cannotConnectInputToInput');
|
return i18n.t('nodes.cannotConnectInputToInput');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we have to figure out which is the target and which is the source
|
||||||
|
const target = handleType === 'target' ? nodeId : connectionNodeId;
|
||||||
|
const targetHandle =
|
||||||
|
handleType === 'target' ? fieldName : connectionFieldName;
|
||||||
|
const source = handleType === 'source' ? nodeId : connectionNodeId;
|
||||||
|
const sourceHandle =
|
||||||
|
handleType === 'source' ? fieldName : connectionFieldName;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
edges.find((edge) => {
|
edges.find((edge) => {
|
||||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
edge.target === target &&
|
||||||
|
edge.targetHandle === targetHandle &&
|
||||||
|
edge.source === source &&
|
||||||
|
edge.sourceHandle === sourceHandle;
|
||||||
|
})
|
||||||
|
) {
|
||||||
|
// We already have a connection from this source to this target
|
||||||
|
return i18n.t('nodes.cannotDuplicateConnection');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
edges.find((edge) => {
|
||||||
|
return edge.target === target && edge.targetHandle === targetHandle;
|
||||||
}) &&
|
}) &&
|
||||||
// except CollectionItem inputs can have multiples
|
// except CollectionItem inputs can have multiples
|
||||||
targetType !== 'CollectionItem'
|
targetType !== 'CollectionItem'
|
||||||
|
@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({
|
|||||||
|
|
||||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||||
|
|
||||||
|
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||||
|
|
||||||
|
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||||
|
|
||||||
export const zCoreMetadata = z
|
export const zCoreMetadata = z
|
||||||
.object({
|
.object({
|
||||||
app_version: z.string().nullish().catch(null),
|
app_version: z.string().nullish().catch(null),
|
||||||
@ -1222,6 +1226,7 @@ export const zInvocationNodeData = z.object({
|
|||||||
notes: z.string(),
|
notes: z.string(),
|
||||||
embedWorkflow: z.boolean(),
|
embedWorkflow: z.boolean(),
|
||||||
isIntermediate: z.boolean(),
|
isIntermediate: z.boolean(),
|
||||||
|
useCache: z.boolean().optional(),
|
||||||
version: zSemVer.optional(),
|
version: zSemVer.optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ export const addSDXLRefinerToGraph = (
|
|||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string,
|
baseNodeId: string,
|
||||||
modelLoaderNodeId?: string,
|
modelLoaderNodeId?: string,
|
||||||
canvasInitImage?: ImageDTO
|
canvasInitImage?: ImageDTO,
|
||||||
|
canvasMaskImage?: ImageDTO
|
||||||
): void => {
|
): void => {
|
||||||
const {
|
const {
|
||||||
refinerModel,
|
refinerModel,
|
||||||
@ -257,8 +258,30 @@ export const addSDXLRefinerToGraph = (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.edges.push(
|
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
|
||||||
{
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: MASK_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
||||||
|
...(graph.nodes[
|
||||||
|
SDXL_REFINER_INPAINT_CREATE_MASK
|
||||||
|
] as CreateDenoiseMaskInvocation),
|
||||||
|
mask: canvasMaskImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
||||||
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||||
field: 'image',
|
field: 'image',
|
||||||
@ -267,18 +290,19 @@ export const addSDXLRefinerToGraph = (
|
|||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
field: 'mask',
|
field: 'mask',
|
||||||
},
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
field: 'denoise_mask',
|
||||||
},
|
},
|
||||||
{
|
destination: {
|
||||||
source: {
|
node_id: SDXL_REFINER_DENOISE_LATENTS,
|
||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
field: 'denoise_mask',
|
||||||
field: 'denoise_mask',
|
},
|
||||||
},
|
});
|
||||||
destination: {
|
|
||||||
node_id: SDXL_REFINER_DENOISE_LATENTS,
|
|
||||||
field: 'denoise_mask',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -663,7 +663,8 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
graph,
|
graph,
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
modelLoaderNodeId,
|
modelLoaderNodeId,
|
||||||
canvasInitImage
|
canvasInitImage,
|
||||||
|
canvasMaskImage
|
||||||
);
|
);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
|
@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
import {
|
||||||
|
CoreMetadata,
|
||||||
|
LoRAMetadataItem,
|
||||||
|
ControlNetMetadataItem,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
setNegativeStylePromptSDXL,
|
setNegativeStylePromptSDXL,
|
||||||
@ -18,9 +22,18 @@ import { useCallback } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
|
controlNetModelsAdapter,
|
||||||
loraModelsAdapter,
|
loraModelsAdapter,
|
||||||
|
useGetControlNetModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
} from '../../../services/api/endpoints/models';
|
} from '../../../services/api/endpoints/models';
|
||||||
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetEnabled,
|
||||||
|
controlNetRecalled,
|
||||||
|
controlNetReset,
|
||||||
|
initialControlNet,
|
||||||
|
} from '../../controlNet/store/controlNetSlice';
|
||||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||||
import {
|
import {
|
||||||
@ -38,6 +51,7 @@ import {
|
|||||||
isValidCfgScale,
|
isValidCfgScale,
|
||||||
isValidHeight,
|
isValidHeight,
|
||||||
isValidLoRAModel,
|
isValidLoRAModel,
|
||||||
|
isValidControlNetModel,
|
||||||
isValidMainModel,
|
isValidMainModel,
|
||||||
isValidNegativePrompt,
|
isValidNegativePrompt,
|
||||||
isValidPositivePrompt,
|
isValidPositivePrompt,
|
||||||
@ -53,6 +67,11 @@ import {
|
|||||||
isValidStrength,
|
isValidStrength,
|
||||||
isValidWidth,
|
isValidWidth,
|
||||||
} from '../types/parameterSchemas';
|
} from '../types/parameterSchemas';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import {
|
||||||
|
CONTROLNET_PROCESSORS,
|
||||||
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||||
|
} from 'features/controlNet/store/constants';
|
||||||
|
|
||||||
const selector = createSelector(stateSelector, ({ generation }) => {
|
const selector = createSelector(stateSelector, ({ generation }) => {
|
||||||
const { model } = generation;
|
const { model } = generation;
|
||||||
@ -390,6 +409,121 @@ export const useRecallParameters = () => {
|
|||||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recall ControlNet with toast
|
||||||
|
*/
|
||||||
|
|
||||||
|
const { controlnets } = useGetControlNetModelsQuery(undefined, {
|
||||||
|
selectFromResult: (result) => ({
|
||||||
|
controlnets: result.data
|
||||||
|
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
|
||||||
|
: [],
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const prepareControlNetMetadataItem = useCallback(
|
||||||
|
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||||
|
if (!isValidControlNetModel(controlnetMetadataItem.control_model)) {
|
||||||
|
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const {
|
||||||
|
image,
|
||||||
|
control_model,
|
||||||
|
control_weight,
|
||||||
|
begin_step_percent,
|
||||||
|
end_step_percent,
|
||||||
|
control_mode,
|
||||||
|
resize_mode,
|
||||||
|
} = controlnetMetadataItem;
|
||||||
|
|
||||||
|
const matchingControlNetModel = controlnets.find(
|
||||||
|
(c) =>
|
||||||
|
c.base_model === control_model.base_model &&
|
||||||
|
c.model_name === control_model.model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!matchingControlNetModel) {
|
||||||
|
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const isCompatibleBaseModel =
|
||||||
|
matchingControlNetModel?.base_model === model?.base_model;
|
||||||
|
|
||||||
|
if (!isCompatibleBaseModel) {
|
||||||
|
return {
|
||||||
|
controlnet: null,
|
||||||
|
error: 'ControlNet incompatible with currently-selected model',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const controlNetId = uuidv4();
|
||||||
|
|
||||||
|
let processorType = initialControlNet.processorType;
|
||||||
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
|
||||||
|
processorType =
|
||||||
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
|
||||||
|
initialControlNet.processorType;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||||
|
|
||||||
|
const controlnet: ControlNetConfig = {
|
||||||
|
isEnabled: true,
|
||||||
|
model: matchingControlNetModel,
|
||||||
|
weight:
|
||||||
|
typeof control_weight === 'number'
|
||||||
|
? control_weight
|
||||||
|
: initialControlNet.weight,
|
||||||
|
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||||
|
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||||
|
controlMode: control_mode || initialControlNet.controlMode,
|
||||||
|
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||||
|
controlImage: image?.image_name || null,
|
||||||
|
processedControlImage: image?.image_name || null,
|
||||||
|
processorType,
|
||||||
|
processorNode:
|
||||||
|
processorNode.type !== 'none'
|
||||||
|
? processorNode
|
||||||
|
: initialControlNet.processorNode,
|
||||||
|
shouldAutoConfig: true,
|
||||||
|
controlNetId,
|
||||||
|
};
|
||||||
|
|
||||||
|
return { controlnet, error: null };
|
||||||
|
},
|
||||||
|
[controlnets, model?.base_model]
|
||||||
|
);
|
||||||
|
|
||||||
|
const recallControlNet = useCallback(
|
||||||
|
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||||
|
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
|
||||||
|
|
||||||
|
if (!result.controlnet) {
|
||||||
|
parameterNotSetToast(result.error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
controlNetRecalled({
|
||||||
|
...result.controlnet,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(controlNetEnabled());
|
||||||
|
|
||||||
|
parameterSetToast();
|
||||||
|
},
|
||||||
|
[
|
||||||
|
prepareControlNetMetadataItem,
|
||||||
|
dispatch,
|
||||||
|
parameterSetToast,
|
||||||
|
parameterNotSetToast,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Sets image as initial image with toast
|
* Sets image as initial image with toast
|
||||||
*/
|
*/
|
||||||
@ -428,6 +562,7 @@ export const useRecallParameters = () => {
|
|||||||
refiner_negative_aesthetic_score,
|
refiner_negative_aesthetic_score,
|
||||||
refiner_start,
|
refiner_start,
|
||||||
loras,
|
loras,
|
||||||
|
controlnets,
|
||||||
} = metadata;
|
} = metadata;
|
||||||
|
|
||||||
if (isValidCfgScale(cfg_scale)) {
|
if (isValidCfgScale(cfg_scale)) {
|
||||||
@ -517,6 +652,15 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
dispatch(controlNetReset());
|
||||||
|
dispatch(controlNetEnabled());
|
||||||
|
controlnets?.forEach((controlnet) => {
|
||||||
|
const result = prepareControlNetMetadataItem(controlnet);
|
||||||
|
if (result.controlnet) {
|
||||||
|
dispatch(controlNetRecalled(result.controlnet));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
allParameterSetToast();
|
allParameterSetToast();
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
@ -524,6 +668,7 @@ export const useRecallParameters = () => {
|
|||||||
allParameterSetToast,
|
allParameterSetToast,
|
||||||
dispatch,
|
dispatch,
|
||||||
prepareLoRAMetadataItem,
|
prepareLoRAMetadataItem,
|
||||||
|
prepareControlNetMetadataItem,
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -542,6 +687,7 @@ export const useRecallParameters = () => {
|
|||||||
recallHeight,
|
recallHeight,
|
||||||
recallStrength,
|
recallStrength,
|
||||||
recallLoRA,
|
recallLoRA,
|
||||||
|
recallControlNet,
|
||||||
recallAllParameters,
|
recallAllParameters,
|
||||||
sendToImageToImage,
|
sendToImageToImage,
|
||||||
};
|
};
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import { ButtonGroup } from '@chakra-ui/react';
|
import { ButtonGroup } from '@chakra-ui/react';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
|
||||||
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
|
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
|
||||||
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
|
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
|
||||||
import StatusStatGroup from './common/StatusStatGroup';
|
import StatusStatGroup from './common/StatusStatGroup';
|
||||||
@ -11,16 +9,7 @@ import StatusStatItem from './common/StatusStatItem';
|
|||||||
|
|
||||||
const InvocationCacheStatus = () => {
|
const InvocationCacheStatus = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined);
|
||||||
const { data: queueStatus } = useGetQueueStatusQuery(undefined);
|
|
||||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
|
|
||||||
pollingInterval:
|
|
||||||
isConnected &&
|
|
||||||
queueStatus?.processor.is_started &&
|
|
||||||
queueStatus?.queue.pending > 0
|
|
||||||
? 5000
|
|
||||||
: 0,
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<StatusStatGroup>
|
<StatusStatGroup>
|
||||||
|
@ -1,46 +1,37 @@
|
|||||||
import { Flex, Skeleton, Text } from '@chakra-ui/react';
|
import { Flex, Skeleton } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { COLUMN_WIDTHS } from './constants';
|
import { COLUMN_WIDTHS } from './constants';
|
||||||
|
|
||||||
const QueueItemSkeleton = () => {
|
const QueueItemSkeleton = () => {
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex alignItems="center" p={1.5} gap={4} minH={9} h="full" w="full">
|
||||||
alignItems="center"
|
|
||||||
gap={4}
|
|
||||||
p={1}
|
|
||||||
pb={2}
|
|
||||||
textTransform="uppercase"
|
|
||||||
fontWeight={700}
|
|
||||||
fontSize="xs"
|
|
||||||
letterSpacing={1}
|
|
||||||
>
|
|
||||||
<Flex
|
<Flex
|
||||||
w={COLUMN_WIDTHS.number}
|
w={COLUMN_WIDTHS.number}
|
||||||
justifyContent="flex-end"
|
justifyContent="flex-end"
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
>
|
>
|
||||||
<Skeleton width="20px">
|
<Skeleton w="full" h="full">
|
||||||
<Text variant="subtext"> </Text>
|
|
||||||
</Skeleton>
|
</Skeleton>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex ps={0.5} w={COLUMN_WIDTHS.statusBadge} alignItems="center">
|
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
|
||||||
<Skeleton width="100%">
|
<Skeleton w="full" h="full">
|
||||||
<Text variant="subtext"> </Text>
|
|
||||||
</Skeleton>
|
</Skeleton>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex ps={0.5} w={COLUMN_WIDTHS.time} alignItems="center">
|
<Flex w={COLUMN_WIDTHS.time} alignItems="center">
|
||||||
<Skeleton width="100%">
|
<Skeleton w="full" h="full">
|
||||||
<Text variant="subtext"> </Text>
|
|
||||||
</Skeleton>
|
</Skeleton>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex ps={0.5} w={COLUMN_WIDTHS.batchId} alignItems="center">
|
<Flex w={COLUMN_WIDTHS.batchId} alignItems="center">
|
||||||
<Skeleton width="100%">
|
<Skeleton w="full" h="full">
|
||||||
<Text variant="subtext"> </Text>
|
|
||||||
</Skeleton>
|
</Skeleton>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex ps={0.5} w={COLUMN_WIDTHS.fieldValues} alignItems="center" flex="1">
|
<Flex w={COLUMN_WIDTHS.fieldValues} alignItems="center" flexGrow={1}>
|
||||||
<Skeleton width="100%">
|
<Skeleton w="full" h="full">
|
||||||
<Text variant="subtext"> </Text>
|
|
||||||
</Skeleton>
|
</Skeleton>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||||
import {
|
import {
|
||||||
listCursorChanged,
|
listCursorChanged,
|
||||||
listPriorityChanged,
|
listPriorityChanged,
|
||||||
@ -23,7 +24,6 @@ import QueueItemComponent from './QueueItemComponent';
|
|||||||
import QueueListComponent from './QueueListComponent';
|
import QueueListComponent from './QueueListComponent';
|
||||||
import QueueListHeader from './QueueListHeader';
|
import QueueListHeader from './QueueListHeader';
|
||||||
import { ListContext } from './types';
|
import { ListContext } from './types';
|
||||||
import QueueItemSkeleton from './QueueItemSkeleton';
|
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any;
|
type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any;
|
||||||
@ -126,54 +126,40 @@ const QueueList = () => {
|
|||||||
[openQueueItems, toggleQueueItem]
|
[openQueueItems, toggleQueueItem]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return <IAINoContentFallbackWithSpinner />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!queueItems.length) {
|
||||||
|
return (
|
||||||
|
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||||
|
<Heading color="base.400" _dark={{ color: 'base.500' }}>
|
||||||
|
{t('queue.queueEmpty')}
|
||||||
|
</Heading>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex w="full" h="full" flexDir="column">
|
<Flex w="full" h="full" flexDir="column">
|
||||||
{isLoading ? (
|
<QueueListHeader />
|
||||||
<>
|
<Flex
|
||||||
<QueueListHeader />
|
ref={rootRef}
|
||||||
<QueueItemSkeleton />
|
w="full"
|
||||||
<QueueItemSkeleton />
|
h="full"
|
||||||
<QueueItemSkeleton />
|
alignItems="center"
|
||||||
<QueueItemSkeleton />
|
justifyContent="center"
|
||||||
<QueueItemSkeleton />
|
>
|
||||||
<QueueItemSkeleton />
|
<Virtuoso<SessionQueueItemDTO, ListContext>
|
||||||
<QueueItemSkeleton />
|
data={queueItems}
|
||||||
<QueueItemSkeleton />
|
endReached={handleLoadMore}
|
||||||
<QueueItemSkeleton />
|
scrollerRef={setScroller as TableVirtuosoScrollerRef}
|
||||||
<QueueItemSkeleton />
|
itemContent={itemContent}
|
||||||
</>
|
computeItemKey={computeItemKey}
|
||||||
) : (
|
components={components}
|
||||||
<>
|
context={context}
|
||||||
{queueItems.length ? (
|
/>
|
||||||
<>
|
</Flex>
|
||||||
<QueueListHeader />
|
|
||||||
<Flex
|
|
||||||
ref={rootRef}
|
|
||||||
w="full"
|
|
||||||
h="full"
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
>
|
|
||||||
<Virtuoso<SessionQueueItemDTO, ListContext>
|
|
||||||
data={queueItems}
|
|
||||||
endReached={handleLoadMore}
|
|
||||||
scrollerRef={setScroller as TableVirtuosoScrollerRef}
|
|
||||||
itemContent={itemContent}
|
|
||||||
computeItemKey={computeItemKey}
|
|
||||||
components={components}
|
|
||||||
context={context}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
|
||||||
<Heading color="base.400" _dark={{ color: 'base.500' }}>
|
|
||||||
{t('queue.queueEmpty')}
|
|
||||||
</Heading>
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { isNil } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
@ -40,7 +41,7 @@ export const useCancelCurrentQueueItem = () => {
|
|||||||
}, [currentQueueItemId, dispatch, t, trigger]);
|
}, [currentQueueItemId, dispatch, t, trigger]);
|
||||||
|
|
||||||
const isDisabled = useMemo(
|
const isDisabled = useMemo(
|
||||||
() => !isConnected || !currentQueueItemId,
|
() => !isConnected || isNil(currentQueueItemId),
|
||||||
[isConnected, currentQueueItemId]
|
[isConnected, currentQueueItemId]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import { UseToastOptions } from '@chakra-ui/react';
|
import { UseToastOptions } from '@chakra-ui/react';
|
||||||
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { get, startCase, truncate, upperFirst } from 'lodash-es';
|
import { startCase } from 'lodash-es';
|
||||||
import { LogLevelName } from 'roarr';
|
import { LogLevelName } from 'roarr';
|
||||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
|
||||||
import {
|
import {
|
||||||
appSocketConnected,
|
appSocketConnected,
|
||||||
appSocketDisconnected,
|
appSocketDisconnected,
|
||||||
@ -20,8 +19,7 @@ import {
|
|||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
import { calculateStepPercentage } from '../util/calculateStepPercentage';
|
import { calculateStepPercentage } from '../util/calculateStepPercentage';
|
||||||
import { makeToast } from '../util/makeToast';
|
import { makeToast } from '../util/makeToast';
|
||||||
import { SystemState, LANGUAGES } from './types';
|
import { LANGUAGES, SystemState } from './types';
|
||||||
import { zPydanticValidationError } from './zodSchemas';
|
|
||||||
|
|
||||||
export const initialSystemState: SystemState = {
|
export const initialSystemState: SystemState = {
|
||||||
isInitialized: false,
|
isInitialized: false,
|
||||||
@ -175,50 +173,6 @@ export const systemSlice = createSlice({
|
|||||||
|
|
||||||
// *** Matchers - must be after all cases ***
|
// *** Matchers - must be after all cases ***
|
||||||
|
|
||||||
/**
|
|
||||||
* Session Invoked - REJECTED
|
|
||||||
* Session Created - REJECTED
|
|
||||||
*/
|
|
||||||
builder.addMatcher(isAnySessionRejected, (state, action) => {
|
|
||||||
let errorDescription = undefined;
|
|
||||||
const duration = 5000;
|
|
||||||
|
|
||||||
if (action.payload?.status === 422) {
|
|
||||||
const result = zPydanticValidationError.safeParse(action.payload);
|
|
||||||
if (result.success) {
|
|
||||||
result.data.error.detail.map((e) => {
|
|
||||||
state.toastQueue.push(
|
|
||||||
makeToast({
|
|
||||||
title: truncate(upperFirst(e.msg), { length: 128 }),
|
|
||||||
status: 'error',
|
|
||||||
description: truncate(
|
|
||||||
`Path:
|
|
||||||
${e.loc.join('.')}`,
|
|
||||||
{ length: 128 }
|
|
||||||
),
|
|
||||||
duration,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} else if (action.payload?.error) {
|
|
||||||
errorDescription = action.payload?.error;
|
|
||||||
}
|
|
||||||
|
|
||||||
state.toastQueue.push(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.serverError'),
|
|
||||||
status: 'error',
|
|
||||||
description: truncate(
|
|
||||||
get(errorDescription, 'detail', 'Unknown Error'),
|
|
||||||
{ length: 128 }
|
|
||||||
),
|
|
||||||
duration,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Any server error
|
* Any server error
|
||||||
*/
|
*/
|
||||||
|
@ -2,7 +2,7 @@ import { z } from 'zod';
|
|||||||
|
|
||||||
export const zPydanticValidationError = z.object({
|
export const zPydanticValidationError = z.object({
|
||||||
status: z.literal(422),
|
status: z.literal(422),
|
||||||
error: z.object({
|
data: z.object({
|
||||||
detail: z.array(
|
detail: z.array(
|
||||||
z.object({
|
z.object({
|
||||||
loc: z.array(z.string()),
|
loc: z.array(z.string()),
|
||||||
|
@ -14,7 +14,7 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||||
import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup';
|
import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup';
|
||||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { ResourceKey } from 'i18next';
|
import { ResourceKey } from 'i18next';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
@ -110,7 +110,7 @@ export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager', 'queue'];
|
|||||||
export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue'];
|
export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue'];
|
||||||
|
|
||||||
const InvokeTabs = () => {
|
const InvokeTabs = () => {
|
||||||
const activeTab = useAppSelector(activeTabIndexSelector);
|
const activeTabIndex = useAppSelector(activeTabIndexSelector);
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
const enabledTabs = useAppSelector(enabledTabsSelector);
|
const enabledTabs = useAppSelector(enabledTabsSelector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -150,13 +150,13 @@ const InvokeTabs = () => {
|
|||||||
|
|
||||||
const handleTabChange = useCallback(
|
const handleTabChange = useCallback(
|
||||||
(index: number) => {
|
(index: number) => {
|
||||||
const activeTabName = tabMap[index];
|
const tab = enabledTabs[index];
|
||||||
if (!activeTabName) {
|
if (!tab) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(setActiveTab(activeTabName));
|
dispatch(setActiveTab(tab.id));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch, enabledTabs]
|
||||||
);
|
);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -216,8 +216,8 @@ const InvokeTabs = () => {
|
|||||||
return (
|
return (
|
||||||
<Tabs
|
<Tabs
|
||||||
variant="appTabs"
|
variant="appTabs"
|
||||||
defaultIndex={activeTab}
|
defaultIndex={activeTabIndex}
|
||||||
index={activeTab}
|
index={activeTabIndex}
|
||||||
onChange={handleTabChange}
|
onChange={handleTabChange}
|
||||||
sx={{
|
sx={{
|
||||||
flexGrow: 1,
|
flexGrow: 1,
|
||||||
|
@ -95,26 +95,32 @@ export default function UnifiedCanvasColorPicker() {
|
|||||||
>
|
>
|
||||||
<Flex minWidth={60} direction="column" gap={4} width="100%">
|
<Flex minWidth={60} direction="column" gap={4} width="100%">
|
||||||
{layer === 'base' && (
|
{layer === 'base' && (
|
||||||
<IAIColorPicker
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
width: '100%',
|
width: '100%',
|
||||||
paddingTop: 2,
|
paddingTop: 2,
|
||||||
paddingBottom: 2,
|
paddingBottom: 2,
|
||||||
}}
|
}}
|
||||||
pickerColor={brushColor}
|
>
|
||||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
<IAIColorPicker
|
||||||
/>
|
color={brushColor}
|
||||||
|
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
)}
|
)}
|
||||||
{layer === 'mask' && (
|
{layer === 'mask' && (
|
||||||
<IAIColorPicker
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
width: '100%',
|
width: '100%',
|
||||||
paddingTop: 2,
|
paddingTop: 2,
|
||||||
paddingBottom: 2,
|
paddingBottom: 2,
|
||||||
}}
|
}}
|
||||||
pickerColor={maskColor}
|
>
|
||||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
<IAIColorPicker
|
||||||
/>
|
color={maskColor}
|
||||||
|
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAIPopover>
|
</IAIPopover>
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
import { InvokeTabName, tabMap } from './tabMap';
|
|
||||||
import { UIState } from './uiTypes';
|
|
||||||
|
|
||||||
export const setActiveTabReducer = (
|
|
||||||
state: UIState,
|
|
||||||
newActiveTab: number | InvokeTabName
|
|
||||||
) => {
|
|
||||||
if (typeof newActiveTab === 'number') {
|
|
||||||
state.activeTab = newActiveTab;
|
|
||||||
} else {
|
|
||||||
state.activeTab = tabMap.indexOf(newActiveTab);
|
|
||||||
}
|
|
||||||
};
|
|
@ -1,27 +1,23 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual, isString } from 'lodash-es';
|
||||||
|
import { tabMap } from './tabMap';
|
||||||
import { InvokeTabName, tabMap } from './tabMap';
|
|
||||||
import { UIState } from './uiTypes';
|
|
||||||
|
|
||||||
export const activeTabNameSelector = createSelector(
|
export const activeTabNameSelector = createSelector(
|
||||||
(state: RootState) => state.ui,
|
(state: RootState) => state,
|
||||||
(ui: UIState) => tabMap[ui.activeTab] as InvokeTabName,
|
/**
|
||||||
{
|
* Previously `activeTab` was an integer, but now it's a string.
|
||||||
memoizeOptions: {
|
* Default to first tab in case user has integer.
|
||||||
equalityCheck: isEqual,
|
*/
|
||||||
},
|
({ ui }) => (isString(ui.activeTab) ? ui.activeTab : 'txt2img')
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
export const activeTabIndexSelector = createSelector(
|
export const activeTabIndexSelector = createSelector(
|
||||||
(state: RootState) => state.ui,
|
(state: RootState) => state,
|
||||||
(ui: UIState) => ui.activeTab,
|
({ ui, config }) => {
|
||||||
{
|
const tabs = tabMap.filter((t) => !config.disabledTabs.includes(t));
|
||||||
memoizeOptions: {
|
const idx = tabs.indexOf(ui.activeTab);
|
||||||
equalityCheck: isEqual,
|
return idx === -1 ? 0 : idx;
|
||||||
},
|
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -2,12 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
|||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||||
import { setActiveTabReducer } from './extraReducers';
|
|
||||||
import { InvokeTabName } from './tabMap';
|
import { InvokeTabName } from './tabMap';
|
||||||
import { UIState } from './uiTypes';
|
import { UIState } from './uiTypes';
|
||||||
|
|
||||||
export const initialUIState: UIState = {
|
export const initialUIState: UIState = {
|
||||||
activeTab: 0,
|
activeTab: 'txt2img',
|
||||||
shouldShowImageDetails: false,
|
shouldShowImageDetails: false,
|
||||||
shouldUseCanvasBetaLayout: false,
|
shouldUseCanvasBetaLayout: false,
|
||||||
shouldShowExistingModelsInSearch: false,
|
shouldShowExistingModelsInSearch: false,
|
||||||
@ -26,7 +25,7 @@ export const uiSlice = createSlice({
|
|||||||
initialState: initialUIState,
|
initialState: initialUIState,
|
||||||
reducers: {
|
reducers: {
|
||||||
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
|
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
|
||||||
setActiveTabReducer(state, action.payload);
|
state.activeTab = action.payload;
|
||||||
},
|
},
|
||||||
setShouldShowImageDetails: (state, action: PayloadAction<boolean>) => {
|
setShouldShowImageDetails: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldShowImageDetails = action.payload;
|
state.shouldShowImageDetails = action.payload;
|
||||||
@ -73,7 +72,7 @@ export const uiSlice = createSlice({
|
|||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers(builder) {
|
||||||
builder.addCase(initialImageChanged, (state) => {
|
builder.addCase(initialImageChanged, (state) => {
|
||||||
setActiveTabReducer(state, 'img2img');
|
state.activeTab = 'img2img';
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { InvokeTabName } from './tabMap';
|
||||||
|
|
||||||
export type Coordinates = {
|
export type Coordinates = {
|
||||||
x: number;
|
x: number;
|
||||||
@ -13,7 +14,7 @@ export type Dimensions = {
|
|||||||
export type Rect = Coordinates & Dimensions;
|
export type Rect = Coordinates & Dimensions;
|
||||||
|
|
||||||
export interface UIState {
|
export interface UIState {
|
||||||
activeTab: number;
|
activeTab: InvokeTabName;
|
||||||
shouldShowImageDetails: boolean;
|
shouldShowImageDetails: boolean;
|
||||||
shouldUseCanvasBetaLayout: boolean;
|
shouldUseCanvasBetaLayout: boolean;
|
||||||
shouldShowExistingModelsInSearch: boolean;
|
shouldShowExistingModelsInSearch: boolean;
|
||||||
|
@ -1,184 +0,0 @@
|
|||||||
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
|
|
||||||
import { $queueId } from 'features/queue/store/queueNanoStore';
|
|
||||||
import { isObject } from 'lodash-es';
|
|
||||||
import { $client } from 'services/api/client';
|
|
||||||
import { paths } from 'services/api/schema';
|
|
||||||
import { O } from 'ts-toolbelt';
|
|
||||||
|
|
||||||
type CreateSessionArg = {
|
|
||||||
graph: NonNullable<
|
|
||||||
paths['/api/v1/sessions/']['post']['requestBody']
|
|
||||||
>['content']['application/json'];
|
|
||||||
};
|
|
||||||
|
|
||||||
type CreateSessionResponse = O.Required<
|
|
||||||
NonNullable<
|
|
||||||
paths['/api/v1/sessions/']['post']['requestBody']
|
|
||||||
>['content']['application/json'],
|
|
||||||
'id'
|
|
||||||
>;
|
|
||||||
|
|
||||||
type CreateSessionThunkConfig = {
|
|
||||||
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* `SessionsService.createSession()` thunk
|
|
||||||
*/
|
|
||||||
export const sessionCreated = createAsyncThunk<
|
|
||||||
CreateSessionResponse,
|
|
||||||
CreateSessionArg,
|
|
||||||
CreateSessionThunkConfig
|
|
||||||
>('api/sessionCreated', async (arg, { rejectWithValue }) => {
|
|
||||||
const { graph } = arg;
|
|
||||||
const { POST } = $client.get();
|
|
||||||
const { data, error, response } = await POST('/api/v1/sessions/', {
|
|
||||||
body: graph,
|
|
||||||
params: { query: { queue_id: $queueId.get() } },
|
|
||||||
});
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return rejectWithValue({ arg, status: response.status, error });
|
|
||||||
}
|
|
||||||
|
|
||||||
return data;
|
|
||||||
});
|
|
||||||
|
|
||||||
type InvokedSessionArg = {
|
|
||||||
session_id: paths['/api/v1/sessions/{session_id}/invoke']['put']['parameters']['path']['session_id'];
|
|
||||||
};
|
|
||||||
|
|
||||||
type InvokedSessionResponse =
|
|
||||||
paths['/api/v1/sessions/{session_id}/invoke']['put']['responses']['200']['content']['application/json'];
|
|
||||||
|
|
||||||
type InvokedSessionThunkConfig = {
|
|
||||||
rejectValue: {
|
|
||||||
arg: InvokedSessionArg;
|
|
||||||
error: unknown;
|
|
||||||
status: number;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const isErrorWithStatus = (error: unknown): error is { status: number } =>
|
|
||||||
isObject(error) && 'status' in error;
|
|
||||||
|
|
||||||
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
|
|
||||||
isObject(error) && 'detail' in error;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* `SessionsService.invokeSession()` thunk
|
|
||||||
*/
|
|
||||||
export const sessionInvoked = createAsyncThunk<
|
|
||||||
InvokedSessionResponse,
|
|
||||||
InvokedSessionArg,
|
|
||||||
InvokedSessionThunkConfig
|
|
||||||
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
|
|
||||||
const { session_id } = arg;
|
|
||||||
const { PUT } = $client.get();
|
|
||||||
const { error, response } = await PUT(
|
|
||||||
'/api/v1/sessions/{session_id}/invoke',
|
|
||||||
{
|
|
||||||
params: {
|
|
||||||
query: { queue_id: $queueId.get(), all: true },
|
|
||||||
path: { session_id },
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
if (isErrorWithStatus(error) && error.status === 403) {
|
|
||||||
return rejectWithValue({
|
|
||||||
arg,
|
|
||||||
status: response.status,
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
error: (error as any).body.detail,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (isErrorWithDetail(error) && response.status === 403) {
|
|
||||||
return rejectWithValue({
|
|
||||||
arg,
|
|
||||||
status: response.status,
|
|
||||||
error: error.detail,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (error) {
|
|
||||||
return rejectWithValue({ arg, status: response.status, error });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
type CancelSessionArg =
|
|
||||||
paths['/api/v1/sessions/{session_id}/invoke']['delete']['parameters']['path'];
|
|
||||||
|
|
||||||
type CancelSessionResponse =
|
|
||||||
paths['/api/v1/sessions/{session_id}/invoke']['delete']['responses']['200']['content']['application/json'];
|
|
||||||
|
|
||||||
type CancelSessionThunkConfig = {
|
|
||||||
rejectValue: {
|
|
||||||
arg: CancelSessionArg;
|
|
||||||
error: unknown;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* `SessionsService.cancelSession()` thunk
|
|
||||||
*/
|
|
||||||
export const sessionCanceled = createAsyncThunk<
|
|
||||||
CancelSessionResponse,
|
|
||||||
CancelSessionArg,
|
|
||||||
CancelSessionThunkConfig
|
|
||||||
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
|
|
||||||
const { session_id } = arg;
|
|
||||||
const { DELETE } = $client.get();
|
|
||||||
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
|
|
||||||
params: {
|
|
||||||
path: { session_id },
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return rejectWithValue({ arg, error });
|
|
||||||
}
|
|
||||||
|
|
||||||
return data;
|
|
||||||
});
|
|
||||||
|
|
||||||
type ListSessionsArg = {
|
|
||||||
params: paths['/api/v1/sessions/']['get']['parameters'];
|
|
||||||
};
|
|
||||||
|
|
||||||
type ListSessionsResponse =
|
|
||||||
paths['/api/v1/sessions/']['get']['responses']['200']['content']['application/json'];
|
|
||||||
|
|
||||||
type ListSessionsThunkConfig = {
|
|
||||||
rejectValue: {
|
|
||||||
arg: ListSessionsArg;
|
|
||||||
error: unknown;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* `SessionsService.listSessions()` thunk
|
|
||||||
*/
|
|
||||||
export const listedSessions = createAsyncThunk<
|
|
||||||
ListSessionsResponse,
|
|
||||||
ListSessionsArg,
|
|
||||||
ListSessionsThunkConfig
|
|
||||||
>('api/listSessions', async (arg, { rejectWithValue }) => {
|
|
||||||
const { params } = arg;
|
|
||||||
const { GET } = $client.get();
|
|
||||||
const { data, error } = await GET('/api/v1/sessions/', {
|
|
||||||
params,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return rejectWithValue({ arg, error });
|
|
||||||
}
|
|
||||||
|
|
||||||
return data;
|
|
||||||
});
|
|
||||||
|
|
||||||
export const isAnySessionRejected = isAnyOf(
|
|
||||||
sessionCreated.rejected,
|
|
||||||
sessionInvoked.rejected
|
|
||||||
);
|
|
@ -1,5 +1,4 @@
|
|||||||
import { ThemeOverride } from '@chakra-ui/react';
|
import { ThemeOverride, ToastProviderProps } from '@chakra-ui/react';
|
||||||
|
|
||||||
import { InvokeAIColors } from './colors/colors';
|
import { InvokeAIColors } from './colors/colors';
|
||||||
import { accordionTheme } from './components/accordion';
|
import { accordionTheme } from './components/accordion';
|
||||||
import { buttonTheme } from './components/button';
|
import { buttonTheme } from './components/button';
|
||||||
@ -149,3 +148,7 @@ export const theme: ThemeOverride = {
|
|||||||
Tooltip: tooltipTheme,
|
Tooltip: tooltipTheme,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const TOAST_OPTIONS: ToastProviderProps = {
|
||||||
|
defaultOptions: { isClosable: true },
|
||||||
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user