Merge branch 'main' into ryan/t2i-adapter

This commit is contained in:
Ryan Dick 2023-09-21 14:26:04 -04:00
commit cd8c53c50d
117 changed files with 2169 additions and 540 deletions

View File

@ -159,7 +159,7 @@ groups in `invokeia.yaml`:
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |

View File

@ -17,9 +17,10 @@ echo 6. Change InvokeAI startup options
echo 7. Re-run the configure script to fix a broken install or to complete a major upgrade
echo 8. Open the developer console
echo 9. Update InvokeAI
echo 10. Command-line help
echo 10. Run the InvokeAI image database maintenance script
echo 11. Command-line help
echo Q - Quit
set /P choice="Please enter 1-10, Q: [1] "
set /P choice="Please enter 1-11, Q: [1] "
if not defined choice set choice=1
IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI..
@ -58,8 +59,11 @@ IF /I "%choice%" == "1" (
echo Running invokeai-update...
python -m invokeai.frontend.install.invokeai_update
) ELSE IF /I "%choice%" == "10" (
echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "11" (
echo Displaying command line help...
python .venv\Scripts\invokeai.exe --help %*
python .venv\Scripts\invokeai-web.exe --help %*
pause
exit /b
) ELSE IF /I "%choice%" == "q" (

View File

@ -97,13 +97,13 @@ do_choice() {
;;
10)
clear
printf "Command-line help\n"
invokeai --help
printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
;;
"HELP 1")
11)
clear
printf "Command-line help\n"
invokeai --help
invokeai-web --help
;;
*)
clear
@ -125,7 +125,10 @@ do_dialog() {
6 "Change InvokeAI startup options"
7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
8 "Open the developer console"
9 "Update InvokeAI")
9 "Update InvokeAI"
10 "Run the InvokeAI image database maintenance script"
11 "Command-line help"
)
choice=$(dialog --clear \
--backtitle "\Zb\Zu\Z3InvokeAI" \
@ -157,9 +160,10 @@ do_line_input() {
printf "7: Re-run the configure script to fix a broken install\n"
printf "8: Open the developer console\n"
printf "9: Update InvokeAI\n"
printf "10: Command-line help\n"
printf "10: Run the InvokeAI image database maintenance script\n"
printf "11: Command-line help\n"
printf "Q: Quit\n\n"
read -p "Please enter 1-10, Q: [1] " yn
read -p "Please enter 1-11, Q: [1] " yn
choice=${yn:='1'}
do_choice $choice
clear

View File

@ -7,6 +7,7 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
@ -113,3 +114,33 @@ async def set_log_level(
async def clear_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.clear()
@app_router.put(
"/invocation_cache/enable",
operation_id="enable_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def enable_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.enable()
@app_router.put(
"/invocation_cache/disable",
operation_id="disable_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def disable_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.disable()
@app_router.get(
"/invocation_cache/status",
operation_id="get_invocation_cache_status",
responses={200: {"model": InvocationCacheStatus}},
)
async def get_invocation_cache_status() -> InvocationCacheStatus:
"""Clears the invocation cache"""
return ApiDependencies.invoker.services.invocation_cache.get_status()

View File

@ -3,16 +3,19 @@
from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from fastapi_socketio import SocketManager
from socketio import ASGIApp, AsyncServer
from ..services.events import EventServiceBase
class SocketIO:
__sio: SocketManager
__sio: AsyncServer
__app: ASGIApp
def __init__(self, app: FastAPI):
self.__sio = SocketManager(app=app)
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="socket.io")
app.mount("/ws", self.__app)
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)

View File

@ -38,7 +38,6 @@ from .baseinvocation import (
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
@ -100,7 +99,7 @@ class ControlNetInvocation(BaseInvocation):
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
default=1.0, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"

View File

@ -58,9 +58,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)

View File

@ -1,12 +1,14 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
@ -211,7 +213,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
@ -221,7 +223,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
control: Union[ControlField, list[ControlField]] = InputField(
default=None,
description=FieldDescriptions.control,
input=Input.Connection,
ui_order=5,
)
@ -955,8 +956,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
@ -983,6 +983,18 @@ class ImageToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
class BlendLatentsInvocation(BaseInvocation):

View File

@ -42,7 +42,8 @@ class CoreMetadata(BaseModelExcludeNull):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(
clip_skip: Optional[int] = Field(
default=None,
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference")
@ -116,7 +117,8 @@ class MetadataAccumulatorInvocation(BaseInvocation):
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
steps: int = InputField(description="The number of steps used for inference")
scheduler: str = InputField(description="The scheduler used for inference")
clip_skip: int = InputField(
clip_skip: Optional[int] = Field(
default=None,
description="The number of skipped CLIP layers",
)
model: MainModelField = InputField(description="The main model used for inference")

View File

@ -166,7 +166,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
default=7.5,
ge=1,
description=FieldDescriptions.cfg_scale,
ui_type=UIType.Float,
)
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
@ -179,7 +178,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
description=FieldDescriptions.control,
ui_type=UIType.Control,
)
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
class InvocationCacheBase(ABC):
@ -32,7 +33,7 @@ class InvocationCacheBase(ABC):
@abstractmethod
def delete(self, key: Union[int, str]) -> None:
"""Deleteds an invocation output from the cache"""
"""Deletes an invocation output from the cache"""
pass
@abstractmethod
@ -44,3 +45,18 @@ class InvocationCacheBase(ABC):
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass
@abstractmethod
def disable(self) -> None:
"""Disables the cache, overriding the max cache size"""
pass
@abstractmethod
def enable(self) -> None:
"""Enables the cache, letting the the max cache size take effect"""
pass
@abstractmethod
def get_status(self) -> InvocationCacheStatus:
"""Returns the status of the cache"""
pass

View File

@ -0,0 +1,9 @@
from pydantic import BaseModel, Field
class InvocationCacheStatus(BaseModel):
size: int = Field(description="The current size of the invocation cache")
hits: int = Field(description="The number of cache hits")
misses: int = Field(description="The number of cache misses")
enabled: bool = Field(description="Whether the invocation cache is enabled")
max_size: int = Field(description="The maximum size of the invocation cache")

View File

@ -3,18 +3,25 @@ from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int
__disabled: bool
__hits: int
__misses: int
__cache_ids: Queue
__invoker: Invoker
def __init__(self, max_cache_size: int = 0) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__disabled = False
self.__hits = 0
self.__misses = 0
self.__cache_ids = Queue()
def start(self, invoker: Invoker) -> None:
@ -25,15 +32,17 @@ class MemoryInvocationCache(InvocationCacheBase):
self.__invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0:
if self.__max_cache_size == 0 or self.__disabled:
return
item = self.__cache.get(key, None)
if item is not None:
self.__hits += 1
return item[0]
self.__misses += 1
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0:
if self.__max_cache_size == 0 or self.__disabled:
return
if key not in self.__cache:
@ -47,25 +56,46 @@ class MemoryInvocationCache(InvocationCacheBase):
pass
def delete(self, key: Union[int, str]) -> None:
if self.__max_cache_size == 0:
if self.__max_cache_size == 0 or self.__disabled:
return
if key in self.__cache:
del self.__cache[key]
def clear(self, *args, **kwargs) -> None:
if self.__max_cache_size == 0:
if self.__max_cache_size == 0 or self.__disabled:
return
self.__cache.clear()
self.__cache_ids = Queue()
self.__misses = 0
self.__hits = 0
def create_key(self, invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"}))
def _delete_by_match(self, to_match: str) -> None:
def disable(self) -> None:
if self.__max_cache_size == 0:
return
self.__disabled = True
def enable(self) -> None:
if self.__max_cache_size == 0:
return
self.__disabled = False
def get_status(self) -> InvocationCacheStatus:
return InvocationCacheStatus(
hits=self.__hits,
misses=self.__misses,
enabled=not self.__disabled and self.__max_cache_size > 0,
size=len(self.__cache),
max_size=self.__max_cache_size,
)
def _delete_by_match(self, to_match: str) -> None:
if self.__max_cache_size == 0 or self.__disabled:
return
keys_to_delete = set()
for key, value_tuple in self.__cache.items():

View File

@ -92,30 +92,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__invoker.services.logger
while not stop_event.is_set():
poll_now_event.clear()
try:
# do not dequeue if there is already a session running
if self.__queue_item is None and resume_event.is_set():
queue_item = self.__invoker.services.session_queue.dequeue()
# do not dequeue if there is already a session running
if self.__queue_item is None and resume_event.is_set():
queue_item = self.__invoker.services.session_queue.dequeue()
if queue_item is not None:
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(
session_queue_batch_id=queue_item.batch_id,
session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session,
invoke_all=True,
)
queue_item = None
if queue_item is not None:
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(
session_queue_batch_id=queue_item.batch_id,
session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session,
invoke_all=True,
)
queue_item = None
if queue_item is None:
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
if queue_item is None:
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
pass
finally:
stop_event.clear()

View File

@ -162,15 +162,15 @@ class SessionQueueItemWithoutGraph(BaseModel):
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
queue_id: str = Field(description="The id of the queue with which this item is associated")
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
queue_id: str = Field(description="The id of the queue with which this item is associated")
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
@classmethod
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":

View File

@ -77,7 +77,6 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
@ -86,8 +85,8 @@ class SqliteSessionQueue(SessionQueueBase):
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
@ -95,8 +94,8 @@ class SqliteSessionQueue(SessionQueueBase):
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
except SessionQueueItemNotFoundError:
return
@ -354,7 +353,6 @@ class SqliteSessionQueue(SessionQueueBase):
return None
queue_item = SessionQueueItem.from_dict(dict(result))
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
@ -427,7 +425,9 @@ class SqliteSessionQueue(SessionQueueBase):
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
queue_item = self.get_queue_item(item_id)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
try:
@ -565,7 +565,6 @@ class SqliteSessionQueue(SessionQueueBase):
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:

View File

@ -42,4 +42,4 @@ IP-Adapters:
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- Not yet supported: [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)

View File

@ -1,4 +1,5 @@
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union
@ -53,6 +54,7 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
@ -178,6 +180,7 @@ class ModelProbe(object):
Get the model type of a hugging-face style folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
@ -203,12 +206,18 @@ class ModelProbe(object):
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
@ -467,16 +476,32 @@ class PipelineFolderProbe(FolderProbeBase):
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return (
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
else BaseModelType.StableDiffusion1
)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):

View File

@ -0,0 +1,568 @@
# pylint: disable=line-too-long
# pylint: disable=broad-exception-caught
# pylint: disable=missing-function-docstring
"""Script to peform db maintenance and outputs directory management."""
import argparse
import datetime
import enum
import glob
import locale
import os
import shutil
import sqlite3
from pathlib import Path
import PIL
import PIL.ImageOps
import PIL.PngImagePlugin
import yaml
class ConfigMapper:
"""Configuration loader."""
def __init__(self): # noqa D107
pass
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
INVOKE_DIRNAME = "invokeai"
YAML_FILENAME = "invokeai.yaml"
DATABASE_FILENAME = "invokeai.db"
database_path = None
database_backup_dir = None
outputs_path = None
archive_path = None
thumbnails_path = None
thumbnails_archive_path = None
def load(self):
"""Read paths from yaml config and validate."""
root = "."
if not self.__load_from_root_config(os.path.abspath(root)):
return False
return True
def __load_from_root_config(self, invoke_root):
"""Validate a yaml path exists, confirm the user wants to use it and load config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.__load_paths_from_yaml_file(yaml_path)
if db_dir is None or outdir is None:
print("The invokeai.yaml file was found but is missing the db_dir and/or outdir setting!")
return False
if os.path.isabs(db_dir):
self.database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
else:
self.database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
if os.path.isabs(outdir):
self.outputs_path = os.path.join(outdir, "images")
self.archive_path = os.path.join(outdir, "images-archive")
else:
self.outputs_path = os.path.join(invoke_root, outdir, "images")
self.archive_path = os.path.join(invoke_root, outdir, "images-archive")
self.thumbnails_path = os.path.join(self.outputs_path, "thumbnails")
self.thumbnails_archive_path = os.path.join(self.archive_path, "thumbnails")
db_exists = os.path.exists(self.database_path)
outdir_exists = os.path.exists(self.outputs_path)
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
text += f"\n Database : {self.database_path} - {'Exists!' if db_exists else 'Not Found!'}"
text += f"\n Outputs : {self.outputs_path}- {'Exists!' if outdir_exists else 'Not Found!'}"
print(text)
if db_exists and outdir_exists:
return True
else:
print(
"\nOne or more paths specified in invoke.yaml do not exist. Please inspect/correct the configuration and ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
)
return False
else:
print(
f"Auto-discovery of configuration failed! Could not find ({yaml_path})!\n\nPlease ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
)
return False
def __load_paths_from_yaml_file(self, yaml_path):
"""Load an Invoke AI yaml file and get the database and outputs paths."""
try:
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
yamlinfo = yaml.safe_load(file)
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
return db_dir, outdir
except Exception:
print(f"Failed to load paths from yaml file! {yaml_path}!")
return None, None
class MaintenanceStats:
"""DTO for tracking work progress."""
def __init__(self): # noqa D107
pass
time_start = datetime.datetime.utcnow()
count_orphaned_db_entries_cleaned = 0
count_orphaned_disk_files_cleaned = 0
count_orphaned_thumbnails_cleaned = 0
count_thumbnails_regenerated = 0
count_errors = 0
@staticmethod
def get_elapsed_time_string():
"""Get a friendly time string for the time elapsed since processing start."""
time_now = datetime.datetime.utcnow()
total_seconds = (time_now - MaintenanceStats.time_start).total_seconds()
hours = int((total_seconds) / 3600)
minutes = int(((total_seconds) % 3600) / 60)
seconds = total_seconds % 60
out_str = f"{hours} hour(s) -" if hours > 0 else ""
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
out_str += f"{seconds:.2f} second(s)"
return out_str
class DatabaseMapper:
"""Class to abstract database functionality."""
def __init__(self, database_path, database_backup_dir): # noqa D107
self.database_path = database_path
self.database_backup_dir = database_backup_dir
self.connection = None
self.cursor = None
def backup(self, timestamp_string):
"""Take a backup of the database."""
if not os.path.exists(self.database_backup_dir):
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
os.makedirs(self.database_backup_dir)
print("Done!")
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
print(f"Making DB Backup at {database_backup_path}...", end="")
shutil.copy2(self.database_path, database_backup_path)
print("Done!")
def connect(self):
"""Open connection to the database."""
self.connection = sqlite3.connect(self.database_path)
self.cursor = self.connection.cursor()
def get_all_image_files(self):
"""Get the full list of image file names from the database."""
sql_get_image_by_name = "SELECT image_name FROM images"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
db_files = []
for row in rows:
db_files.append(row[0])
return db_files
def remove_image_file_record(self, filename: str):
"""Remove an image file reference from the database by filename."""
sanitized_filename = str.replace(filename, "'", "''") # prevent injection
sql_command = f"DELETE FROM images WHERE image_name='{sanitized_filename}'"
self.cursor.execute(sql_command)
self.connection.commit()
def does_image_exist(self, image_filename):
"""Check database if a image name already exists and return a boolean."""
sanitized_filename = str.replace(image_filename, "'", "''") # prevent injection
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{sanitized_filename}'"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
return True if len(rows) > 0 else False
def disconnect(self):
"""Disconnect from the db, cleaning up connections and cursors."""
if self.cursor is not None:
self.cursor.close()
if self.connection is not None:
self.connection.close()
class PhysicalFileMapper:
"""Containing class for script functionality."""
def __init__(self, outputs_path, thumbnails_path, archive_path, thumbnails_archive_path): # noqa D107
self.outputs_path = outputs_path
self.archive_path = archive_path
self.thumbnails_path = thumbnails_path
self.thumbnails_archive_path = thumbnails_archive_path
def create_archive_directories(self):
"""Create the directory for archiving orphaned image files."""
if not os.path.exists(self.archive_path):
print(f"Image archive directory ({self.archive_path}) does not exist -> creating...", end="")
os.makedirs(self.archive_path)
print("Created!")
if not os.path.exists(self.thumbnails_archive_path):
print(
f"Image thumbnails archive directory ({self.thumbnails_archive_path}) does not exist -> creating...",
end="",
)
os.makedirs(self.thumbnails_archive_path)
print("Created!")
def get_image_path_for_image_name(self, image_filename): # noqa D102
return os.path.join(self.outputs_path, image_filename)
def image_file_exists(self, image_filename): # noqa D102
return os.path.exists(self.get_image_path_for_image_name(image_filename))
def get_thumbnail_path_for_image(self, image_filename): # noqa D102
return os.path.join(self.thumbnails_path, os.path.splitext(image_filename)[0]) + ".webp"
def get_image_name_from_thumbnail_path(self, thumbnail_path): # noqa D102
return os.path.splitext(os.path.basename(thumbnail_path))[0] + ".png"
def thumbnail_exists_for_filename(self, image_filename): # noqa D102
return os.path.exists(self.get_thumbnail_path_for_image(image_filename))
def archive_image(self, image_filename): # noqa D102
if self.image_file_exists(image_filename):
image_path = self.get_image_path_for_image_name(image_filename)
shutil.move(image_path, self.archive_path)
def archive_thumbnail_by_image_filename(self, image_filename): # noqa D102
if self.thumbnail_exists_for_filename(image_filename):
thumbnail_path = self.get_thumbnail_path_for_image(image_filename)
shutil.move(thumbnail_path, self.thumbnails_archive_path)
def get_all_png_filenames_in_directory(self, directory_path): # noqa D102
filepaths = glob.glob(directory_path + "/*.png", recursive=False)
filenames = []
for filepath in filepaths:
filenames.append(os.path.basename(filepath))
return filenames
def get_all_thumbnails_with_full_path(self, thumbnails_directory): # noqa D102
return glob.glob(thumbnails_directory + "/*.webp", recursive=False)
def generate_thumbnail_for_image_name(self, image_filename): # noqa D102
# create thumbnail
file_path = self.get_image_path_for_image_name(image_filename)
thumb_path = self.get_thumbnail_path_for_image(image_filename)
thumb_size = 256, 256
with PIL.Image.open(file_path) as source_image:
source_image.thumbnail(thumb_size)
source_image.save(thumb_path, "webp")
class MaintenanceOperation(str, enum.Enum):
"""Enum class for operations."""
Ask = "ask"
CleanOrphanedDbEntries = "clean"
CleanOrphanedDiskFiles = "archive"
ReGenerateThumbnails = "thumbnails"
All = "all"
class InvokeAIDatabaseMaintenanceApp:
"""Main processor class for the application."""
_operation: MaintenanceOperation
_headless: bool = False
__stats: MaintenanceStats = MaintenanceStats()
def __init__(self, operation: MaintenanceOperation = MaintenanceOperation.Ask):
"""Initialize maintenance app."""
self._operation = MaintenanceOperation(operation)
self._headless = operation != MaintenanceOperation.Ask
def ask_for_operation(self) -> MaintenanceOperation:
"""Ask user to choose the operation to perform."""
while True:
print()
print("It is recommennded to run these operations as ordered below to avoid additional")
print("work being performed that will be discarded in a subsequent step.")
print()
print("Select maintenance operation:")
print()
print("1) Clean Orphaned Database Image Entries")
print(" Cleans entries in the database where the matching file was removed from")
print(" the outputs directory.")
print("2) Archive Orphaned Image Files")
print(" Files found in the outputs directory without an entry in the database are")
print(" moved to an archive directory.")
print("3) Re-Generate Missing Thumbnail Files")
print(" For files found in the outputs directory, re-generate a thumbnail if it")
print(" not found in the thumbnails directory.")
print()
print("(CTRL-C to quit)")
try:
input_option = int(input("Specify desired operation number (1-3): "))
operations = [
MaintenanceOperation.CleanOrphanedDbEntries,
MaintenanceOperation.CleanOrphanedDiskFiles,
MaintenanceOperation.ReGenerateThumbnails,
]
return operations[input_option - 1]
except (IndexError, ValueError):
print("\nInvalid selection!")
def ask_to_continue(self) -> bool:
"""Ask user whether they want to continue with the operation."""
while True:
input_choice = input("Do you wish to continue? (Y or N)? ")
if str.lower(input_choice) == "y":
return True
if str.lower(input_choice) == "n":
return False
def clean_orphaned_db_entries(
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
):
"""Clean dangling database entries that no longer point to a file in outputs."""
if self._headless:
print(f"Removing database references to images that no longer exist in {config.outputs_path}...")
else:
print()
print("===============================================================================")
print("= Clean Orphaned Database Entries")
print()
print("Perform this operation if you have removed files from the outputs/images")
print("directory but the database was never updated. You may see this as empty imaages")
print("in the app gallery, or images that only show an enlarged version of the")
print("thumbnail.")
print()
print(f"Database File Path : {config.database_path}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Outputs/Images Archive Directory : {config.archive_path}")
print("\nNotes about this operation:")
print("- This operation will find database image file entries that do not exist in the")
print(" outputs/images dir and remove those entries from the database.")
print("- This operation will target all image types including intermediate files.")
print("- If a thumbnail still exists in outputs/images/thumbnails matching the")
print(" orphaned entry, it will be moved to the archive directory.")
print()
if not self.ask_to_continue():
raise KeyboardInterrupt
file_mapper.create_archive_directories()
db_mapper.backup(config.TIMESTAMP_STRING)
db_mapper.connect()
db_files = db_mapper.get_all_image_files()
for db_file in db_files:
try:
if not file_mapper.image_file_exists(db_file):
print(f"Found orphaned image db entry {db_file}. Cleaning ...", end="")
db_mapper.remove_image_file_record(db_file)
print("Cleaned!")
if file_mapper.thumbnail_exists_for_filename(db_file):
print("A thumbnail was found, archiving ...", end="")
file_mapper.archive_thumbnail_by_image_filename(db_file)
print("Archived!")
self.__stats.count_orphaned_db_entries_cleaned += 1
except Exception as ex:
print("An error occurred cleaning db entry, error was:")
print(ex)
self.__stats.count_errors += 1
def clean_orphaned_disk_files(
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
):
"""Archive image files that no longer have entries in the database."""
if self._headless:
print(f"Archiving orphaned image files to {config.archive_path}...")
else:
print()
print("===============================================================================")
print("= Clean Orphaned Disk Files")
print()
print("Perform this operation if you have files that were copied into the outputs")
print("directory which are not referenced by the database. This can happen if you")
print("upgraded to a version with a fresh database, but re-used the outputs directory")
print("and now new images are mixed with the files not in the db. The script will")
print("archive these files so you can choose to delete them or re-import using the")
print("official import script.")
print()
print(f"Database File Path : {config.database_path}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Outputs/Images Archive Directory : {config.archive_path}")
print("\nNotes about this operation:")
print("- This operation will find image files not referenced by the database and move to an")
print(" archive directory.")
print("- This operation will target all image types including intermediate references.")
print("- The matching thumbnail will also be archived.")
print("- Any remaining orphaned thumbnails will also be archived.")
if not self.ask_to_continue():
raise KeyboardInterrupt
print()
file_mapper.create_archive_directories()
db_mapper.backup(config.TIMESTAMP_STRING)
db_mapper.connect()
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
for phys_file in phys_files:
try:
if not db_mapper.does_image_exist(phys_file):
print(f"Found orphaned file {phys_file}, archiving...", end="")
file_mapper.archive_image(phys_file)
print("Archived!")
if file_mapper.thumbnail_exists_for_filename(phys_file):
print("Related thumbnail exists, archiving...", end="")
file_mapper.archive_thumbnail_by_image_filename(phys_file)
print("Archived!")
else:
print("No matching thumbnail existed to be cleaned.")
self.__stats.count_orphaned_disk_files_cleaned += 1
except Exception as ex:
print("Error found trying to archive file or thumbnail, error was:")
print(ex)
self.__stats.count_errors += 1
thumb_filepaths = file_mapper.get_all_thumbnails_with_full_path(config.thumbnails_path)
# archive any remaining orphaned thumbnails
for thumb_filepath in thumb_filepaths:
try:
thumb_src_image_name = file_mapper.get_image_name_from_thumbnail_path(thumb_filepath)
if not file_mapper.image_file_exists(thumb_src_image_name):
print(f"Found orphaned thumbnail {thumb_filepath}, archiving...", end="")
file_mapper.archive_thumbnail_by_image_filename(thumb_src_image_name)
print("Archived!")
self.__stats.count_orphaned_thumbnails_cleaned += 1
except Exception as ex:
print("Error found trying to archive thumbnail, error was:")
print(ex)
self.__stats.count_errors += 1
def regenerate_thumbnails(self, config: ConfigMapper, file_mapper: PhysicalFileMapper, *args):
"""Create missing thumbnails for any valid general images both in the db and on disk."""
if self._headless:
print("Regenerating missing image thumbnails...")
else:
print()
print("===============================================================================")
print("= Regenerate Thumbnails")
print()
print("This operation will find files that have no matching thumbnail on disk")
print("and regenerate those thumbnail files.")
print("NOTE: It is STRONGLY recommended that the user first clean/archive orphaned")
print(" disk files from the previous menu to avoid wasting time regenerating")
print(" thumbnails for orphaned files.")
print()
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Outputs/Images Directory : {config.thumbnails_path}")
print("\nNotes about this operation:")
print("- This operation will find image files both referenced in the db and on disk")
print(" that do not have a matching thumbnail on disk and re-generate the thumbnail")
print(" file.")
if not self.ask_to_continue():
raise KeyboardInterrupt
print()
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
for phys_file in phys_files:
try:
if not file_mapper.thumbnail_exists_for_filename(phys_file):
print(f"Found file without thumbnail {phys_file}...Regenerating Thumbnail...", end="")
file_mapper.generate_thumbnail_for_image_name(phys_file)
print("Done!")
self.__stats.count_thumbnails_regenerated += 1
except Exception as ex:
print("Error found trying to regenerate thumbnail, error was:")
print(ex)
self.__stats.count_errors += 1
def main(self): # noqa D107
print("\n===============================================================================")
print("Database and outputs Maintenance for Invoke AI 3.0.0 +")
print("===============================================================================\n")
config_mapper = ConfigMapper()
if not config_mapper.load():
print("\nInvalid configuration...exiting.\n")
return
file_mapper = PhysicalFileMapper(
config_mapper.outputs_path,
config_mapper.thumbnails_path,
config_mapper.archive_path,
config_mapper.thumbnails_archive_path,
)
db_mapper = DatabaseMapper(config_mapper.database_path, config_mapper.database_backup_dir)
op = self._operation
operations_to_perform = []
if op == MaintenanceOperation.Ask:
op = self.ask_for_operation()
if op in [MaintenanceOperation.CleanOrphanedDbEntries, MaintenanceOperation.All]:
operations_to_perform.append(self.clean_orphaned_db_entries)
if op in [MaintenanceOperation.CleanOrphanedDiskFiles, MaintenanceOperation.All]:
operations_to_perform.append(self.clean_orphaned_disk_files)
if op in [MaintenanceOperation.ReGenerateThumbnails, MaintenanceOperation.All]:
operations_to_perform.append(self.regenerate_thumbnails)
for operation in operations_to_perform:
operation(config_mapper, file_mapper, db_mapper)
print("\n===============================================================================")
print(f"= Maintenance Complete - Elapsed Time: {MaintenanceStats.get_elapsed_time_string()}")
print()
print(f"Orphaned db entries cleaned : {self.__stats.count_orphaned_db_entries_cleaned}")
print(f"Orphaned disk files archived : {self.__stats.count_orphaned_disk_files_cleaned}")
print(f"Orphaned thumbnail files archived : {self.__stats.count_orphaned_thumbnails_cleaned}")
print(f"Thumbnails regenerated : {self.__stats.count_thumbnails_regenerated}")
print(f"Errors during operation : {self.__stats.count_errors}")
print()
def main(): # noqa D107
parser = argparse.ArgumentParser(
description="InvokeAI image database maintenance utility",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""Operations:
ask Choose operation from a menu [default]
all Run all maintenance operations
clean Clean database of dangling entries
archive Archive orphaned image files
thumbnails Regenerate missing image thumbnails
""",
)
parser.add_argument("--root", default=".", type=Path, help="InvokeAI root directory")
parser.add_argument(
"--operation", default="ask", choices=[x.value for x in MaintenanceOperation], help="Operation to perform."
)
args = parser.parse_args()
try:
os.chdir(args.root)
app = InvokeAIDatabaseMaintenanceApp(args.operation)
app.main()
except KeyboardInterrupt:
print("\n\nUser cancelled execution.")
except FileNotFoundError:
print(f"Invalid root directory '{args.root}'.")
if __name__ == "__main__":
main()

View File

@ -264,6 +264,22 @@
"graphQueued": "Graph queued",
"graphFailedToQueue": "Failed to queue graph"
},
"invocationCache": {
"invocationCache": "Invocation Cache",
"cacheSize": "Cache Size",
"maxCacheSize": "Max Cache Size",
"hits": "Cache Hits",
"misses": "Cache Misses",
"clear": "Clear",
"clearSucceeded": "Invocation Cache Cleared",
"clearFailed": "Problem Clearing Invocation Cache",
"enable": "Enable",
"enableSucceeded": "Invocation Cache Enabled",
"enableFailed": "Problem Enabling Invocation Cache",
"disable": "Disable",
"disableSucceeded": "Invocation Cache Disabled",
"disableFailed": "Problem Disabling Invocation Cache"
},
"gallery": {
"allImagesLoaded": "All Images Loaded",
"assets": "Assets",
@ -1213,14 +1229,14 @@
},
"dynamicPromptsCombinatorial": {
"heading": "Combinatorial Generation",
"paragraph": "Generate an image for every possible combination of Dynamic Prompt until the Max Prompts is reached."
"paragraph": "Generate an image for every possible combination of Dynamic Prompts until the Max Prompts is reached."
},
"infillMethod": {
"heading": "Infill Method",
"paragraph": "Method to infill the selected area."
},
"lora": {
"heading": "LoRA",
"heading": "LoRA Weight",
"paragraph": "Weight of the LoRA. Higher weight will lead to larger impacts on the final image."
},
"noiseEnable": {
@ -1239,21 +1255,21 @@
"heading": "Denoising Strength",
"paragraph": "How much noise is added to the input image. 0 will result in an identical image, while 1 will result in a completely new image."
},
"paramImages": {
"heading": "Images",
"paragraph": "Number of images that will be generated."
"paramIterations": {
"heading": "Iterations",
"paragraph": "The number of images to generate. If Dynamic Prompts is enabled, each of the prompts will be generated this many times."
},
"paramModel": {
"heading": "Model",
"paragraph": "Model used for the denoising steps. Different models are trained to specialize in producing different aesthetic results and content."
},
"paramNegativeConditioning": {
"heading": "Negative Prompts",
"paragraph": "This is where you enter your negative prompts."
"heading": "Negative Prompt",
"paragraph": "The generation process avoids the concepts in the negative prompt. Use this to exclude qualities or objects from the output. Supports Compel syntax and embeddings."
},
"paramPositiveConditioning": {
"heading": "Positive Prompts",
"paragraph": "This is where you enter your positive prompts."
"heading": "Positive Prompt",
"paragraph": "Guides the generation process. You may use any words or phrases. Supports Compel and Dynamic Prompts syntaxes and embeddings."
},
"paramRatio": {
"heading": "Ratio",

View File

@ -3,7 +3,7 @@ import { startAppListening } from '..';
import { $logger } from 'app/logging/logger';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { t } from 'i18next';
export const addCanvasCopiedToClipboardListener = () => {
@ -15,10 +15,12 @@ export const addCanvasCopiedToClipboardListener = () => {
.child({ namespace: 'canvasCopiedToClipboardListener' });
const state = getState();
const blob = await getBaseLayerBlob(state);
try {
const blob = getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
copyBlobToClipboard(blob);
} catch (err) {
moduleLog.error(String(err));
dispatch(
addToast({
title: t('toast.problemCopyingCanvas'),
@ -29,8 +31,6 @@ export const addCanvasCopiedToClipboardListener = () => {
return;
}
copyBlobToClipboard(blob);
dispatch(
addToast({
title: t('toast.canvasCopiedClipboard'),

View File

@ -15,10 +15,11 @@ export const addCanvasDownloadedAsImageListener = () => {
.child({ namespace: 'canvasSavedToGalleryListener' });
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
moduleLog.error(String(err));
dispatch(
addToast({
title: t('toast.problemDownloadingCanvas'),

View File

@ -14,10 +14,11 @@ export const addCanvasImageToControlNetListener = () => {
const log = logger('canvas');
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
log.error('Problem getting base layer blob');
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
log.error(String(err));
dispatch(
addToast({
title: t('toast.problemSavingCanvas'),

View File

@ -13,10 +13,11 @@ export const addCanvasSavedToGalleryListener = () => {
const log = logger('canvas');
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
log.error('Problem getting base layer blob');
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
log.error(String(err));
dispatch(
addToast({
title: t('toast.problemSavingCanvas'),

View File

@ -4,7 +4,9 @@ import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
@ -99,6 +101,12 @@ export const addImageDroppedListener = () => {
controlNetId,
})
);
dispatch(
controlNetIsEnabledChanged({
controlNetId,
isEnabled: true,
})
);
return;
}
@ -111,6 +119,7 @@ export const addImageDroppedListener = () => {
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO));
dispatch(isIPAdapterEnabledChanged(true));
return;
}

View File

@ -3,7 +3,9 @@ import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
@ -87,6 +89,12 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
const { controlNetId } = postUploadAction;
dispatch(
controlNetIsEnabledChanged({
controlNetId,
isEnabled: true,
})
);
dispatch(
controlNetImageChanged({
controlNetId,
@ -104,6 +112,7 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO));
dispatch(isIPAdapterEnabledChanged(true));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,

View File

@ -4,6 +4,7 @@ import { api } from 'services/api';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..';
import { isInitializedChanged } from 'features/system/store/systemSlice';
export const addSocketConnectedEventListener = () => {
startAppListening({
@ -13,7 +14,7 @@ export const addSocketConnectedEventListener = () => {
log.debug('Connected');
const { nodes, config } = getState();
const { nodes, config, system } = getState();
const { disabledTabs } = config;
@ -21,7 +22,12 @@ export const addSocketConnectedEventListener = () => {
dispatch(receivedOpenAPISchema());
}
dispatch(api.util.resetApiState());
if (system.isInitialized) {
// only reset the query caches if this connect event is a *reconnect* event
dispatch(api.util.resetApiState());
} else {
dispatch(isInitializedChanged(true));
}
// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));

View File

@ -1,5 +1,4 @@
import { logger } from 'app/logging/logger';
import { api } from 'services/api';
import {
appSocketDisconnected,
socketDisconnected,
@ -13,8 +12,6 @@ export const addSocketDisconnectedEventListener = () => {
const log = logger('socketio');
log.debug('Disconnected');
dispatch(api.util.resetApiState());
// pass along the socket event as an application action
dispatch(appSocketDisconnected(action.payload));
},

View File

@ -21,7 +21,8 @@ export type AppFeature =
| 'multiselect'
| 'pauseQueue'
| 'resumeQueue'
| 'prependQueue';
| 'prependQueue'
| 'invocationCache';
/**
* A disable-able Stable Diffusion feature

View File

@ -0,0 +1,22 @@
import { Box, Image } from '@chakra-ui/react';
import InvokeAILogoImage from 'assets/images/logo.png';
import { memo } from 'react';
const GreyscaleInvokeAIIcon = () => (
<Box pos="relative" w={4} h={4}>
<Image
src={InvokeAILogoImage}
alt="invoke-ai-logo"
pos="absolute"
top={-0.5}
insetInlineStart={-0.5}
w={5}
h={5}
minW={5}
minH={5}
filter="saturate(0)"
/>
</Box>
);
export default memo(GreyscaleInvokeAIIcon);

View File

@ -31,7 +31,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
insetInlineStart={0}
w="full"
h="full"
pointerEvents="none"
pointerEvents={active ? 'auto' : 'none'}
>
<AnimatePresence>
{isValidDrop(data, active) && (

View File

@ -1,58 +1,67 @@
import {
Box,
Button,
Popover,
PopoverTrigger,
PopoverContent,
PopoverArrow,
PopoverCloseButton,
PopoverHeader,
PopoverBody,
PopoverProps,
Divider,
Flex,
Text,
Heading,
Image,
Popover,
PopoverArrow,
PopoverBody,
PopoverCloseButton,
PopoverContent,
PopoverProps,
PopoverTrigger,
Portal,
Text,
} from '@chakra-ui/react';
import { useAppSelector } from '../../app/store/storeHooks';
import { ReactNode, memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppSelector } from '../../app/store/storeHooks';
interface Props extends PopoverProps {
const OPEN_DELAY = 1500;
type Props = Omit<PopoverProps, 'children'> & {
details: string;
children: JSX.Element;
children: ReactNode;
image?: string;
buttonLabel?: string;
buttonHref?: string;
placement?: PopoverProps['placement'];
}
};
function IAIInformationalPopover({
const IAIInformationalPopover = ({
details,
image,
buttonLabel,
buttonHref,
children,
placement,
}: Props): JSX.Element {
const shouldDisableInformationalPopovers = useAppSelector(
(state) => state.system.shouldDisableInformationalPopovers
}: Props) => {
const shouldEnableInformationalPopovers = useAppSelector(
(state) => state.system.shouldEnableInformationalPopovers
);
const { t } = useTranslation();
const heading = t(`popovers.${details}.heading`);
const paragraph = t(`popovers.${details}.paragraph`);
if (shouldDisableInformationalPopovers) {
return children;
} else {
return (
<Popover
placement={placement || 'top'}
closeOnBlur={false}
trigger="hover"
variant="informational"
>
<PopoverTrigger>
<div>{children}</div>
</PopoverTrigger>
if (!shouldEnableInformationalPopovers) {
return <>{children}</>;
}
return (
<Popover
placement={placement || 'top'}
closeOnBlur={false}
trigger="hover"
variant="informational"
openDelay={OPEN_DELAY}
>
<PopoverTrigger>
<Box w="full">{children}</Box>
</PopoverTrigger>
<Portal>
<PopoverContent>
<PopoverArrow />
<PopoverCloseButton />
@ -83,14 +92,17 @@ function IAIInformationalPopover({
gap: 3,
flexDirection: 'column',
width: '100%',
p: 3,
pt: heading ? 0 : 3,
}}
>
{heading && <PopoverHeader>{heading}</PopoverHeader>}
<Text sx={{ px: 3 }}>{paragraph}</Text>
{heading && (
<>
<Heading size="sm">{heading}</Heading>
<Divider />
</>
)}
<Text>{paragraph}</Text>
{buttonLabel && (
<Flex sx={{ px: 3 }} justifyContent="flex-end">
<Flex justifyContent="flex-end">
<Button
onClick={() => window.open(buttonHref)}
size="sm"
@ -104,9 +116,9 @@ function IAIInformationalPopover({
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
}
}
</Portal>
</Popover>
);
};
export default IAIInformationalPopover;
export default memo(IAIInformationalPopover);

View File

@ -1,4 +1,4 @@
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip, forwardRef } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
@ -11,7 +11,7 @@ type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
label?: string;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const IAIMantineMultiSelect = forwardRef((props: IAIMultiSelectProps, ref) => {
const {
searchable = true,
tooltip,
@ -47,7 +47,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
<MultiSelect
label={
label ? (
<FormControl isDisabled={disabled}>
<FormControl ref={ref} isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
@ -63,6 +63,8 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
/>
</Tooltip>
);
};
});
IAIMantineMultiSelect.displayName = 'IAIMantineMultiSelect';
export default memo(IAIMantineMultiSelect);

View File

@ -1,4 +1,4 @@
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip, forwardRef } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
@ -17,7 +17,7 @@ type IAISelectProps = Omit<SelectProps, 'label'> & {
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const IAIMantineSearchableSelect = forwardRef((props: IAISelectProps, ref) => {
const {
searchable = true,
tooltip,
@ -74,7 +74,7 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
ref={inputRef}
label={
label ? (
<FormControl isDisabled={disabled}>
<FormControl ref={ref} isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
@ -92,6 +92,8 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
/>
</Tooltip>
);
};
});
IAIMantineSearchableSelect.displayName = 'IAIMantineSearchableSelect';
export default memo(IAIMantineSearchableSelect);

View File

@ -1,4 +1,4 @@
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip, forwardRef } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react';
@ -15,7 +15,7 @@ export type IAISelectProps = Omit<SelectProps, 'label'> & {
label?: string;
};
const IAIMantineSelect = (props: IAISelectProps) => {
const IAIMantineSelect = forwardRef((props: IAISelectProps, ref) => {
const { tooltip, inputRef, label, disabled, required, ...rest } = props;
const styles = useMantineSelectStyles();
@ -25,7 +25,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
<Select
label={
label ? (
<FormControl isRequired={required} isDisabled={disabled}>
<FormControl ref={ref} isRequired={required} isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
@ -37,6 +37,8 @@ const IAIMantineSelect = (props: IAISelectProps) => {
/>
</Tooltip>
);
};
});
IAIMantineSelect.displayName = 'IAIMantineSelect';
export default memo(IAIMantineSelect);

View File

@ -13,6 +13,7 @@ import {
NumberInputStepperProps,
Tooltip,
TooltipProps,
forwardRef,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { stopPastePropagation } from 'common/util/stopPastePropagation';
@ -50,7 +51,7 @@ interface Props extends Omit<NumberInputProps, 'onChange'> {
/**
* Customized Chakra FormControl + NumberInput multi-part component.
*/
const IAINumberInput = (props: Props) => {
const IAINumberInput = forwardRef((props: Props, ref) => {
const {
label,
isDisabled = false,
@ -141,6 +142,7 @@ const IAINumberInput = (props: Props) => {
return (
<Tooltip {...tooltipProps}>
<FormControl
ref={ref}
isDisabled={isDisabled}
isInvalid={isInvalid}
{...formControlProps}
@ -172,6 +174,8 @@ const IAINumberInput = (props: Props) => {
</FormControl>
</Tooltip>
);
};
});
IAINumberInput.displayName = 'IAINumberInput';
export default memo(IAINumberInput);

View File

@ -22,6 +22,7 @@ import {
SliderTrackProps,
Tooltip,
TooltipProps,
forwardRef,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
@ -71,7 +72,7 @@ export type IAIFullSliderProps = {
sliderIAIIconButtonProps?: IAIIconButtonProps;
};
const IAISlider = (props: IAIFullSliderProps) => {
const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
const [showTooltip, setShowTooltip] = useState(false);
const {
label,
@ -187,6 +188,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
return (
<FormControl
ref={ref}
onClick={forceInputBlur}
sx={
isCompact
@ -354,6 +356,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
</HStack>
</FormControl>
);
};
});
IAISlider.displayName = 'IAISlider';
export default memo(IAISlider);

View File

@ -72,4 +72,6 @@ const IAISwitch = (props: IAISwitchProps) => {
);
};
IAISwitch.displayName = 'IAISwitch';
export default memo(IAISwitch);

View File

@ -9,7 +9,7 @@ export const getBaseLayerBlob = async (state: RootState) => {
const canvasBaseLayer = getCanvasBaseLayer();
if (!canvasBaseLayer) {
return;
throw new Error('Problem getting base layer blob');
}
const {

View File

@ -1,12 +1,12 @@
import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetDuplicated,
controlNetRemoved,
controlNetToggled,
controlNetIsEnabledChanged,
} from '../store/controlNetSlice';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
@ -77,9 +77,17 @@ const ControlNet = (props: ControlNetProps) => {
);
}, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
controlNetIsEnabledChanged({
controlNetId,
isEnabled: e.target.checked,
})
);
},
[controlNetId, dispatch]
);
return (
<Flex
@ -106,8 +114,8 @@ const ControlNet = (props: ControlNetProps) => {
sx={{
w: 'full',
minW: 0,
opacity: isEnabled ? 1 : 0.5,
pointerEvents: isEnabled ? 'auto' : 'none',
// opacity: isEnabled ? 1 : 0.5,
// pointerEvents: isEnabled ? 'auto' : 'none',
transitionProperty: 'common',
transitionDuration: '0.1s',
}}

View File

@ -13,6 +13,7 @@ import {
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
import {
useAddImageToBoardMutation,
@ -26,7 +27,6 @@ import {
ControlNetConfig,
controlNetImageChanged,
} from '../store/controlNetSlice';
import { useTranslation } from 'react-i18next';
type Props = {
controlNet: ControlNetConfig;
@ -52,7 +52,6 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
isEnabled,
controlNetId,
} = controlNet;
@ -172,15 +171,13 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
h: isSmall ? 28 : 366, // magic no touch
alignItems: 'center',
justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',
opacity: isEnabled ? 1 : 0.5,
}}
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage || !isEnabled}
isDropDisabled={shouldShowProcessedImage}
postUploadAction={postUploadAction}
/>
@ -202,7 +199,6 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
droppableData={droppableData}
imageDTO={processedControlImage}
isUploadDisabled={true}
isDropDisabled={!isEnabled}
/>
</Box>

View File

@ -3,8 +3,8 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { isIPAdapterEnableToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { isIPAdapterEnabledChanged } from 'features/controlNet/store/controlNetSlice';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
@ -22,9 +22,12 @@ const ParamIPAdapterFeatureToggle = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(() => {
dispatch(isIPAdapterEnableToggled());
}, [dispatch]);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(isIPAdapterEnabledChanged(e.target.checked));
},
[dispatch]
);
return (
<IAISwitch

View File

@ -1,7 +1,9 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { RootState } from 'app/store/store';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
@ -16,15 +18,17 @@ import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { ipAdapterInfo } = controlNet;
return { ipAdapterInfo };
},
defaultSelectorOptions
);
const ParamIPAdapterImage = () => {
const ipAdapterInfo = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo
);
const isIPAdapterEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const { ipAdapterInfo } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -71,8 +75,6 @@ const ParamIPAdapterImage = () => {
droppableData={droppableData}
draggableData={draggableData}
postUploadAction={postUploadAction}
isUploadDisabled={!isIPAdapterEnabled}
isDropDisabled={!isIPAdapterEnabled}
dropLabel={t('toast.setIPAdapterImage')}
noContentFallback={
<IAINoContentFallback

View File

@ -11,6 +11,9 @@ import { useTranslation } from 'react-i18next';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const ParamIPAdapterModelSelect = () => {
const isEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const ipAdapterModel = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.model
);
@ -90,6 +93,7 @@ const ParamIPAdapterModelSelect = () => {
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
disabled={!isEnabled}
/>
);
};

View File

@ -1,3 +1,4 @@
import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -26,16 +27,15 @@ const ParamControlNetFeatureToggle = () => {
}, [dispatch]);
return (
<IAIInformationalPopover details="controlNetToggle">
<IAISwitch
label="Enable ControlNet"
isChecked={isEnabled}
onChange={handleChange}
formControlProps={{
width: '100%',
}}
/>
</IAIInformationalPopover>
<Box width="100%">
<IAIInformationalPopover details="controlNetToggle">
<IAISwitch
label="Enable ControlNet"
isChecked={isEnabled}
onChange={handleChange}
/>
</IAIInformationalPopover>
</Box>
);
};

View File

@ -146,16 +146,16 @@ export const controlNetSlice = createSlice({
const { controlNetId } = action.payload;
delete state.controlNets[controlNetId];
},
controlNetToggled: (
controlNetIsEnabledChanged: (
state,
action: PayloadAction<{ controlNetId: string }>
action: PayloadAction<{ controlNetId: string; isEnabled: boolean }>
) => {
const { controlNetId } = action.payload;
const { controlNetId, isEnabled } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.isEnabled = !cn.isEnabled;
cn.isEnabled = isEnabled;
},
controlNetImageChanged: (
state,
@ -377,8 +377,8 @@ export const controlNetSlice = createSlice({
controlNetReset: () => {
return { ...initialControlNetState };
},
isIPAdapterEnableToggled: (state) => {
state.isIPAdapterEnabled = !state.isIPAdapterEnabled;
isIPAdapterEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isIPAdapterEnabled = action.payload;
},
ipAdapterImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.ipAdapterInfo.adapterImage = action.payload;
@ -450,7 +450,7 @@ export const {
controlNetRemoved,
controlNetImageChanged,
controlNetProcessedImageChanged,
controlNetToggled,
controlNetIsEnabledChanged,
controlNetModelChanged,
controlNetWeightChanged,
controlNetBeginStepPctChanged,
@ -461,7 +461,7 @@ export const {
controlNetProcessorTypeChanged,
controlNetReset,
controlNetAutoConfigToggled,
isIPAdapterEnableToggled,
isIPAdapterEnabledChanged,
ipAdapterImageChanged,
ipAdapterWeightChanged,
ipAdapterModelChanged,

View File

@ -2,7 +2,6 @@ import {
DragOverlay,
MouseSensor,
TouchSensor,
pointerWithin,
useSensor,
useSensors,
} from '@dnd-kit/core';
@ -14,6 +13,7 @@ import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo, useCallback, useState } from 'react';
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
import { customPointerWithin } from '../util/customPointerWithin';
import { DndContextTypesafe } from './DndContextTypesafe';
import DragPreview from './DragPreview';
@ -77,7 +77,7 @@ const AppDndContext = (props: PropsWithChildren) => {
onDragStart={handleDragStart}
onDragEnd={handleDragEnd}
sensors={sensors}
collisionDetection={pointerWithin}
collisionDetection={customPointerWithin}
autoScroll={false}
>
{props.children}
@ -87,7 +87,7 @@ const AppDndContext = (props: PropsWithChildren) => {
style={{
width: 'min-content',
height: 'min-content',
cursor: 'none',
cursor: 'grabbing',
userSelect: 'none',
// expand overlay to prevent cursor from going outside it and displaying
padding: '10rem',

View File

@ -0,0 +1,38 @@
import { CollisionDetection, pointerWithin } from '@dnd-kit/core';
/**
* Filters out droppable elements that are overflowed, then applies the pointerWithin collision detection.
*
* Fixes collision detection firing on droppables that are not visible, having been scrolled out of view.
*
* See https://github.com/clauderic/dnd-kit/issues/1198
*/
export const customPointerWithin: CollisionDetection = (arg) => {
if (!arg.pointerCoordinates) {
// sanity check
return [];
}
// Get all elements at the pointer coordinates. This excludes elements which are overflowed,
// so it won't include the droppable elements that are scrolled out of view.
const targetElements = document.elementsFromPoint(
arg.pointerCoordinates.x,
arg.pointerCoordinates.y
);
const filteredDroppableContainers = arg.droppableContainers.filter(
(container) => {
if (!container.node.current) {
return false;
}
// Only include droppable elements that are in the list of elements at the pointer coordinates.
return targetElements.includes(container.node.current);
}
);
// Run the provided collision detection with the filtered droppable elements.
return pointerWithin({
...arg,
droppableContainers: filteredDroppableContainers,
});
};

View File

@ -5,7 +5,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { useCallback } from 'react';
import { MouseEvent, useCallback, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import {
Background,
@ -21,6 +21,7 @@ import {
OnSelectionChangeFunc,
ProOptions,
ReactFlow,
XYPosition,
} from 'reactflow';
import { useIsValidConnection } from '../../hooks/useIsValidConnection';
import {
@ -79,7 +80,8 @@ export const Flow = () => {
const edges = useAppSelector((state) => state.nodes.edges);
const viewport = useAppSelector((state) => state.nodes.viewport);
const { shouldSnapToGrid, selectionMode } = useAppSelector(selector);
const flowWrapper = useRef<HTMLDivElement>(null);
const cursorPosition = useRef<XYPosition>();
const isValidConnection = useIsValidConnection();
const [borderRadius] = useToken('radii', ['base']);
@ -154,6 +156,17 @@ export const Flow = () => {
flow.fitView();
}, []);
const onMouseMove = useCallback((event: MouseEvent<HTMLDivElement>) => {
const bounds = flowWrapper.current?.getBoundingClientRect();
if (bounds) {
const pos = $flow.get()?.project({
x: event.clientX - bounds.left,
y: event.clientY - bounds.top,
});
cursorPosition.current = pos;
}
}, []);
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault();
dispatch(selectionCopied());
@ -166,18 +179,20 @@ export const Flow = () => {
useHotkeys(['Ctrl+v', 'Meta+v'], (e) => {
e.preventDefault();
dispatch(selectionPasted());
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
});
return (
<ReactFlow
id="workflow-editor"
ref={flowWrapper}
defaultViewport={viewport}
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
onInit={onInit}
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}

View File

@ -31,7 +31,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <Box p={2}>Output field in input: {field?.type}</Box>;
}
if (field?.type === 'string' && fieldTemplate?.type === 'string') {
if (
(field?.type === 'string' && fieldTemplate?.type === 'string') ||
(field?.type === 'StringPolymorphic' &&
fieldTemplate?.type === 'StringPolymorphic')
) {
return (
<StringInputField
nodeId={nodeId}
@ -41,7 +45,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') {
if (
(field?.type === 'boolean' && fieldTemplate?.type === 'boolean') ||
(field?.type === 'BooleanPolymorphic' &&
fieldTemplate?.type === 'BooleanPolymorphic')
) {
return (
<BooleanInputField
nodeId={nodeId}
@ -53,7 +61,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
(field?.type === 'float' && fieldTemplate?.type === 'float')
(field?.type === 'float' && fieldTemplate?.type === 'float') ||
(field?.type === 'FloatPolymorphic' &&
fieldTemplate?.type === 'FloatPolymorphic') ||
(field?.type === 'IntegerPolymorphic' &&
fieldTemplate?.type === 'IntegerPolymorphic')
) {
return (
<NumberInputField
@ -74,7 +86,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') {
if (
(field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') ||
(field?.type === 'ImagePolymorphic' &&
fieldTemplate?.type === 'ImagePolymorphic')
) {
return (
<ImageInputField
nodeId={nodeId}

View File

@ -4,12 +4,17 @@ import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
BooleanPolymorphicInputFieldTemplate,
BooleanPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
props: FieldComponentProps<
BooleanInputFieldValue | BooleanPolymorphicInputFieldValue,
BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate
>
) => {
const { nodeId, field } = props;

View File

@ -12,6 +12,8 @@ import {
FieldComponentProps,
ImageInputFieldTemplate,
ImageInputFieldValue,
ImagePolymorphicInputFieldTemplate,
ImagePolymorphicInputFieldValue,
} from 'features/nodes/types/types';
import { memo, useCallback, useMemo } from 'react';
import { FaUndo } from 'react-icons/fa';
@ -19,7 +21,10 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
props: FieldComponentProps<
ImageInputFieldValue | ImagePolymorphicInputFieldValue,
ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();

View File

@ -12,15 +12,25 @@ import {
FieldComponentProps,
FloatInputFieldTemplate,
FloatInputFieldValue,
FloatPolymorphicInputFieldTemplate,
FloatPolymorphicInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
IntegerPolymorphicInputFieldTemplate,
IntegerPolymorphicInputFieldValue,
} from 'features/nodes/types/types';
import { memo, useEffect, useMemo, useState } from 'react';
const NumberInputFieldComponent = (
props: FieldComponentProps<
IntegerInputFieldValue | FloatInputFieldValue,
IntegerInputFieldTemplate | FloatInputFieldTemplate
| IntegerInputFieldValue
| IntegerPolymorphicInputFieldValue
| FloatInputFieldValue
| FloatPolymorphicInputFieldValue,
| IntegerInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| FloatInputFieldTemplate
| FloatPolymorphicInputFieldTemplate
>
) => {
const { nodeId, field, fieldTemplate } = props;

View File

@ -6,11 +6,16 @@ import {
StringInputFieldTemplate,
StringInputFieldValue,
FieldComponentProps,
StringPolymorphicInputFieldValue,
StringPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
props: FieldComponentProps<
StringInputFieldValue | StringPolymorphicInputFieldValue,
StringInputFieldTemplate | StringPolymorphicInputFieldTemplate
>
) => {
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();

View File

@ -45,6 +45,7 @@ const NodeEditorPanelGroup = () => {
<PanelGroup
ref={panelGroupRef}
id="workflow-panel-group"
autoSaveId="workflow-panel-group"
direction="vertical"
style={{ height: '100%', width: '100%' }}
storage={panelStorage}

View File

@ -5,6 +5,10 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const selector = useMemo(
@ -21,7 +25,12 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
return [];
}
return map(nodeTemplate.inputs)
.filter((field) => ['any', 'direct'].includes(field.input))
.filter(
(field) =>
(['any', 'direct'].includes(field.input) ||
POLYMORPHIC_TYPES.includes(field.type)) &&
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
)
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)

View File

@ -4,6 +4,10 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { isInvocationNode } from '../types/types';
export const useConnectionInputFieldNames = (nodeId: string) => {
@ -21,7 +25,12 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
return [];
}
return map(nodeTemplate.inputs)
.filter((field) => field.input === 'connection')
.filter(
(field) =>
(field.input === 'connection' &&
!POLYMORPHIC_TYPES.includes(field.type)) ||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
)
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)

View File

@ -1,5 +1,5 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { cloneDeep, forEach, isEqual, map, uniqBy } from 'lodash-es';
import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es';
import {
addEdge,
applyEdgeChanges,
@ -16,9 +16,9 @@ import {
OnConnectStartParams,
SelectionMode,
Viewport,
XYPosition,
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { sessionCanceled, sessionInvoked } from 'services/api/thunks/session';
import { ImageField } from 'services/api/types';
import {
appSocketGeneratorProgress,
@ -722,8 +722,30 @@ const nodesSlice = createSlice({
selectionCopied: (state) => {
state.nodesToCopy = state.nodes.filter((n) => n.selected).map(cloneDeep);
state.edgesToCopy = state.edges.filter((e) => e.selected).map(cloneDeep);
if (state.nodesToCopy.length > 0) {
const averagePosition = { x: 0, y: 0 };
state.nodesToCopy.forEach((e) => {
const xOffset = 0.15 * (e.width ?? 0);
const yOffset = 0.5 * (e.height ?? 0);
averagePosition.x += e.position.x + xOffset;
averagePosition.y += e.position.y + yOffset;
});
averagePosition.x /= state.nodesToCopy.length;
averagePosition.y /= state.nodesToCopy.length;
state.nodesToCopy.forEach((e) => {
e.position.x -= averagePosition.x;
e.position.y -= averagePosition.y;
});
}
},
selectionPasted: (state) => {
selectionPasted: (
state,
action: PayloadAction<{ cursorPosition?: XYPosition }>
) => {
const { cursorPosition } = action.payload;
const newNodes = state.nodesToCopy.map(cloneDeep);
const oldNodeIds = newNodes.map((n) => n.data.id);
const newEdges = state.edgesToCopy
@ -752,8 +774,8 @@ const nodesSlice = createSlice({
const position = findUnoccupiedPosition(
state.nodes,
node.position.x,
node.position.y
node.position.x + (cursorPosition?.x ?? 0),
node.position.y + (cursorPosition?.y ?? 0)
);
node.position = position;
@ -853,28 +875,10 @@ const nodesSlice = createSlice({
node.progressImage = progress_image ?? null;
}
});
builder.addCase(sessionInvoked.fulfilled, (state) => {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = NodeStatus.PENDING;
nes.error = null;
nes.progress = null;
nes.progressImage = null;
nes.outputs = [];
});
});
builder.addCase(sessionCanceled.fulfilled, (state) => {
map(state.nodeExecutionStates, (nes) => {
if (nes.status === NodeStatus.IN_PROGRESS) {
nes.status = NodeStatus.PENDING;
}
});
});
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
if (
['completed', 'canceled', 'failed'].includes(action.payload.data.status)
) {
if (['in_progress'].includes(action.payload.data.status)) {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = NodeStatus.PENDING;
nes.status = NodeStatus.IN_PROGRESS;
nes.error = null;
nes.progress = null;
nes.progressImage = null;

View File

@ -102,6 +102,29 @@ export const POLYMORPHIC_TO_SINGLE_MAP = {
T2IAdapterPolymorphic: 'T2IAdapterField',
};
export const TYPES_WITH_INPUT_COMPONENTS = [
'string',
'StringPolymorphic',
'boolean',
'BooleanPolymorphic',
'integer',
'float',
'FloatPolymorphic',
'IntegerPolymorphic',
'enum',
'ImageField',
'ImagePolymorphic',
'MainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'LoRAModelField',
'ControlNetModelField',
'ColorField',
'SDXLMainModelField',
'Scheduler',
'IPAdapterModelField',
];
export const isPolymorphicItemType = (
itemType: string | undefined
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>

View File

@ -224,7 +224,7 @@ export type IntegerCollectionInputFieldValue = z.infer<
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerPolymorphic'),
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
value: z.number().int().optional(),
});
export type IntegerPolymorphicInputFieldValue = z.infer<
typeof zIntegerPolymorphicInputFieldValue
@ -246,7 +246,7 @@ export type FloatCollectionInputFieldValue = z.infer<
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatPolymorphic'),
value: z.union([z.number(), z.array(z.number())]).optional(),
value: z.number().optional(),
});
export type FloatPolymorphicInputFieldValue = z.infer<
typeof zFloatPolymorphicInputFieldValue
@ -268,7 +268,7 @@ export type StringCollectionInputFieldValue = z.infer<
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringPolymorphic'),
value: z.union([z.string(), z.array(z.string())]).optional(),
value: z.string().optional(),
});
export type StringPolymorphicInputFieldValue = z.infer<
typeof zStringPolymorphicInputFieldValue
@ -290,7 +290,7 @@ export type BooleanCollectionInputFieldValue = z.infer<
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanPolymorphic'),
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
value: z.boolean().optional(),
});
export type BooleanPolymorphicInputFieldValue = z.infer<
typeof zBooleanPolymorphicInputFieldValue
@ -542,7 +542,7 @@ export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImagePolymorphic'),
value: z.union([zImageField, z.array(zImageField)]).optional(),
value: zImageField.optional(),
});
export type ImagePolymorphicInputFieldValue = z.infer<
typeof zImagePolymorphicInputFieldValue

View File

@ -8,7 +8,11 @@ import {
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { CONTROL_NET_COLLECT, METADATA_ACCUMULATOR } from './constants';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
CONTROL_NET_COLLECT,
METADATA_ACCUMULATOR,
} from './constants';
export const addControlNetToLinearGraph = (
state: RootState,
@ -100,6 +104,16 @@ export const addControlNetToLinearGraph = (
field: 'item',
},
});
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'control',
},
});
}
});
}
}

View File

@ -1,7 +1,7 @@
import { RootState } from 'app/store/store';
import { IPAdapterInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { IP_ADAPTER } from './constants';
import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER } from './constants';
export const addIPAdapterToLinearGraph = (
state: RootState,
@ -55,5 +55,15 @@ export const addIPAdapterToLinearGraph = (
field: 'ip_adapter',
},
});
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'ip_adapter',
},
});
}
}
};

View File

@ -45,7 +45,6 @@ export const buildCanvasSDXLImageToImageGraph = (
seed,
steps,
vaePrecision,
clipSkip,
shouldUseCpuNoise,
seamlessXAxis,
seamlessYAxis,
@ -339,7 +338,6 @@ export const buildCanvasSDXLImageToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,
};

View File

@ -46,7 +46,6 @@ export const buildCanvasSDXLTextToImageGraph = (
seed,
steps,
vaePrecision,
clipSkip,
shouldUseCpuNoise,
seamlessXAxis,
seamlessYAxis,
@ -321,7 +320,6 @@ export const buildCanvasSDXLTextToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
};
graph.edges.push({

View File

@ -49,7 +49,6 @@ export const buildLinearSDXLImageToImageGraph = (
shouldFitToWidthHeight,
width,
height,
clipSkip,
shouldUseCpuNoise,
vaePrecision,
seamlessXAxis,
@ -349,7 +348,6 @@ export const buildLinearSDXLImageToImageGraph = (
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,

View File

@ -38,7 +38,6 @@ export const buildLinearSDXLTextToImageGraph = (
steps,
width,
height,
clipSkip,
shouldUseCpuNoise,
vaePrecision,
seamlessXAxis,
@ -243,7 +242,6 @@ export const buildLinearSDXLTextToImageGraph = (
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};

View File

@ -13,16 +13,16 @@ import ParamClipSkip from './ParamClipSkip';
const selector = createSelector(
stateSelector,
(state: RootState) => {
const { clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
state.generation;
return { clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
},
defaultSelectorOptions
);
export default function ParamAdvancedCollapse() {
const { clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
useAppSelector(selector);
const { t } = useTranslation();
const activeLabel = useMemo(() => {
@ -34,7 +34,7 @@ export default function ParamAdvancedCollapse() {
activeLabel.push(t('parameters.gpuNoise'));
}
if (clipSkip > 0) {
if (clipSkip > 0 && model && model.base_model !== 'sdxl') {
activeLabel.push(
t('parameters.clipSkipWithLayerCount', { layerCount: clipSkip })
);
@ -49,15 +49,19 @@ export default function ParamAdvancedCollapse() {
}
return activeLabel.join(', ');
}, [clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
return (
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamSeamless />
<Divider />
<ParamClipSkip />
<Divider pt={2} />
{model && model?.base_model !== 'sdxl' && (
<>
<ParamClipSkip />
<Divider pt={2} />
</>
)}
<ParamCpuNoiseToggle />
</Flex>
</IAICollapse>

View File

@ -42,6 +42,10 @@ export default function ParamClipSkip() {
return clipSkipMap[model.base_model].markers;
}, [model]);
if (model?.base_model === 'sdxl') {
return null;
}
return (
<IAIInformationalPopover details="clipSkip">
<IAISlider

View File

@ -1,4 +1,4 @@
import { Flex, Spacer, Text } from '@chakra-ui/react';
import { Box, Flex, Spacer, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
@ -94,20 +94,22 @@ export default function ParamBoundingBoxSize() {
}}
>
<Flex alignItems="center" gap={2}>
<IAIInformationalPopover details="paramRatio">
<Text
sx={{
fontSize: 'sm',
width: 'full',
color: 'base.700',
_dark: {
color: 'base.300',
},
}}
>
{t('parameters.aspectRatio')}
</Text>
</IAIInformationalPopover>
<Box width="full">
<IAIInformationalPopover details="paramRatio">
<Text
sx={{
fontSize: 'sm',
width: 'full',
color: 'base.700',
_dark: {
color: 'base.300',
},
}}
>
{t('parameters.aspectRatio')}
</Text>
</IAIInformationalPopover>
</Box>
<Spacer />
<ParamAspectRatio />
<IAIIconButton

View File

@ -23,17 +23,18 @@ import { v4 as uuidv4 } from 'uuid';
const selector = createSelector(
[stateSelector],
({ controlNet }) => {
const { controlNets, isEnabled, isIPAdapterEnabled } = controlNet;
const { controlNets, isEnabled, isIPAdapterEnabled, ipAdapterInfo } =
controlNet;
const validControlNets = getValidControlNets(controlNets);
const isIPAdapterValid = ipAdapterInfo.model && ipAdapterInfo.adapterImage;
let activeLabel = undefined;
if (isEnabled && validControlNets.length > 0) {
activeLabel = `${validControlNets.length} ControlNet`;
}
if (isIPAdapterEnabled) {
if (isIPAdapterEnabled && isIPAdapterValid) {
if (activeLabel) {
activeLabel = `${activeLabel}, IP Adapter`;
} else {

View File

@ -61,7 +61,7 @@ const ParamIterations = ({ asSlider }: Props) => {
}, [dispatch, initial]);
return asSlider || shouldUseSliders ? (
<IAIInformationalPopover details="paramImages">
<IAIInformationalPopover details="paramIterations">
<IAISlider
label={t('parameters.iterations')}
step={step}
@ -77,7 +77,7 @@ const ParamIterations = ({ asSlider }: Props) => {
/>
</IAIInformationalPopover>
) : (
<IAIInformationalPopover details="paramImages">
<IAIInformationalPopover details="paramIterations">
<IAINumberInput
label={t('parameters.iterations')}
step={step}

View File

@ -1,6 +1,7 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover';
import IAITextarea from 'common/components/IAITextarea';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
@ -9,7 +10,6 @@ import { ChangeEvent, KeyboardEvent, memo, useCallback, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useTranslation } from 'react-i18next';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover';
const ParamNegativeConditioning = () => {
const negativePrompt = useAppSelector(
@ -76,13 +76,16 @@ const ParamNegativeConditioning = () => {
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
return (
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAIInformationalPopover details="paramNegativeConditioning">
<IAIInformationalPopover
placement="right"
details="paramNegativeConditioning"
>
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="negativePrompt"
name="negativePrompt"
@ -95,20 +98,20 @@ const ParamNegativeConditioning = () => {
minH={16}
{...(isEmbeddingEnabled && { onKeyDown: handleKeyDown })}
/>
</IAIInformationalPopover>
</ParamEmbeddingPopover>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</FormControl>
</ParamEmbeddingPopover>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</FormControl>
</IAIInformationalPopover>
);
};

View File

@ -104,13 +104,16 @@ const ParamPositiveConditioning = () => {
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAIInformationalPopover details="paramPositiveConditioning">
<IAIInformationalPopover
placement="right"
details="paramPositiveConditioning"
>
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
@ -122,9 +125,9 @@ const ParamPositiveConditioning = () => {
resize="vertical"
minH={32}
/>
</IAIInformationalPopover>
</ParamEmbeddingPopover>
</FormControl>
</ParamEmbeddingPopover>
</FormControl>
</IAIInformationalPopover>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{

View File

@ -1,4 +1,4 @@
import { Flex, Spacer, Text } from '@chakra-ui/react';
import { Box, Flex, Spacer, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
@ -84,20 +84,21 @@ export default function ParamSize() {
}}
>
<Flex alignItems="center" gap={2}>
<IAIInformationalPopover details="paramRatio">
<Text
sx={{
fontSize: 'sm',
width: 'full',
color: 'base.700',
_dark: {
color: 'base.300',
},
}}
>
{t('parameters.aspectRatio')}
</Text>
</IAIInformationalPopover>
<Box width="full">
<IAIInformationalPopover details="paramRatio">
<Text
sx={{
fontSize: 'sm',
color: 'base.700',
_dark: {
color: 'base.300',
},
}}
>
{t('parameters.aspectRatio')}
</Text>
</IAIInformationalPopover>
</Box>
<Spacer />
<ParamAspectRatio />
<IAIIconButton

View File

@ -119,8 +119,8 @@ const ParamMainModelSelect = () => {
data={[]}
/>
) : (
<IAIInformationalPopover details="paramModel" placement="bottom">
<Flex w="100%" alignItems="center" gap={3}>
<Flex w="100%" alignItems="center" gap={3}>
<IAIInformationalPopover details="paramModel" placement="bottom">
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}
@ -134,13 +134,13 @@ const ParamMainModelSelect = () => {
onChange={handleChangeModel}
w="100%"
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
</IAIInformationalPopover>
</IAIInformationalPopover>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
);
};

View File

@ -21,7 +21,7 @@ import {
loraModelsAdapter,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import { loraRecalled } from '../../lora/store/loraSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions';
import {
setCfgScale,
@ -509,6 +509,7 @@ export const useRecallParameters = () => {
dispatch(setRefinerStart(refiner_start));
}
dispatch(lorasCleared());
loras?.forEach((lora) => {
const result = prepareLoRAMetadataItem(lora);
if (result.lora) {

View File

@ -12,12 +12,12 @@ type Props = {
const CancelCurrentQueueItemButton = ({ asIconButton, sx }: Props) => {
const { t } = useTranslation();
const { cancelQueueItem, isLoading, currentQueueItemId } =
const { cancelQueueItem, isLoading, isDisabled } =
useCancelCurrentQueueItem();
return (
<QueueButton
isDisabled={!currentQueueItemId}
isDisabled={isDisabled}
isLoading={isLoading}
asIconButton={asIconButton}
label={t('queue.cancel')}

View File

@ -0,0 +1,22 @@
import IAIButton from 'common/components/IAIButton';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useClearInvocationCache } from '../hooks/useClearInvocationCache';
const ClearInvocationCacheButton = () => {
const { t } = useTranslation();
const { clearInvocationCache, isDisabled, isLoading } =
useClearInvocationCache();
return (
<IAIButton
isDisabled={isDisabled}
isLoading={isLoading}
onClick={clearInvocationCache}
>
{t('invocationCache.clear')}
</IAIButton>
);
};
export default memo(ClearInvocationCacheButton);

View File

@ -13,7 +13,7 @@ type Props = {
const ClearQueueButton = ({ asIconButton, sx }: Props) => {
const { t } = useTranslation();
const { clearQueue, isLoading, queueStatus } = useClearQueue();
const { clearQueue, isLoading, isDisabled } = useClearQueue();
return (
<IAIAlertDialog
@ -22,7 +22,7 @@ const ClearQueueButton = ({ asIconButton, sx }: Props) => {
acceptButtonText={t('queue.clear')}
triggerComponent={
<QueueButton
isDisabled={!queueStatus?.queue.total}
isDisabled={isDisabled}
isLoading={isLoading}
asIconButton={asIconButton}
label={t('queue.clear')}

View File

@ -0,0 +1,55 @@
import { ButtonGroup } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
import StatusStatGroup from './common/StatusStatGroup';
import StatusStatItem from './common/StatusStatItem';
const InvocationCacheStatus = () => {
const { t } = useTranslation();
const isConnected = useAppSelector((state) => state.system.isConnected);
const { data: queueStatus } = useGetQueueStatusQuery(undefined);
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
pollingInterval:
isConnected &&
queueStatus?.processor.is_started &&
queueStatus?.queue.pending > 0
? 5000
: 0,
});
return (
<StatusStatGroup>
<StatusStatItem
isDisabled={!cacheStatus?.enabled}
label={t('invocationCache.cacheSize')}
value={cacheStatus?.size ?? 0}
/>
<StatusStatItem
isDisabled={!cacheStatus?.enabled}
label={t('invocationCache.hits')}
value={cacheStatus?.hits ?? 0}
/>
<StatusStatItem
isDisabled={!cacheStatus?.enabled}
label={t('invocationCache.misses')}
value={cacheStatus?.misses ?? 0}
/>
<StatusStatItem
isDisabled={!cacheStatus?.enabled}
label={t('invocationCache.maxCacheSize')}
value={cacheStatus?.max_size ?? 0}
/>
<ButtonGroup w={24} orientation="vertical" size="xs">
<ClearInvocationCacheButton />
<ToggleInvocationCacheButton />
</ButtonGroup>
</StatusStatGroup>
);
};
export default memo(InvocationCacheStatus);

View File

@ -10,14 +10,14 @@ type Props = {
const PauseProcessorButton = ({ asIconButton }: Props) => {
const { t } = useTranslation();
const { pauseProcessor, isLoading, isStarted } = usePauseProcessor();
const { pauseProcessor, isLoading, isDisabled } = usePauseProcessor();
return (
<QueueButton
asIconButton={asIconButton}
label={t('queue.pause')}
tooltip={t('queue.pauseTooltip')}
isDisabled={!isStarted}
isDisabled={isDisabled}
isLoading={isLoading}
icon={<FaPause />}
onClick={pauseProcessor}

View File

@ -10,11 +10,11 @@ type Props = {
const PruneQueueButton = ({ asIconButton }: Props) => {
const { t } = useTranslation();
const { pruneQueue, isLoading, finishedCount } = usePruneQueue();
const { pruneQueue, isLoading, finishedCount, isDisabled } = usePruneQueue();
return (
<QueueButton
isDisabled={!finishedCount}
isDisabled={isDisabled}
isLoading={isLoading}
asIconButton={asIconButton}
label={t('queue.prune')}

View File

@ -1,9 +1,10 @@
import { ChakraProps } from '@chakra-ui/react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useQueueBack } from '../hooks/useQueueBack';
import EnqueueButtonTooltip from './QueueButtonTooltip';
import QueueButton from './common/QueueButton';
import { ChakraProps } from '@chakra-ui/react';
import GreyscaleInvokeAIIcon from 'common/components/GreyscaleInvokeAIIcon';
type Props = {
asIconButton?: boolean;
@ -23,6 +24,7 @@ const QueueBackButton = ({ asIconButton, sx }: Props) => {
onClick={queueBack}
tooltip={<EnqueueButtonTooltip />}
sx={sx}
icon={asIconButton ? <GreyscaleInvokeAIIcon /> : undefined}
/>
);
};

View File

@ -1,38 +1,39 @@
import { Stat, StatGroup, StatLabel, StatNumber } from '@chakra-ui/react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import StatusStatGroup from './common/StatusStatGroup';
import StatusStatItem from './common/StatusStatItem';
const QueueStatus = () => {
const { data: queueStatus } = useGetQueueStatusQuery();
const { t } = useTranslation();
return (
<StatGroup alignItems="center" justifyContent="center" w="full" h="full">
<Stat w={24}>
<StatLabel>{t('queue.in_progress')}</StatLabel>
<StatNumber>{queueStatus?.queue.in_progress ?? 0}</StatNumber>
</Stat>
<Stat w={24}>
<StatLabel>{t('queue.pending')}</StatLabel>
<StatNumber>{queueStatus?.queue.pending ?? 0}</StatNumber>
</Stat>
<Stat w={24}>
<StatLabel>{t('queue.completed')}</StatLabel>
<StatNumber>{queueStatus?.queue.completed ?? 0}</StatNumber>
</Stat>
<Stat w={24}>
<StatLabel>{t('queue.failed')}</StatLabel>
<StatNumber>{queueStatus?.queue.failed ?? 0}</StatNumber>
</Stat>
<Stat w={24}>
<StatLabel>{t('queue.canceled')}</StatLabel>
<StatNumber>{queueStatus?.queue.canceled ?? 0}</StatNumber>
</Stat>
<Stat w={24}>
<StatLabel>{t('queue.total')}</StatLabel>
<StatNumber>{queueStatus?.queue.total}</StatNumber>
</Stat>
</StatGroup>
<StatusStatGroup>
<StatusStatItem
label={t('queue.in_progress')}
value={queueStatus?.queue.in_progress ?? 0}
/>
<StatusStatItem
label={t('queue.pending')}
value={queueStatus?.queue.pending ?? 0}
/>
<StatusStatItem
label={t('queue.completed')}
value={queueStatus?.queue.completed ?? 0}
/>
<StatusStatItem
label={t('queue.failed')}
value={queueStatus?.queue.failed ?? 0}
/>
<StatusStatItem
label={t('queue.canceled')}
value={queueStatus?.queue.canceled ?? 0}
/>
<StatusStatItem
label={t('queue.total')}
value={queueStatus?.queue.total ?? 0}
/>
</StatusStatGroup>
);
};

View File

@ -1,16 +1,14 @@
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { Box, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ClearQueueButton from './ClearQueueButton';
import PauseProcessorButton from './PauseProcessorButton';
import PruneQueueButton from './PruneQueueButton';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import InvocationCacheStatus from './InvocationCacheStatus';
import QueueList from './QueueList/QueueList';
import QueueStatus from './QueueStatus';
import ResumeProcessorButton from './ResumeProcessorButton';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import QueueTabQueueControls from './QueueTabQueueControls';
const QueueTabContent = () => {
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
const isInvocationCacheEnabled =
useFeatureStatus('invocationCache').isFeatureEnabled;
return (
<Flex
@ -23,33 +21,9 @@ const QueueTabContent = () => {
gap={2}
>
<Flex gap={2} w="full">
<Flex layerStyle="second" borderRadius="base" p={2} gap={2}>
{isPauseEnabled || isResumeEnabled ? (
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
{isResumeEnabled ? <ResumeProcessorButton /> : <></>}
{isPauseEnabled ? <PauseProcessorButton /> : <></>}
</ButtonGroup>
) : (
<></>
)}
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
<PruneQueueButton />
<ClearQueueButton />
</ButtonGroup>
</Flex>
<Flex
layerStyle="second"
borderRadius="base"
flexDir="column"
py={2}
px={3}
gap={2}
>
<QueueStatus />
</Flex>
{/* <QueueStatusCard />
<CurrentQueueItemCard />
<NextQueueItemCard /> */}
<QueueTabQueueControls />
<QueueStatus />
{isInvocationCacheEnabled && <InvocationCacheStatus />}
</Flex>
<Box layerStyle="second" p={2} borderRadius="base" w="full" h="full">
<QueueList />

View File

@ -0,0 +1,30 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import ClearQueueButton from './ClearQueueButton';
import PauseProcessorButton from './PauseProcessorButton';
import PruneQueueButton from './PruneQueueButton';
import ResumeProcessorButton from './ResumeProcessorButton';
const QueueTabQueueControls = () => {
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
return (
<Flex layerStyle="second" borderRadius="base" p={2} gap={2}>
{isPauseEnabled || isResumeEnabled ? (
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
{isResumeEnabled ? <ResumeProcessorButton /> : <></>}
{isPauseEnabled ? <PauseProcessorButton /> : <></>}
</ButtonGroup>
) : (
<></>
)}
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
<PruneQueueButton />
<ClearQueueButton />
</ButtonGroup>
</Flex>
);
};
export default memo(QueueTabQueueControls);

View File

@ -10,14 +10,14 @@ type Props = {
const ResumeProcessorButton = ({ asIconButton }: Props) => {
const { t } = useTranslation();
const { resumeProcessor, isLoading, isStarted } = useResumeProcessor();
const { resumeProcessor, isLoading, isDisabled } = useResumeProcessor();
return (
<QueueButton
asIconButton={asIconButton}
label={t('queue.resume')}
tooltip={t('queue.resumeTooltip')}
isDisabled={isStarted}
isDisabled={isDisabled}
isLoading={isLoading}
icon={<FaPlay />}
onClick={resumeProcessor}

View File

@ -0,0 +1,47 @@
import IAIButton from 'common/components/IAIButton';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
import { useDisableInvocationCache } from '../hooks/useDisableInvocationCache';
import { useEnableInvocationCache } from '../hooks/useEnableInvocationCache';
const ToggleInvocationCacheButton = () => {
const { t } = useTranslation();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const {
enableInvocationCache,
isDisabled: isEnableDisabled,
isLoading: isEnableLoading,
} = useEnableInvocationCache();
const {
disableInvocationCache,
isDisabled: isDisableDisabled,
isLoading: isDisableLoading,
} = useDisableInvocationCache();
if (cacheStatus?.enabled) {
return (
<IAIButton
isDisabled={isDisableDisabled}
isLoading={isDisableLoading}
onClick={disableInvocationCache}
>
{t('invocationCache.disable')}
</IAIButton>
);
}
return (
<IAIButton
isDisabled={isEnableDisabled}
isLoading={isEnableLoading}
onClick={enableInvocationCache}
>
{t('invocationCache.enable')}
</IAIButton>
);
};
export default memo(ToggleInvocationCacheButton);

View File

@ -1,27 +0,0 @@
import { ButtonGroup, ButtonGroupProps, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ClearQueueButton from './ClearQueueButton';
import PauseProcessorButton from './PauseProcessorButton';
import PruneQueueButton from './PruneQueueButton';
import ResumeProcessorButton from './ResumeProcessorButton';
type Props = ButtonGroupProps & {
asIconButtons?: boolean;
};
const VerticalQueueControls = ({ asIconButtons, ...rest }: Props) => {
return (
<Flex flexDir="column" gap={2}>
<ButtonGroup w="full" isAttached {...rest}>
<ResumeProcessorButton asIconButton={asIconButtons} />
<PauseProcessorButton asIconButton={asIconButtons} />
</ButtonGroup>
<ButtonGroup w="full" isAttached {...rest}>
<PruneQueueButton asIconButton={asIconButtons} />
<ClearQueueButton asIconButton={asIconButtons} />
</ButtonGroup>
</Flex>
);
};
export default memo(VerticalQueueControls);

View File

@ -0,0 +1,22 @@
import { StatGroup, StatGroupProps } from '@chakra-ui/react';
import { memo } from 'react';
const StatusStatGroup = ({ children, ...rest }: StatGroupProps) => (
<StatGroup
alignItems="center"
justifyContent="center"
w="full"
h="full"
layerStyle="second"
borderRadius="base"
py={2}
px={3}
gap={6}
flexWrap="nowrap"
{...rest}
>
{children}
</StatGroup>
);
export default memo(StatusStatGroup);

View File

@ -0,0 +1,47 @@
import {
ChakraProps,
Stat,
StatLabel,
StatNumber,
StatProps,
} from '@chakra-ui/react';
import { memo } from 'react';
const sx: ChakraProps['sx'] = {
'&[aria-disabled="true"]': {
color: 'base.400',
_dark: {
color: 'base.500',
},
},
};
type Props = Omit<StatProps, 'children'> & {
label: string;
value: string | number;
isDisabled?: boolean;
};
const StatusStatItem = ({
label,
value,
isDisabled = false,
...rest
}: Props) => (
<Stat
flexGrow={1}
textOverflow="ellipsis"
overflow="hidden"
whiteSpace="nowrap"
aria-disabled={isDisabled}
sx={sx}
{...rest}
>
<StatLabel textOverflow="ellipsis" overflow="hidden" whiteSpace="nowrap">
{label}
</StatLabel>
<StatNumber>{value}</StatNumber>
</Stat>
);
export default memo(StatusStatItem);

View File

@ -1,4 +1,4 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -8,6 +8,7 @@ import {
} from 'services/api/endpoints/queue';
export const useCancelBatch = (batch_id: string) => {
const isConnected = useAppSelector((state) => state.system.isConnected);
const { isCanceled } = useGetBatchStatusQuery(
{ batch_id },
{
@ -49,5 +50,5 @@ export const useCancelBatch = (batch_id: string) => {
}
}, [batch_id, dispatch, isCanceled, t, trigger]);
return { cancelBatch, isLoading, isCanceled };
return { cancelBatch, isLoading, isCanceled, isDisabled: !isConnected };
};

View File

@ -1,4 +1,4 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -8,6 +8,7 @@ import {
} from 'services/api/endpoints/queue';
export const useCancelCurrentQueueItem = () => {
const isConnected = useAppSelector((state) => state.system.isConnected);
const { data: queueStatus } = useGetQueueStatusQuery();
const [trigger, { isLoading }] = useCancelQueueItemMutation();
const dispatch = useAppDispatch();
@ -38,5 +39,15 @@ export const useCancelCurrentQueueItem = () => {
}
}, [currentQueueItemId, dispatch, t, trigger]);
return { cancelQueueItem, isLoading, currentQueueItemId };
const isDisabled = useMemo(
() => !isConnected || !currentQueueItemId,
[isConnected, currentQueueItemId]
);
return {
cancelQueueItem,
isLoading,
currentQueueItemId,
isDisabled,
};
};

View File

@ -1,10 +1,11 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
export const useCancelQueueItem = (item_id: number) => {
const isConnected = useAppSelector((state) => state.system.isConnected);
const [trigger, { isLoading }] = useCancelQueueItemMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -27,5 +28,5 @@ export const useCancelQueueItem = (item_id: number) => {
}
}, [dispatch, item_id, t, trigger]);
return { cancelQueueItem, isLoading };
return { cancelQueueItem, isLoading, isDisabled: !isConnected };
};

View File

@ -0,0 +1,48 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useClearInvocationCacheMutation,
useGetInvocationCacheStatusQuery,
} from 'services/api/endpoints/appInfo';
export const useClearInvocationCache = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((state) => state.system.isConnected);
const [trigger, { isLoading }] = useClearInvocationCacheMutation({
fixedCacheKey: 'clearInvocationCache',
});
const isDisabled = useMemo(
() => !cacheStatus?.size || !isConnected,
[cacheStatus?.size, isConnected]
);
const clearInvocationCache = useCallback(async () => {
if (isDisabled) {
return;
}
try {
await trigger().unwrap();
dispatch(
addToast({
title: t('invocationCache.clearSucceeded'),
status: 'success',
})
);
} catch {
dispatch(
addToast({
title: t('invocationCache.clearFailed'),
status: 'error',
})
);
}
}, [isDisabled, trigger, dispatch, t]);
return { clearInvocationCache, isLoading, cacheStatus, isDisabled };
};

View File

@ -1,6 +1,6 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useClearQueueMutation,
@ -10,10 +10,9 @@ import { listCursorChanged, listPriorityChanged } from '../store/queueSlice';
export const useClearQueue = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((state) => state.system.isConnected);
const [trigger, { isLoading }] = useClearQueueMutation({
fixedCacheKey: 'clearQueue',
});
@ -43,5 +42,10 @@ export const useClearQueue = () => {
}
}, [queueStatus?.queue.total, trigger, dispatch, t]);
return { clearQueue, isLoading, queueStatus };
const isDisabled = useMemo(
() => !isConnected || !queueStatus?.queue.total,
[isConnected, queueStatus?.queue.total]
);
return { clearQueue, isLoading, queueStatus, isDisabled };
};

View File

@ -0,0 +1,48 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useDisableInvocationCacheMutation,
useGetInvocationCacheStatusQuery,
} from 'services/api/endpoints/appInfo';
export const useDisableInvocationCache = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((state) => state.system.isConnected);
const [trigger, { isLoading }] = useDisableInvocationCacheMutation({
fixedCacheKey: 'disableInvocationCache',
});
const isDisabled = useMemo(
() => !cacheStatus?.enabled || !isConnected || cacheStatus?.max_size === 0,
[cacheStatus?.enabled, cacheStatus?.max_size, isConnected]
);
const disableInvocationCache = useCallback(async () => {
if (isDisabled) {
return;
}
try {
await trigger().unwrap();
dispatch(
addToast({
title: t('invocationCache.disableSucceeded'),
status: 'success',
})
);
} catch {
dispatch(
addToast({
title: t('invocationCache.disableFailed'),
status: 'error',
})
);
}
}, [isDisabled, trigger, dispatch, t]);
return { disableInvocationCache, isLoading, cacheStatus, isDisabled };
};

View File

@ -0,0 +1,48 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useEnableInvocationCacheMutation,
useGetInvocationCacheStatusQuery,
} from 'services/api/endpoints/appInfo';
export const useEnableInvocationCache = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
const isConnected = useAppSelector((state) => state.system.isConnected);
const [trigger, { isLoading }] = useEnableInvocationCacheMutation({
fixedCacheKey: 'enableInvocationCache',
});
const isDisabled = useMemo(
() => cacheStatus?.enabled || !isConnected || cacheStatus?.max_size === 0,
[cacheStatus?.enabled, cacheStatus?.max_size, isConnected]
);
const enableInvocationCache = useCallback(async () => {
if (isDisabled) {
return;
}
try {
await trigger().unwrap();
dispatch(
addToast({
title: t('invocationCache.enableSucceeded'),
status: 'success',
})
);
} catch {
dispatch(
addToast({
title: t('invocationCache.enableFailed'),
status: 'error',
})
);
}
}, [isDisabled, trigger, dispatch, t]);
return { enableInvocationCache, isLoading, cacheStatus, isDisabled };
};

View File

@ -1,4 +1,4 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -10,6 +10,7 @@ import {
export const usePauseProcessor = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isConnected = useAppSelector((state) => state.system.isConnected);
const { data: queueStatus } = useGetQueueStatusQuery();
const [trigger, { isLoading }] = usePauseProcessorMutation({
fixedCacheKey: 'pauseProcessor',
@ -42,5 +43,10 @@ export const usePauseProcessor = () => {
}
}, [isStarted, trigger, dispatch, t]);
return { pauseProcessor, isLoading, isStarted };
const isDisabled = useMemo(
() => !isConnected || !isStarted,
[isConnected, isStarted]
);
return { pauseProcessor, isLoading, isStarted, isDisabled };
};

View File

@ -1,6 +1,6 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useGetQueueStatusQuery,
@ -11,6 +11,7 @@ import { listCursorChanged, listPriorityChanged } from '../store/queueSlice';
export const usePruneQueue = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isConnected = useAppSelector((state) => state.system.isConnected);
const [trigger, { isLoading }] = usePruneQueueMutation({
fixedCacheKey: 'pruneQueue',
});
@ -51,5 +52,10 @@ export const usePruneQueue = () => {
}
}, [finishedCount, trigger, dispatch, t]);
return { pruneQueue, isLoading, finishedCount };
const isDisabled = useMemo(
() => !isConnected || !finishedCount,
[finishedCount, isConnected]
);
return { pruneQueue, isLoading, finishedCount, isDisabled };
};

Some files were not shown because too many files have changed in this diff Show More