mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into ryan/t2i-adapter
This commit is contained in:
commit
cd8c53c50d
@ -159,7 +159,7 @@ groups in `invokeia.yaml`:
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
|
||||
|
@ -17,9 +17,10 @@ echo 6. Change InvokeAI startup options
|
||||
echo 7. Re-run the configure script to fix a broken install or to complete a major upgrade
|
||||
echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
echo 10. Run the InvokeAI image database maintenance script
|
||||
echo 11. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [1] "
|
||||
set /P choice="Please enter 1-11, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
@ -58,8 +59,11 @@ IF /I "%choice%" == "1" (
|
||||
echo Running invokeai-update...
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
echo Running the db maintenance script...
|
||||
python .venv\Scripts\invokeai-db-maintenance.exe
|
||||
) ELSE IF /I "%choice%" == "11" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
|
@ -97,13 +97,13 @@ do_choice() {
|
||||
;;
|
||||
10)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
printf "Running the db maintenance script\n"
|
||||
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
"HELP 1")
|
||||
11)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
invokeai-web --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
@ -125,7 +125,10 @@ do_dialog() {
|
||||
6 "Change InvokeAI startup options"
|
||||
7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
|
||||
8 "Open the developer console"
|
||||
9 "Update InvokeAI")
|
||||
9 "Update InvokeAI"
|
||||
10 "Run the InvokeAI image database maintenance script"
|
||||
11 "Command-line help"
|
||||
)
|
||||
|
||||
choice=$(dialog --clear \
|
||||
--backtitle "\Zb\Zu\Z3InvokeAI" \
|
||||
@ -157,9 +160,10 @@ do_line_input() {
|
||||
printf "7: Re-run the configure script to fix a broken install\n"
|
||||
printf "8: Open the developer console\n"
|
||||
printf "9: Update InvokeAI\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "10: Run the InvokeAI image database maintenance script\n"
|
||||
printf "11: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
read -p "Please enter 1-11, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
|
@ -7,6 +7,7 @@ from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
@ -113,3 +114,33 @@ async def set_log_level(
|
||||
async def clear_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.clear()
|
||||
|
||||
|
||||
@app_router.put(
|
||||
"/invocation_cache/enable",
|
||||
operation_id="enable_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def enable_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.enable()
|
||||
|
||||
|
||||
@app_router.put(
|
||||
"/invocation_cache/disable",
|
||||
operation_id="disable_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def disable_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.disable()
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/invocation_cache/status",
|
||||
operation_id="get_invocation_cache_status",
|
||||
responses={200: {"model": InvocationCacheStatus}},
|
||||
)
|
||||
async def get_invocation_cache_status() -> InvocationCacheStatus:
|
||||
"""Clears the invocation cache"""
|
||||
return ApiDependencies.invoker.services.invocation_cache.get_status()
|
||||
|
@ -3,16 +3,19 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from fastapi_socketio import SocketManager
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
|
@ -38,7 +38,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -100,7 +99,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
default=1.0, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
|
@ -58,9 +58,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
input=Input.Direct,
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||
|
@ -1,12 +1,14 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
|
||||
@ -211,7 +213,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
@ -221,7 +223,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
)
|
||||
@ -955,8 +956,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
@ -983,6 +983,18 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@_encode_to_tensor.register
|
||||
@staticmethod
|
||||
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
|
||||
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
|
@ -42,7 +42,8 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
@ -116,7 +117,8 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: int = InputField(
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
|
@ -166,7 +166,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
@ -179,7 +178,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
ui_type=UIType.Control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
|
||||
|
||||
class InvocationCacheBase(ABC):
|
||||
@ -32,7 +33,7 @@ class InvocationCacheBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
"""Deleteds an invocation output from the cache"""
|
||||
"""Deletes an invocation output from the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -44,3 +45,18 @@ class InvocationCacheBase(ABC):
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disable(self) -> None:
|
||||
"""Disables the cache, overriding the max cache size"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enable(self) -> None:
|
||||
"""Enables the cache, letting the the max cache size take effect"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self) -> InvocationCacheStatus:
|
||||
"""Returns the status of the cache"""
|
||||
pass
|
||||
|
@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvocationCacheStatus(BaseModel):
|
||||
size: int = Field(description="The current size of the invocation cache")
|
||||
hits: int = Field(description="The number of cache hits")
|
||||
misses: int = Field(description="The number of cache misses")
|
||||
enabled: bool = Field(description="Whether the invocation cache is enabled")
|
||||
max_size: int = Field(description="The maximum size of the invocation cache")
|
@ -3,18 +3,25 @@ from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
class MemoryInvocationCache(InvocationCacheBase):
|
||||
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
|
||||
__max_cache_size: int
|
||||
__disabled: bool
|
||||
__hits: int
|
||||
__misses: int
|
||||
__cache_ids: Queue
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, max_cache_size: int = 0) -> None:
|
||||
self.__cache = dict()
|
||||
self.__max_cache_size = max_cache_size
|
||||
self.__disabled = False
|
||||
self.__hits = 0
|
||||
self.__misses = 0
|
||||
self.__cache_ids = Queue()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
@ -25,15 +32,17 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
self.__invoker.services.latents.on_deleted(self._delete_by_match)
|
||||
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
if self.__max_cache_size == 0:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
item = self.__cache.get(key, None)
|
||||
if item is not None:
|
||||
self.__hits += 1
|
||||
return item[0]
|
||||
self.__misses += 1
|
||||
|
||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||
if self.__max_cache_size == 0:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
if key not in self.__cache:
|
||||
@ -47,25 +56,46 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
pass
|
||||
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
if self.__max_cache_size == 0:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
if key in self.__cache:
|
||||
del self.__cache[key]
|
||||
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
if self.__max_cache_size == 0:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
self.__cache.clear()
|
||||
self.__cache_ids = Queue()
|
||||
self.__misses = 0
|
||||
self.__hits = 0
|
||||
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
return hash(invocation.json(exclude={"id"}))
|
||||
|
||||
def _delete_by_match(self, to_match: str) -> None:
|
||||
def disable(self) -> None:
|
||||
if self.__max_cache_size == 0:
|
||||
return
|
||||
self.__disabled = True
|
||||
|
||||
def enable(self) -> None:
|
||||
if self.__max_cache_size == 0:
|
||||
return
|
||||
self.__disabled = False
|
||||
|
||||
def get_status(self) -> InvocationCacheStatus:
|
||||
return InvocationCacheStatus(
|
||||
hits=self.__hits,
|
||||
misses=self.__misses,
|
||||
enabled=not self.__disabled and self.__max_cache_size > 0,
|
||||
size=len(self.__cache),
|
||||
max_size=self.__max_cache_size,
|
||||
)
|
||||
|
||||
def _delete_by_match(self, to_match: str) -> None:
|
||||
if self.__max_cache_size == 0 or self.__disabled:
|
||||
return
|
||||
|
||||
keys_to_delete = set()
|
||||
for key, value_tuple in self.__cache.items():
|
||||
|
@ -92,30 +92,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self.__invoker.services.logger
|
||||
while not stop_event.is_set():
|
||||
poll_now_event.clear()
|
||||
try:
|
||||
# do not dequeue if there is already a session running
|
||||
if self.__queue_item is None and resume_event.is_set():
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
# do not dequeue if there is already a session running
|
||||
if self.__queue_item is None and resume_event.is_set():
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.batch_id,
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.batch_id,
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
if queue_item is None:
|
||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
if queue_item is None:
|
||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
|
@ -162,15 +162,15 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
field_values: Optional[list[NodeFieldValue]] = Field(
|
||||
default=None, description="The field values that were used for this queue item"
|
||||
)
|
||||
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
||||
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
||||
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
||||
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
||||
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
||||
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
||||
field_values: Optional[list[NodeFieldValue]] = Field(
|
||||
default=None, description="The field values that were used for this queue item"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
|
@ -77,7 +77,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
@ -86,8 +85,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error = event[1]["data"]["error"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
@ -95,8 +94,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
@ -354,7 +353,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return None
|
||||
queue_item = SessionQueueItem.from_dict(dict(result))
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
@ -427,7 +425,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return self.get_queue_item(item_id)
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
try:
|
||||
@ -565,7 +565,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
|
@ -42,4 +42,4 @@ IP-Adapters:
|
||||
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
|
||||
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
|
||||
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
|
||||
- Not yet supported: [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
|
||||
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
@ -53,6 +54,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
@ -178,6 +180,7 @@ class ModelProbe(object):
|
||||
Get the model type of a hugging-face style folder.
|
||||
"""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
@ -203,12 +206,18 @@ class ModelProbe(object):
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||
@ -467,16 +476,32 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return (
|
||||
BaseModelType.StableDiffusionXL
|
||||
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
else BaseModelType.StableDiffusion1
|
||||
)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.folder_path.name
|
||||
if name == "vae":
|
||||
name = self.folder_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
|
568
invokeai/backend/util/db_maintenance.py
Normal file
568
invokeai/backend/util/db_maintenance.py
Normal file
@ -0,0 +1,568 @@
|
||||
# pylint: disable=line-too-long
|
||||
# pylint: disable=broad-exception-caught
|
||||
# pylint: disable=missing-function-docstring
|
||||
"""Script to peform db maintenance and outputs directory management."""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import enum
|
||||
import glob
|
||||
import locale
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import PIL
|
||||
import PIL.ImageOps
|
||||
import PIL.PngImagePlugin
|
||||
import yaml
|
||||
|
||||
|
||||
class ConfigMapper:
|
||||
"""Configuration loader."""
|
||||
|
||||
def __init__(self): # noqa D107
|
||||
pass
|
||||
|
||||
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
||||
|
||||
INVOKE_DIRNAME = "invokeai"
|
||||
YAML_FILENAME = "invokeai.yaml"
|
||||
DATABASE_FILENAME = "invokeai.db"
|
||||
|
||||
database_path = None
|
||||
database_backup_dir = None
|
||||
outputs_path = None
|
||||
archive_path = None
|
||||
thumbnails_path = None
|
||||
thumbnails_archive_path = None
|
||||
|
||||
def load(self):
|
||||
"""Read paths from yaml config and validate."""
|
||||
root = "."
|
||||
|
||||
if not self.__load_from_root_config(os.path.abspath(root)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __load_from_root_config(self, invoke_root):
|
||||
"""Validate a yaml path exists, confirm the user wants to use it and load config."""
|
||||
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
|
||||
if os.path.exists(yaml_path):
|
||||
db_dir, outdir = self.__load_paths_from_yaml_file(yaml_path)
|
||||
|
||||
if db_dir is None or outdir is None:
|
||||
print("The invokeai.yaml file was found but is missing the db_dir and/or outdir setting!")
|
||||
return False
|
||||
|
||||
if os.path.isabs(db_dir):
|
||||
self.database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
|
||||
else:
|
||||
self.database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
|
||||
|
||||
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
|
||||
|
||||
if os.path.isabs(outdir):
|
||||
self.outputs_path = os.path.join(outdir, "images")
|
||||
self.archive_path = os.path.join(outdir, "images-archive")
|
||||
else:
|
||||
self.outputs_path = os.path.join(invoke_root, outdir, "images")
|
||||
self.archive_path = os.path.join(invoke_root, outdir, "images-archive")
|
||||
|
||||
self.thumbnails_path = os.path.join(self.outputs_path, "thumbnails")
|
||||
self.thumbnails_archive_path = os.path.join(self.archive_path, "thumbnails")
|
||||
|
||||
db_exists = os.path.exists(self.database_path)
|
||||
outdir_exists = os.path.exists(self.outputs_path)
|
||||
|
||||
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
|
||||
text += f"\n Database : {self.database_path} - {'Exists!' if db_exists else 'Not Found!'}"
|
||||
text += f"\n Outputs : {self.outputs_path}- {'Exists!' if outdir_exists else 'Not Found!'}"
|
||||
print(text)
|
||||
|
||||
if db_exists and outdir_exists:
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
"\nOne or more paths specified in invoke.yaml do not exist. Please inspect/correct the configuration and ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
|
||||
)
|
||||
return False
|
||||
else:
|
||||
print(
|
||||
f"Auto-discovery of configuration failed! Could not find ({yaml_path})!\n\nPlease ensure the script is run in the developer console mode (option 8) from an Invoke AI root directory."
|
||||
)
|
||||
return False
|
||||
|
||||
def __load_paths_from_yaml_file(self, yaml_path):
|
||||
"""Load an Invoke AI yaml file and get the database and outputs paths."""
|
||||
try:
|
||||
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
yamlinfo = yaml.safe_load(file)
|
||||
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
|
||||
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
|
||||
return db_dir, outdir
|
||||
except Exception:
|
||||
print(f"Failed to load paths from yaml file! {yaml_path}!")
|
||||
return None, None
|
||||
|
||||
|
||||
class MaintenanceStats:
|
||||
"""DTO for tracking work progress."""
|
||||
|
||||
def __init__(self): # noqa D107
|
||||
pass
|
||||
|
||||
time_start = datetime.datetime.utcnow()
|
||||
count_orphaned_db_entries_cleaned = 0
|
||||
count_orphaned_disk_files_cleaned = 0
|
||||
count_orphaned_thumbnails_cleaned = 0
|
||||
count_thumbnails_regenerated = 0
|
||||
count_errors = 0
|
||||
|
||||
@staticmethod
|
||||
def get_elapsed_time_string():
|
||||
"""Get a friendly time string for the time elapsed since processing start."""
|
||||
time_now = datetime.datetime.utcnow()
|
||||
total_seconds = (time_now - MaintenanceStats.time_start).total_seconds()
|
||||
hours = int((total_seconds) / 3600)
|
||||
minutes = int(((total_seconds) % 3600) / 60)
|
||||
seconds = total_seconds % 60
|
||||
out_str = f"{hours} hour(s) -" if hours > 0 else ""
|
||||
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
|
||||
out_str += f"{seconds:.2f} second(s)"
|
||||
return out_str
|
||||
|
||||
|
||||
class DatabaseMapper:
|
||||
"""Class to abstract database functionality."""
|
||||
|
||||
def __init__(self, database_path, database_backup_dir): # noqa D107
|
||||
self.database_path = database_path
|
||||
self.database_backup_dir = database_backup_dir
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
|
||||
def backup(self, timestamp_string):
|
||||
"""Take a backup of the database."""
|
||||
if not os.path.exists(self.database_backup_dir):
|
||||
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
|
||||
os.makedirs(self.database_backup_dir)
|
||||
print("Done!")
|
||||
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
|
||||
print(f"Making DB Backup at {database_backup_path}...", end="")
|
||||
shutil.copy2(self.database_path, database_backup_path)
|
||||
print("Done!")
|
||||
|
||||
def connect(self):
|
||||
"""Open connection to the database."""
|
||||
self.connection = sqlite3.connect(self.database_path)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
def get_all_image_files(self):
|
||||
"""Get the full list of image file names from the database."""
|
||||
sql_get_image_by_name = "SELECT image_name FROM images"
|
||||
self.cursor.execute(sql_get_image_by_name)
|
||||
rows = self.cursor.fetchall()
|
||||
db_files = []
|
||||
for row in rows:
|
||||
db_files.append(row[0])
|
||||
return db_files
|
||||
|
||||
def remove_image_file_record(self, filename: str):
|
||||
"""Remove an image file reference from the database by filename."""
|
||||
sanitized_filename = str.replace(filename, "'", "''") # prevent injection
|
||||
sql_command = f"DELETE FROM images WHERE image_name='{sanitized_filename}'"
|
||||
self.cursor.execute(sql_command)
|
||||
self.connection.commit()
|
||||
|
||||
def does_image_exist(self, image_filename):
|
||||
"""Check database if a image name already exists and return a boolean."""
|
||||
sanitized_filename = str.replace(image_filename, "'", "''") # prevent injection
|
||||
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{sanitized_filename}'"
|
||||
self.cursor.execute(sql_get_image_by_name)
|
||||
rows = self.cursor.fetchall()
|
||||
return True if len(rows) > 0 else False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from the db, cleaning up connections and cursors."""
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
if self.connection is not None:
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class PhysicalFileMapper:
|
||||
"""Containing class for script functionality."""
|
||||
|
||||
def __init__(self, outputs_path, thumbnails_path, archive_path, thumbnails_archive_path): # noqa D107
|
||||
self.outputs_path = outputs_path
|
||||
self.archive_path = archive_path
|
||||
self.thumbnails_path = thumbnails_path
|
||||
self.thumbnails_archive_path = thumbnails_archive_path
|
||||
|
||||
def create_archive_directories(self):
|
||||
"""Create the directory for archiving orphaned image files."""
|
||||
if not os.path.exists(self.archive_path):
|
||||
print(f"Image archive directory ({self.archive_path}) does not exist -> creating...", end="")
|
||||
os.makedirs(self.archive_path)
|
||||
print("Created!")
|
||||
if not os.path.exists(self.thumbnails_archive_path):
|
||||
print(
|
||||
f"Image thumbnails archive directory ({self.thumbnails_archive_path}) does not exist -> creating...",
|
||||
end="",
|
||||
)
|
||||
os.makedirs(self.thumbnails_archive_path)
|
||||
print("Created!")
|
||||
|
||||
def get_image_path_for_image_name(self, image_filename): # noqa D102
|
||||
return os.path.join(self.outputs_path, image_filename)
|
||||
|
||||
def image_file_exists(self, image_filename): # noqa D102
|
||||
return os.path.exists(self.get_image_path_for_image_name(image_filename))
|
||||
|
||||
def get_thumbnail_path_for_image(self, image_filename): # noqa D102
|
||||
return os.path.join(self.thumbnails_path, os.path.splitext(image_filename)[0]) + ".webp"
|
||||
|
||||
def get_image_name_from_thumbnail_path(self, thumbnail_path): # noqa D102
|
||||
return os.path.splitext(os.path.basename(thumbnail_path))[0] + ".png"
|
||||
|
||||
def thumbnail_exists_for_filename(self, image_filename): # noqa D102
|
||||
return os.path.exists(self.get_thumbnail_path_for_image(image_filename))
|
||||
|
||||
def archive_image(self, image_filename): # noqa D102
|
||||
if self.image_file_exists(image_filename):
|
||||
image_path = self.get_image_path_for_image_name(image_filename)
|
||||
shutil.move(image_path, self.archive_path)
|
||||
|
||||
def archive_thumbnail_by_image_filename(self, image_filename): # noqa D102
|
||||
if self.thumbnail_exists_for_filename(image_filename):
|
||||
thumbnail_path = self.get_thumbnail_path_for_image(image_filename)
|
||||
shutil.move(thumbnail_path, self.thumbnails_archive_path)
|
||||
|
||||
def get_all_png_filenames_in_directory(self, directory_path): # noqa D102
|
||||
filepaths = glob.glob(directory_path + "/*.png", recursive=False)
|
||||
filenames = []
|
||||
for filepath in filepaths:
|
||||
filenames.append(os.path.basename(filepath))
|
||||
return filenames
|
||||
|
||||
def get_all_thumbnails_with_full_path(self, thumbnails_directory): # noqa D102
|
||||
return glob.glob(thumbnails_directory + "/*.webp", recursive=False)
|
||||
|
||||
def generate_thumbnail_for_image_name(self, image_filename): # noqa D102
|
||||
# create thumbnail
|
||||
file_path = self.get_image_path_for_image_name(image_filename)
|
||||
thumb_path = self.get_thumbnail_path_for_image(image_filename)
|
||||
thumb_size = 256, 256
|
||||
with PIL.Image.open(file_path) as source_image:
|
||||
source_image.thumbnail(thumb_size)
|
||||
source_image.save(thumb_path, "webp")
|
||||
|
||||
|
||||
class MaintenanceOperation(str, enum.Enum):
|
||||
"""Enum class for operations."""
|
||||
|
||||
Ask = "ask"
|
||||
CleanOrphanedDbEntries = "clean"
|
||||
CleanOrphanedDiskFiles = "archive"
|
||||
ReGenerateThumbnails = "thumbnails"
|
||||
All = "all"
|
||||
|
||||
|
||||
class InvokeAIDatabaseMaintenanceApp:
|
||||
"""Main processor class for the application."""
|
||||
|
||||
_operation: MaintenanceOperation
|
||||
_headless: bool = False
|
||||
__stats: MaintenanceStats = MaintenanceStats()
|
||||
|
||||
def __init__(self, operation: MaintenanceOperation = MaintenanceOperation.Ask):
|
||||
"""Initialize maintenance app."""
|
||||
self._operation = MaintenanceOperation(operation)
|
||||
self._headless = operation != MaintenanceOperation.Ask
|
||||
|
||||
def ask_for_operation(self) -> MaintenanceOperation:
|
||||
"""Ask user to choose the operation to perform."""
|
||||
while True:
|
||||
print()
|
||||
print("It is recommennded to run these operations as ordered below to avoid additional")
|
||||
print("work being performed that will be discarded in a subsequent step.")
|
||||
print()
|
||||
print("Select maintenance operation:")
|
||||
print()
|
||||
print("1) Clean Orphaned Database Image Entries")
|
||||
print(" Cleans entries in the database where the matching file was removed from")
|
||||
print(" the outputs directory.")
|
||||
print("2) Archive Orphaned Image Files")
|
||||
print(" Files found in the outputs directory without an entry in the database are")
|
||||
print(" moved to an archive directory.")
|
||||
print("3) Re-Generate Missing Thumbnail Files")
|
||||
print(" For files found in the outputs directory, re-generate a thumbnail if it")
|
||||
print(" not found in the thumbnails directory.")
|
||||
print()
|
||||
print("(CTRL-C to quit)")
|
||||
|
||||
try:
|
||||
input_option = int(input("Specify desired operation number (1-3): "))
|
||||
|
||||
operations = [
|
||||
MaintenanceOperation.CleanOrphanedDbEntries,
|
||||
MaintenanceOperation.CleanOrphanedDiskFiles,
|
||||
MaintenanceOperation.ReGenerateThumbnails,
|
||||
]
|
||||
return operations[input_option - 1]
|
||||
except (IndexError, ValueError):
|
||||
print("\nInvalid selection!")
|
||||
|
||||
def ask_to_continue(self) -> bool:
|
||||
"""Ask user whether they want to continue with the operation."""
|
||||
while True:
|
||||
input_choice = input("Do you wish to continue? (Y or N)? ")
|
||||
if str.lower(input_choice) == "y":
|
||||
return True
|
||||
if str.lower(input_choice) == "n":
|
||||
return False
|
||||
|
||||
def clean_orphaned_db_entries(
|
||||
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
|
||||
):
|
||||
"""Clean dangling database entries that no longer point to a file in outputs."""
|
||||
if self._headless:
|
||||
print(f"Removing database references to images that no longer exist in {config.outputs_path}...")
|
||||
else:
|
||||
print()
|
||||
print("===============================================================================")
|
||||
print("= Clean Orphaned Database Entries")
|
||||
print()
|
||||
print("Perform this operation if you have removed files from the outputs/images")
|
||||
print("directory but the database was never updated. You may see this as empty imaages")
|
||||
print("in the app gallery, or images that only show an enlarged version of the")
|
||||
print("thumbnail.")
|
||||
print()
|
||||
print(f"Database File Path : {config.database_path}")
|
||||
print(f"Database backup will be taken at : {config.database_backup_dir}")
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Outputs/Images Archive Directory : {config.archive_path}")
|
||||
|
||||
print("\nNotes about this operation:")
|
||||
print("- This operation will find database image file entries that do not exist in the")
|
||||
print(" outputs/images dir and remove those entries from the database.")
|
||||
print("- This operation will target all image types including intermediate files.")
|
||||
print("- If a thumbnail still exists in outputs/images/thumbnails matching the")
|
||||
print(" orphaned entry, it will be moved to the archive directory.")
|
||||
print()
|
||||
|
||||
if not self.ask_to_continue():
|
||||
raise KeyboardInterrupt
|
||||
|
||||
file_mapper.create_archive_directories()
|
||||
db_mapper.backup(config.TIMESTAMP_STRING)
|
||||
db_mapper.connect()
|
||||
db_files = db_mapper.get_all_image_files()
|
||||
for db_file in db_files:
|
||||
try:
|
||||
if not file_mapper.image_file_exists(db_file):
|
||||
print(f"Found orphaned image db entry {db_file}. Cleaning ...", end="")
|
||||
db_mapper.remove_image_file_record(db_file)
|
||||
print("Cleaned!")
|
||||
if file_mapper.thumbnail_exists_for_filename(db_file):
|
||||
print("A thumbnail was found, archiving ...", end="")
|
||||
file_mapper.archive_thumbnail_by_image_filename(db_file)
|
||||
print("Archived!")
|
||||
self.__stats.count_orphaned_db_entries_cleaned += 1
|
||||
except Exception as ex:
|
||||
print("An error occurred cleaning db entry, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
def clean_orphaned_disk_files(
|
||||
self, config: ConfigMapper, file_mapper: PhysicalFileMapper, db_mapper: DatabaseMapper
|
||||
):
|
||||
"""Archive image files that no longer have entries in the database."""
|
||||
if self._headless:
|
||||
print(f"Archiving orphaned image files to {config.archive_path}...")
|
||||
else:
|
||||
print()
|
||||
print("===============================================================================")
|
||||
print("= Clean Orphaned Disk Files")
|
||||
print()
|
||||
print("Perform this operation if you have files that were copied into the outputs")
|
||||
print("directory which are not referenced by the database. This can happen if you")
|
||||
print("upgraded to a version with a fresh database, but re-used the outputs directory")
|
||||
print("and now new images are mixed with the files not in the db. The script will")
|
||||
print("archive these files so you can choose to delete them or re-import using the")
|
||||
print("official import script.")
|
||||
print()
|
||||
print(f"Database File Path : {config.database_path}")
|
||||
print(f"Database backup will be taken at : {config.database_backup_dir}")
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Outputs/Images Archive Directory : {config.archive_path}")
|
||||
|
||||
print("\nNotes about this operation:")
|
||||
print("- This operation will find image files not referenced by the database and move to an")
|
||||
print(" archive directory.")
|
||||
print("- This operation will target all image types including intermediate references.")
|
||||
print("- The matching thumbnail will also be archived.")
|
||||
print("- Any remaining orphaned thumbnails will also be archived.")
|
||||
|
||||
if not self.ask_to_continue():
|
||||
raise KeyboardInterrupt
|
||||
|
||||
print()
|
||||
|
||||
file_mapper.create_archive_directories()
|
||||
db_mapper.backup(config.TIMESTAMP_STRING)
|
||||
db_mapper.connect()
|
||||
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
|
||||
for phys_file in phys_files:
|
||||
try:
|
||||
if not db_mapper.does_image_exist(phys_file):
|
||||
print(f"Found orphaned file {phys_file}, archiving...", end="")
|
||||
file_mapper.archive_image(phys_file)
|
||||
print("Archived!")
|
||||
if file_mapper.thumbnail_exists_for_filename(phys_file):
|
||||
print("Related thumbnail exists, archiving...", end="")
|
||||
file_mapper.archive_thumbnail_by_image_filename(phys_file)
|
||||
print("Archived!")
|
||||
else:
|
||||
print("No matching thumbnail existed to be cleaned.")
|
||||
self.__stats.count_orphaned_disk_files_cleaned += 1
|
||||
except Exception as ex:
|
||||
print("Error found trying to archive file or thumbnail, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
thumb_filepaths = file_mapper.get_all_thumbnails_with_full_path(config.thumbnails_path)
|
||||
# archive any remaining orphaned thumbnails
|
||||
for thumb_filepath in thumb_filepaths:
|
||||
try:
|
||||
thumb_src_image_name = file_mapper.get_image_name_from_thumbnail_path(thumb_filepath)
|
||||
if not file_mapper.image_file_exists(thumb_src_image_name):
|
||||
print(f"Found orphaned thumbnail {thumb_filepath}, archiving...", end="")
|
||||
file_mapper.archive_thumbnail_by_image_filename(thumb_src_image_name)
|
||||
print("Archived!")
|
||||
self.__stats.count_orphaned_thumbnails_cleaned += 1
|
||||
except Exception as ex:
|
||||
print("Error found trying to archive thumbnail, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
def regenerate_thumbnails(self, config: ConfigMapper, file_mapper: PhysicalFileMapper, *args):
|
||||
"""Create missing thumbnails for any valid general images both in the db and on disk."""
|
||||
if self._headless:
|
||||
print("Regenerating missing image thumbnails...")
|
||||
else:
|
||||
print()
|
||||
print("===============================================================================")
|
||||
print("= Regenerate Thumbnails")
|
||||
print()
|
||||
print("This operation will find files that have no matching thumbnail on disk")
|
||||
print("and regenerate those thumbnail files.")
|
||||
print("NOTE: It is STRONGLY recommended that the user first clean/archive orphaned")
|
||||
print(" disk files from the previous menu to avoid wasting time regenerating")
|
||||
print(" thumbnails for orphaned files.")
|
||||
|
||||
print()
|
||||
print(f"Outputs/Images Directory : {config.outputs_path}")
|
||||
print(f"Outputs/Images Directory : {config.thumbnails_path}")
|
||||
|
||||
print("\nNotes about this operation:")
|
||||
print("- This operation will find image files both referenced in the db and on disk")
|
||||
print(" that do not have a matching thumbnail on disk and re-generate the thumbnail")
|
||||
print(" file.")
|
||||
|
||||
if not self.ask_to_continue():
|
||||
raise KeyboardInterrupt
|
||||
|
||||
print()
|
||||
|
||||
phys_files = file_mapper.get_all_png_filenames_in_directory(config.outputs_path)
|
||||
for phys_file in phys_files:
|
||||
try:
|
||||
if not file_mapper.thumbnail_exists_for_filename(phys_file):
|
||||
print(f"Found file without thumbnail {phys_file}...Regenerating Thumbnail...", end="")
|
||||
file_mapper.generate_thumbnail_for_image_name(phys_file)
|
||||
print("Done!")
|
||||
self.__stats.count_thumbnails_regenerated += 1
|
||||
except Exception as ex:
|
||||
print("Error found trying to regenerate thumbnail, error was:")
|
||||
print(ex)
|
||||
self.__stats.count_errors += 1
|
||||
|
||||
def main(self): # noqa D107
|
||||
print("\n===============================================================================")
|
||||
print("Database and outputs Maintenance for Invoke AI 3.0.0 +")
|
||||
print("===============================================================================\n")
|
||||
|
||||
config_mapper = ConfigMapper()
|
||||
if not config_mapper.load():
|
||||
print("\nInvalid configuration...exiting.\n")
|
||||
return
|
||||
|
||||
file_mapper = PhysicalFileMapper(
|
||||
config_mapper.outputs_path,
|
||||
config_mapper.thumbnails_path,
|
||||
config_mapper.archive_path,
|
||||
config_mapper.thumbnails_archive_path,
|
||||
)
|
||||
db_mapper = DatabaseMapper(config_mapper.database_path, config_mapper.database_backup_dir)
|
||||
|
||||
op = self._operation
|
||||
operations_to_perform = []
|
||||
|
||||
if op == MaintenanceOperation.Ask:
|
||||
op = self.ask_for_operation()
|
||||
|
||||
if op in [MaintenanceOperation.CleanOrphanedDbEntries, MaintenanceOperation.All]:
|
||||
operations_to_perform.append(self.clean_orphaned_db_entries)
|
||||
if op in [MaintenanceOperation.CleanOrphanedDiskFiles, MaintenanceOperation.All]:
|
||||
operations_to_perform.append(self.clean_orphaned_disk_files)
|
||||
if op in [MaintenanceOperation.ReGenerateThumbnails, MaintenanceOperation.All]:
|
||||
operations_to_perform.append(self.regenerate_thumbnails)
|
||||
|
||||
for operation in operations_to_perform:
|
||||
operation(config_mapper, file_mapper, db_mapper)
|
||||
|
||||
print("\n===============================================================================")
|
||||
print(f"= Maintenance Complete - Elapsed Time: {MaintenanceStats.get_elapsed_time_string()}")
|
||||
print()
|
||||
print(f"Orphaned db entries cleaned : {self.__stats.count_orphaned_db_entries_cleaned}")
|
||||
print(f"Orphaned disk files archived : {self.__stats.count_orphaned_disk_files_cleaned}")
|
||||
print(f"Orphaned thumbnail files archived : {self.__stats.count_orphaned_thumbnails_cleaned}")
|
||||
print(f"Thumbnails regenerated : {self.__stats.count_thumbnails_regenerated}")
|
||||
print(f"Errors during operation : {self.__stats.count_errors}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main(): # noqa D107
|
||||
parser = argparse.ArgumentParser(
|
||||
description="InvokeAI image database maintenance utility",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""Operations:
|
||||
ask Choose operation from a menu [default]
|
||||
all Run all maintenance operations
|
||||
clean Clean database of dangling entries
|
||||
archive Archive orphaned image files
|
||||
thumbnails Regenerate missing image thumbnails
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--root", default=".", type=Path, help="InvokeAI root directory")
|
||||
parser.add_argument(
|
||||
"--operation", default="ask", choices=[x.value for x in MaintenanceOperation], help="Operation to perform."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
os.chdir(args.root)
|
||||
app = InvokeAIDatabaseMaintenanceApp(args.operation)
|
||||
app.main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nUser cancelled execution.")
|
||||
except FileNotFoundError:
|
||||
print(f"Invalid root directory '{args.root}'.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -264,6 +264,22 @@
|
||||
"graphQueued": "Graph queued",
|
||||
"graphFailedToQueue": "Failed to queue graph"
|
||||
},
|
||||
"invocationCache": {
|
||||
"invocationCache": "Invocation Cache",
|
||||
"cacheSize": "Cache Size",
|
||||
"maxCacheSize": "Max Cache Size",
|
||||
"hits": "Cache Hits",
|
||||
"misses": "Cache Misses",
|
||||
"clear": "Clear",
|
||||
"clearSucceeded": "Invocation Cache Cleared",
|
||||
"clearFailed": "Problem Clearing Invocation Cache",
|
||||
"enable": "Enable",
|
||||
"enableSucceeded": "Invocation Cache Enabled",
|
||||
"enableFailed": "Problem Enabling Invocation Cache",
|
||||
"disable": "Disable",
|
||||
"disableSucceeded": "Invocation Cache Disabled",
|
||||
"disableFailed": "Problem Disabling Invocation Cache"
|
||||
},
|
||||
"gallery": {
|
||||
"allImagesLoaded": "All Images Loaded",
|
||||
"assets": "Assets",
|
||||
@ -1213,14 +1229,14 @@
|
||||
},
|
||||
"dynamicPromptsCombinatorial": {
|
||||
"heading": "Combinatorial Generation",
|
||||
"paragraph": "Generate an image for every possible combination of Dynamic Prompt until the Max Prompts is reached."
|
||||
"paragraph": "Generate an image for every possible combination of Dynamic Prompts until the Max Prompts is reached."
|
||||
},
|
||||
"infillMethod": {
|
||||
"heading": "Infill Method",
|
||||
"paragraph": "Method to infill the selected area."
|
||||
},
|
||||
"lora": {
|
||||
"heading": "LoRA",
|
||||
"heading": "LoRA Weight",
|
||||
"paragraph": "Weight of the LoRA. Higher weight will lead to larger impacts on the final image."
|
||||
},
|
||||
"noiseEnable": {
|
||||
@ -1239,21 +1255,21 @@
|
||||
"heading": "Denoising Strength",
|
||||
"paragraph": "How much noise is added to the input image. 0 will result in an identical image, while 1 will result in a completely new image."
|
||||
},
|
||||
"paramImages": {
|
||||
"heading": "Images",
|
||||
"paragraph": "Number of images that will be generated."
|
||||
"paramIterations": {
|
||||
"heading": "Iterations",
|
||||
"paragraph": "The number of images to generate. If Dynamic Prompts is enabled, each of the prompts will be generated this many times."
|
||||
},
|
||||
"paramModel": {
|
||||
"heading": "Model",
|
||||
"paragraph": "Model used for the denoising steps. Different models are trained to specialize in producing different aesthetic results and content."
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"heading": "Negative Prompts",
|
||||
"paragraph": "This is where you enter your negative prompts."
|
||||
"heading": "Negative Prompt",
|
||||
"paragraph": "The generation process avoids the concepts in the negative prompt. Use this to exclude qualities or objects from the output. Supports Compel syntax and embeddings."
|
||||
},
|
||||
"paramPositiveConditioning": {
|
||||
"heading": "Positive Prompts",
|
||||
"paragraph": "This is where you enter your positive prompts."
|
||||
"heading": "Positive Prompt",
|
||||
"paragraph": "Guides the generation process. You may use any words or phrases. Supports Compel and Dynamic Prompts syntaxes and embeddings."
|
||||
},
|
||||
"paramRatio": {
|
||||
"heading": "Ratio",
|
||||
|
@ -3,7 +3,7 @@ import { startAppListening } from '..';
|
||||
import { $logger } from 'app/logging/logger';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard';
|
||||
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addCanvasCopiedToClipboardListener = () => {
|
||||
@ -15,10 +15,12 @@ export const addCanvasCopiedToClipboardListener = () => {
|
||||
.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
try {
|
||||
const blob = getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
copyBlobToClipboard(blob);
|
||||
} catch (err) {
|
||||
moduleLog.error(String(err));
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('toast.problemCopyingCanvas'),
|
||||
@ -29,8 +31,6 @@ export const addCanvasCopiedToClipboardListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
copyBlobToClipboard(blob);
|
||||
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('toast.canvasCopiedClipboard'),
|
||||
|
@ -15,10 +15,11 @@ export const addCanvasDownloadedAsImageListener = () => {
|
||||
.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
let blob;
|
||||
try {
|
||||
blob = await getBaseLayerBlob(state);
|
||||
} catch (err) {
|
||||
moduleLog.error(String(err));
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('toast.problemDownloadingCanvas'),
|
||||
|
@ -14,10 +14,11 @@ export const addCanvasImageToControlNetListener = () => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
log.error('Problem getting base layer blob');
|
||||
let blob;
|
||||
try {
|
||||
blob = await getBaseLayerBlob(state);
|
||||
} catch (err) {
|
||||
log.error(String(err));
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('toast.problemSavingCanvas'),
|
||||
|
@ -13,10 +13,11 @@ export const addCanvasSavedToGalleryListener = () => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
log.error('Problem getting base layer blob');
|
||||
let blob;
|
||||
try {
|
||||
blob = await getBaseLayerBlob(state);
|
||||
} catch (err) {
|
||||
log.error(String(err));
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('toast.problemSavingCanvas'),
|
||||
|
@ -4,7 +4,9 @@ import { parseify } from 'common/util/serialize';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
controlNetImageChanged,
|
||||
controlNetIsEnabledChanged,
|
||||
ipAdapterImageChanged,
|
||||
isIPAdapterEnabledChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
@ -99,6 +101,12 @@ export const addImageDroppedListener = () => {
|
||||
controlNetId,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
controlNetIsEnabledChanged({
|
||||
controlNetId,
|
||||
isEnabled: true,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -111,6 +119,7 @@ export const addImageDroppedListener = () => {
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO));
|
||||
dispatch(isIPAdapterEnabledChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,9 @@ import { logger } from 'app/logging/logger';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
controlNetImageChanged,
|
||||
controlNetIsEnabledChanged,
|
||||
ipAdapterImageChanged,
|
||||
isIPAdapterEnabledChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
@ -87,6 +89,12 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
|
||||
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
|
||||
const { controlNetId } = postUploadAction;
|
||||
dispatch(
|
||||
controlNetIsEnabledChanged({
|
||||
controlNetId,
|
||||
isEnabled: true,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId,
|
||||
@ -104,6 +112,7 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
|
||||
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
|
||||
dispatch(ipAdapterImageChanged(imageDTO));
|
||||
dispatch(isIPAdapterEnabledChanged(true));
|
||||
dispatch(
|
||||
addToast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
|
@ -4,6 +4,7 @@ import { api } from 'services/api';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
import { isInitializedChanged } from 'features/system/store/systemSlice';
|
||||
|
||||
export const addSocketConnectedEventListener = () => {
|
||||
startAppListening({
|
||||
@ -13,7 +14,7 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
log.debug('Connected');
|
||||
|
||||
const { nodes, config } = getState();
|
||||
const { nodes, config, system } = getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
@ -21,7 +22,12 @@ export const addSocketConnectedEventListener = () => {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}
|
||||
|
||||
dispatch(api.util.resetApiState());
|
||||
if (system.isInitialized) {
|
||||
// only reset the query caches if this connect event is a *reconnect* event
|
||||
dispatch(api.util.resetApiState());
|
||||
} else {
|
||||
dispatch(isInitializedChanged(true));
|
||||
}
|
||||
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketConnected(action.payload));
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { api } from 'services/api';
|
||||
import {
|
||||
appSocketDisconnected,
|
||||
socketDisconnected,
|
||||
@ -13,8 +12,6 @@ export const addSocketDisconnectedEventListener = () => {
|
||||
const log = logger('socketio');
|
||||
log.debug('Disconnected');
|
||||
|
||||
dispatch(api.util.resetApiState());
|
||||
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketDisconnected(action.payload));
|
||||
},
|
||||
|
@ -21,7 +21,8 @@ export type AppFeature =
|
||||
| 'multiselect'
|
||||
| 'pauseQueue'
|
||||
| 'resumeQueue'
|
||||
| 'prependQueue';
|
||||
| 'prependQueue'
|
||||
| 'invocationCache';
|
||||
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
|
@ -0,0 +1,22 @@
|
||||
import { Box, Image } from '@chakra-ui/react';
|
||||
import InvokeAILogoImage from 'assets/images/logo.png';
|
||||
import { memo } from 'react';
|
||||
|
||||
const GreyscaleInvokeAIIcon = () => (
|
||||
<Box pos="relative" w={4} h={4}>
|
||||
<Image
|
||||
src={InvokeAILogoImage}
|
||||
alt="invoke-ai-logo"
|
||||
pos="absolute"
|
||||
top={-0.5}
|
||||
insetInlineStart={-0.5}
|
||||
w={5}
|
||||
h={5}
|
||||
minW={5}
|
||||
minH={5}
|
||||
filter="saturate(0)"
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
|
||||
export default memo(GreyscaleInvokeAIIcon);
|
@ -31,7 +31,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
||||
insetInlineStart={0}
|
||||
w="full"
|
||||
h="full"
|
||||
pointerEvents="none"
|
||||
pointerEvents={active ? 'auto' : 'none'}
|
||||
>
|
||||
<AnimatePresence>
|
||||
{isValidDrop(data, active) && (
|
||||
|
@ -1,58 +1,67 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Popover,
|
||||
PopoverTrigger,
|
||||
PopoverContent,
|
||||
PopoverArrow,
|
||||
PopoverCloseButton,
|
||||
PopoverHeader,
|
||||
PopoverBody,
|
||||
PopoverProps,
|
||||
Divider,
|
||||
Flex,
|
||||
Text,
|
||||
Heading,
|
||||
Image,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverCloseButton,
|
||||
PopoverContent,
|
||||
PopoverProps,
|
||||
PopoverTrigger,
|
||||
Portal,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppSelector } from '../../app/store/storeHooks';
|
||||
import { ReactNode, memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAppSelector } from '../../app/store/storeHooks';
|
||||
|
||||
interface Props extends PopoverProps {
|
||||
const OPEN_DELAY = 1500;
|
||||
|
||||
type Props = Omit<PopoverProps, 'children'> & {
|
||||
details: string;
|
||||
children: JSX.Element;
|
||||
children: ReactNode;
|
||||
image?: string;
|
||||
buttonLabel?: string;
|
||||
buttonHref?: string;
|
||||
placement?: PopoverProps['placement'];
|
||||
}
|
||||
};
|
||||
|
||||
function IAIInformationalPopover({
|
||||
const IAIInformationalPopover = ({
|
||||
details,
|
||||
image,
|
||||
buttonLabel,
|
||||
buttonHref,
|
||||
children,
|
||||
placement,
|
||||
}: Props): JSX.Element {
|
||||
const shouldDisableInformationalPopovers = useAppSelector(
|
||||
(state) => state.system.shouldDisableInformationalPopovers
|
||||
}: Props) => {
|
||||
const shouldEnableInformationalPopovers = useAppSelector(
|
||||
(state) => state.system.shouldEnableInformationalPopovers
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const heading = t(`popovers.${details}.heading`);
|
||||
const paragraph = t(`popovers.${details}.paragraph`);
|
||||
|
||||
if (shouldDisableInformationalPopovers) {
|
||||
return children;
|
||||
} else {
|
||||
return (
|
||||
<Popover
|
||||
placement={placement || 'top'}
|
||||
closeOnBlur={false}
|
||||
trigger="hover"
|
||||
variant="informational"
|
||||
>
|
||||
<PopoverTrigger>
|
||||
<div>{children}</div>
|
||||
</PopoverTrigger>
|
||||
if (!shouldEnableInformationalPopovers) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Popover
|
||||
placement={placement || 'top'}
|
||||
closeOnBlur={false}
|
||||
trigger="hover"
|
||||
variant="informational"
|
||||
openDelay={OPEN_DELAY}
|
||||
>
|
||||
<PopoverTrigger>
|
||||
<Box w="full">{children}</Box>
|
||||
</PopoverTrigger>
|
||||
<Portal>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
<PopoverCloseButton />
|
||||
@ -83,14 +92,17 @@ function IAIInformationalPopover({
|
||||
gap: 3,
|
||||
flexDirection: 'column',
|
||||
width: '100%',
|
||||
p: 3,
|
||||
pt: heading ? 0 : 3,
|
||||
}}
|
||||
>
|
||||
{heading && <PopoverHeader>{heading}</PopoverHeader>}
|
||||
<Text sx={{ px: 3 }}>{paragraph}</Text>
|
||||
{heading && (
|
||||
<>
|
||||
<Heading size="sm">{heading}</Heading>
|
||||
<Divider />
|
||||
</>
|
||||
)}
|
||||
<Text>{paragraph}</Text>
|
||||
{buttonLabel && (
|
||||
<Flex sx={{ px: 3 }} justifyContent="flex-end">
|
||||
<Flex justifyContent="flex-end">
|
||||
<Button
|
||||
onClick={() => window.open(buttonHref)}
|
||||
size="sm"
|
||||
@ -104,9 +116,9 @@ function IAIInformationalPopover({
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
}
|
||||
</Portal>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export default IAIInformationalPopover;
|
||||
export default memo(IAIInformationalPopover);
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip, forwardRef } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
@ -11,7 +11,7 @@ type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
|
||||
label?: string;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const IAIMantineMultiSelect = forwardRef((props: IAIMultiSelectProps, ref) => {
|
||||
const {
|
||||
searchable = true,
|
||||
tooltip,
|
||||
@ -47,7 +47,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
<MultiSelect
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormControl ref={ref} isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
@ -63,6 +63,8 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
IAIMantineMultiSelect.displayName = 'IAIMantineMultiSelect';
|
||||
|
||||
export default memo(IAIMantineMultiSelect);
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip, forwardRef } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
@ -17,7 +17,7 @@ type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
||||
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||
const IAIMantineSearchableSelect = forwardRef((props: IAISelectProps, ref) => {
|
||||
const {
|
||||
searchable = true,
|
||||
tooltip,
|
||||
@ -74,7 +74,7 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||
ref={inputRef}
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormControl ref={ref} isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
@ -92,6 +92,8 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
IAIMantineSearchableSelect.displayName = 'IAIMantineSearchableSelect';
|
||||
|
||||
export default memo(IAIMantineSearchableSelect);
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip, forwardRef } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
|
||||
import { RefObject, memo } from 'react';
|
||||
@ -15,7 +15,7 @@ export type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
label?: string;
|
||||
};
|
||||
|
||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
const IAIMantineSelect = forwardRef((props: IAISelectProps, ref) => {
|
||||
const { tooltip, inputRef, label, disabled, required, ...rest } = props;
|
||||
|
||||
const styles = useMantineSelectStyles();
|
||||
@ -25,7 +25,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
<Select
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isRequired={required} isDisabled={disabled}>
|
||||
<FormControl ref={ref} isRequired={required} isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
@ -37,6 +37,8 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
IAIMantineSelect.displayName = 'IAIMantineSelect';
|
||||
|
||||
export default memo(IAIMantineSelect);
|
||||
|
@ -13,6 +13,7 @@ import {
|
||||
NumberInputStepperProps,
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
forwardRef,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
@ -50,7 +51,7 @@ interface Props extends Omit<NumberInputProps, 'onChange'> {
|
||||
/**
|
||||
* Customized Chakra FormControl + NumberInput multi-part component.
|
||||
*/
|
||||
const IAINumberInput = (props: Props) => {
|
||||
const IAINumberInput = forwardRef((props: Props, ref) => {
|
||||
const {
|
||||
label,
|
||||
isDisabled = false,
|
||||
@ -141,6 +142,7 @@ const IAINumberInput = (props: Props) => {
|
||||
return (
|
||||
<Tooltip {...tooltipProps}>
|
||||
<FormControl
|
||||
ref={ref}
|
||||
isDisabled={isDisabled}
|
||||
isInvalid={isInvalid}
|
||||
{...formControlProps}
|
||||
@ -172,6 +174,8 @@ const IAINumberInput = (props: Props) => {
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
IAINumberInput.displayName = 'IAINumberInput';
|
||||
|
||||
export default memo(IAINumberInput);
|
||||
|
@ -22,6 +22,7 @@ import {
|
||||
SliderTrackProps,
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
forwardRef,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
||||
@ -71,7 +72,7 @@ export type IAIFullSliderProps = {
|
||||
sliderIAIIconButtonProps?: IAIIconButtonProps;
|
||||
};
|
||||
|
||||
const IAISlider = (props: IAIFullSliderProps) => {
|
||||
const IAISlider = forwardRef((props: IAIFullSliderProps, ref) => {
|
||||
const [showTooltip, setShowTooltip] = useState(false);
|
||||
const {
|
||||
label,
|
||||
@ -187,6 +188,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
|
||||
return (
|
||||
<FormControl
|
||||
ref={ref}
|
||||
onClick={forceInputBlur}
|
||||
sx={
|
||||
isCompact
|
||||
@ -354,6 +356,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
</HStack>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
IAISlider.displayName = 'IAISlider';
|
||||
|
||||
export default memo(IAISlider);
|
||||
|
@ -72,4 +72,6 @@ const IAISwitch = (props: IAISwitchProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
IAISwitch.displayName = 'IAISwitch';
|
||||
|
||||
export default memo(IAISwitch);
|
||||
|
@ -9,7 +9,7 @@ export const getBaseLayerBlob = async (state: RootState) => {
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
|
||||
if (!canvasBaseLayer) {
|
||||
return;
|
||||
throw new Error('Problem getting base layer blob');
|
||||
}
|
||||
|
||||
const {
|
||||
|
@ -1,12 +1,12 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetDuplicated,
|
||||
controlNetRemoved,
|
||||
controlNetToggled,
|
||||
controlNetIsEnabledChanged,
|
||||
} from '../store/controlNetSlice';
|
||||
import ParamControlNetModel from './parameters/ParamControlNetModel';
|
||||
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
||||
@ -77,9 +77,17 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
);
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
const handleToggleIsEnabled = useCallback(() => {
|
||||
dispatch(controlNetToggled({ controlNetId }));
|
||||
}, [controlNetId, dispatch]);
|
||||
const handleToggleIsEnabled = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
controlNetIsEnabledChanged({
|
||||
controlNetId,
|
||||
isEnabled: e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -106,8 +114,8 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
sx={{
|
||||
w: 'full',
|
||||
minW: 0,
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||
// opacity: isEnabled ? 1 : 0.5,
|
||||
// pointerEvents: isEnabled ? 'auto' : 'none',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
}}
|
||||
|
@ -13,6 +13,7 @@ import {
|
||||
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
|
||||
import {
|
||||
useAddImageToBoardMutation,
|
||||
@ -26,7 +27,6 @@ import {
|
||||
ControlNetConfig,
|
||||
controlNetImageChanged,
|
||||
} from '../store/controlNetSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
controlNet: ControlNetConfig;
|
||||
@ -52,7 +52,6 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
controlImage: controlImageName,
|
||||
processedControlImage: processedControlImageName,
|
||||
processorType,
|
||||
isEnabled,
|
||||
controlNetId,
|
||||
} = controlNet;
|
||||
|
||||
@ -172,15 +171,13 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
h: isSmall ? 28 : 366, // magic no touch
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||
isDropDisabled={shouldShowProcessedImage}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
@ -202,7 +199,6 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
droppableData={droppableData}
|
||||
imageDTO={processedControlImage}
|
||||
isUploadDisabled={true}
|
||||
isDropDisabled={!isEnabled}
|
||||
/>
|
||||
</Box>
|
||||
|
||||
|
@ -3,8 +3,8 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { isIPAdapterEnableToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { isIPAdapterEnabledChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -22,9 +22,12 @@ const ParamIPAdapterFeatureToggle = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(() => {
|
||||
dispatch(isIPAdapterEnableToggled());
|
||||
}, [dispatch]);
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(isIPAdapterEnabledChanged(e.target.checked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
|
@ -1,7 +1,9 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
@ -16,15 +18,17 @@ import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/types';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { ipAdapterInfo } = controlNet;
|
||||
return { ipAdapterInfo };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamIPAdapterImage = () => {
|
||||
const ipAdapterInfo = useAppSelector(
|
||||
(state: RootState) => state.controlNet.ipAdapterInfo
|
||||
);
|
||||
|
||||
const isIPAdapterEnabled = useAppSelector(
|
||||
(state: RootState) => state.controlNet.isIPAdapterEnabled
|
||||
);
|
||||
|
||||
const { ipAdapterInfo } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -71,8 +75,6 @@ const ParamIPAdapterImage = () => {
|
||||
droppableData={droppableData}
|
||||
draggableData={draggableData}
|
||||
postUploadAction={postUploadAction}
|
||||
isUploadDisabled={!isIPAdapterEnabled}
|
||||
isDropDisabled={!isIPAdapterEnabled}
|
||||
dropLabel={t('toast.setIPAdapterImage')}
|
||||
noContentFallback={
|
||||
<IAINoContentFallback
|
||||
|
@ -11,6 +11,9 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const ParamIPAdapterModelSelect = () => {
|
||||
const isEnabled = useAppSelector(
|
||||
(state: RootState) => state.controlNet.isIPAdapterEnabled
|
||||
);
|
||||
const ipAdapterModel = useAppSelector(
|
||||
(state: RootState) => state.controlNet.ipAdapterInfo.model
|
||||
);
|
||||
@ -90,6 +93,7 @@ const ParamIPAdapterModelSelect = () => {
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
sx={{ width: '100%' }}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,3 +1,4 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@ -26,16 +27,15 @@ const ParamControlNetFeatureToggle = () => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IAIInformationalPopover details="controlNetToggle">
|
||||
<IAISwitch
|
||||
label="Enable ControlNet"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleChange}
|
||||
formControlProps={{
|
||||
width: '100%',
|
||||
}}
|
||||
/>
|
||||
</IAIInformationalPopover>
|
||||
<Box width="100%">
|
||||
<IAIInformationalPopover details="controlNetToggle">
|
||||
<IAISwitch
|
||||
label="Enable ControlNet"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
</IAIInformationalPopover>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -146,16 +146,16 @@ export const controlNetSlice = createSlice({
|
||||
const { controlNetId } = action.payload;
|
||||
delete state.controlNets[controlNetId];
|
||||
},
|
||||
controlNetToggled: (
|
||||
controlNetIsEnabledChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ controlNetId: string }>
|
||||
action: PayloadAction<{ controlNetId: string; isEnabled: boolean }>
|
||||
) => {
|
||||
const { controlNetId } = action.payload;
|
||||
const { controlNetId, isEnabled } = action.payload;
|
||||
const cn = state.controlNets[controlNetId];
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
cn.isEnabled = !cn.isEnabled;
|
||||
cn.isEnabled = isEnabled;
|
||||
},
|
||||
controlNetImageChanged: (
|
||||
state,
|
||||
@ -377,8 +377,8 @@ export const controlNetSlice = createSlice({
|
||||
controlNetReset: () => {
|
||||
return { ...initialControlNetState };
|
||||
},
|
||||
isIPAdapterEnableToggled: (state) => {
|
||||
state.isIPAdapterEnabled = !state.isIPAdapterEnabled;
|
||||
isIPAdapterEnabledChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.isIPAdapterEnabled = action.payload;
|
||||
},
|
||||
ipAdapterImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
||||
state.ipAdapterInfo.adapterImage = action.payload;
|
||||
@ -450,7 +450,7 @@ export const {
|
||||
controlNetRemoved,
|
||||
controlNetImageChanged,
|
||||
controlNetProcessedImageChanged,
|
||||
controlNetToggled,
|
||||
controlNetIsEnabledChanged,
|
||||
controlNetModelChanged,
|
||||
controlNetWeightChanged,
|
||||
controlNetBeginStepPctChanged,
|
||||
@ -461,7 +461,7 @@ export const {
|
||||
controlNetProcessorTypeChanged,
|
||||
controlNetReset,
|
||||
controlNetAutoConfigToggled,
|
||||
isIPAdapterEnableToggled,
|
||||
isIPAdapterEnabledChanged,
|
||||
ipAdapterImageChanged,
|
||||
ipAdapterWeightChanged,
|
||||
ipAdapterModelChanged,
|
||||
|
@ -2,7 +2,6 @@ import {
|
||||
DragOverlay,
|
||||
MouseSensor,
|
||||
TouchSensor,
|
||||
pointerWithin,
|
||||
useSensor,
|
||||
useSensors,
|
||||
} from '@dnd-kit/core';
|
||||
@ -14,6 +13,7 @@ import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
||||
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
|
||||
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
|
||||
import { customPointerWithin } from '../util/customPointerWithin';
|
||||
import { DndContextTypesafe } from './DndContextTypesafe';
|
||||
import DragPreview from './DragPreview';
|
||||
|
||||
@ -77,7 +77,7 @@ const AppDndContext = (props: PropsWithChildren) => {
|
||||
onDragStart={handleDragStart}
|
||||
onDragEnd={handleDragEnd}
|
||||
sensors={sensors}
|
||||
collisionDetection={pointerWithin}
|
||||
collisionDetection={customPointerWithin}
|
||||
autoScroll={false}
|
||||
>
|
||||
{props.children}
|
||||
@ -87,7 +87,7 @@ const AppDndContext = (props: PropsWithChildren) => {
|
||||
style={{
|
||||
width: 'min-content',
|
||||
height: 'min-content',
|
||||
cursor: 'none',
|
||||
cursor: 'grabbing',
|
||||
userSelect: 'none',
|
||||
// expand overlay to prevent cursor from going outside it and displaying
|
||||
padding: '10rem',
|
||||
|
@ -0,0 +1,38 @@
|
||||
import { CollisionDetection, pointerWithin } from '@dnd-kit/core';
|
||||
|
||||
/**
|
||||
* Filters out droppable elements that are overflowed, then applies the pointerWithin collision detection.
|
||||
*
|
||||
* Fixes collision detection firing on droppables that are not visible, having been scrolled out of view.
|
||||
*
|
||||
* See https://github.com/clauderic/dnd-kit/issues/1198
|
||||
*/
|
||||
export const customPointerWithin: CollisionDetection = (arg) => {
|
||||
if (!arg.pointerCoordinates) {
|
||||
// sanity check
|
||||
return [];
|
||||
}
|
||||
|
||||
// Get all elements at the pointer coordinates. This excludes elements which are overflowed,
|
||||
// so it won't include the droppable elements that are scrolled out of view.
|
||||
const targetElements = document.elementsFromPoint(
|
||||
arg.pointerCoordinates.x,
|
||||
arg.pointerCoordinates.y
|
||||
);
|
||||
|
||||
const filteredDroppableContainers = arg.droppableContainers.filter(
|
||||
(container) => {
|
||||
if (!container.node.current) {
|
||||
return false;
|
||||
}
|
||||
// Only include droppable elements that are in the list of elements at the pointer coordinates.
|
||||
return targetElements.includes(container.node.current);
|
||||
}
|
||||
);
|
||||
|
||||
// Run the provided collision detection with the filtered droppable elements.
|
||||
return pointerWithin({
|
||||
...arg,
|
||||
droppableContainers: filteredDroppableContainers,
|
||||
});
|
||||
};
|
@ -5,7 +5,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { MouseEvent, useCallback, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import {
|
||||
Background,
|
||||
@ -21,6 +21,7 @@ import {
|
||||
OnSelectionChangeFunc,
|
||||
ProOptions,
|
||||
ReactFlow,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { useIsValidConnection } from '../../hooks/useIsValidConnection';
|
||||
import {
|
||||
@ -79,7 +80,8 @@ export const Flow = () => {
|
||||
const edges = useAppSelector((state) => state.nodes.edges);
|
||||
const viewport = useAppSelector((state) => state.nodes.viewport);
|
||||
const { shouldSnapToGrid, selectionMode } = useAppSelector(selector);
|
||||
|
||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||
const cursorPosition = useRef<XYPosition>();
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
const [borderRadius] = useToken('radii', ['base']);
|
||||
@ -154,6 +156,17 @@ export const Flow = () => {
|
||||
flow.fitView();
|
||||
}, []);
|
||||
|
||||
const onMouseMove = useCallback((event: MouseEvent<HTMLDivElement>) => {
|
||||
const bounds = flowWrapper.current?.getBoundingClientRect();
|
||||
if (bounds) {
|
||||
const pos = $flow.get()?.project({
|
||||
x: event.clientX - bounds.left,
|
||||
y: event.clientY - bounds.top,
|
||||
});
|
||||
cursorPosition.current = pos;
|
||||
}
|
||||
}, []);
|
||||
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionCopied());
|
||||
@ -166,18 +179,20 @@ export const Flow = () => {
|
||||
|
||||
useHotkeys(['Ctrl+v', 'Meta+v'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionPasted());
|
||||
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
|
||||
});
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
id="workflow-editor"
|
||||
ref={flowWrapper}
|
||||
defaultViewport={viewport}
|
||||
nodeTypes={nodeTypes}
|
||||
edgeTypes={edgeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onInit={onInit}
|
||||
onMouseMove={onMouseMove}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
|
@ -31,7 +31,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <Box p={2}>Output field in input: {field?.type}</Box>;
|
||||
}
|
||||
|
||||
if (field?.type === 'string' && fieldTemplate?.type === 'string') {
|
||||
if (
|
||||
(field?.type === 'string' && fieldTemplate?.type === 'string') ||
|
||||
(field?.type === 'StringPolymorphic' &&
|
||||
fieldTemplate?.type === 'StringPolymorphic')
|
||||
) {
|
||||
return (
|
||||
<StringInputField
|
||||
nodeId={nodeId}
|
||||
@ -41,7 +45,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') {
|
||||
if (
|
||||
(field?.type === 'boolean' && fieldTemplate?.type === 'boolean') ||
|
||||
(field?.type === 'BooleanPolymorphic' &&
|
||||
fieldTemplate?.type === 'BooleanPolymorphic')
|
||||
) {
|
||||
return (
|
||||
<BooleanInputField
|
||||
nodeId={nodeId}
|
||||
@ -53,7 +61,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
|
||||
if (
|
||||
(field?.type === 'integer' && fieldTemplate?.type === 'integer') ||
|
||||
(field?.type === 'float' && fieldTemplate?.type === 'float')
|
||||
(field?.type === 'float' && fieldTemplate?.type === 'float') ||
|
||||
(field?.type === 'FloatPolymorphic' &&
|
||||
fieldTemplate?.type === 'FloatPolymorphic') ||
|
||||
(field?.type === 'IntegerPolymorphic' &&
|
||||
fieldTemplate?.type === 'IntegerPolymorphic')
|
||||
) {
|
||||
return (
|
||||
<NumberInputField
|
||||
@ -74,7 +86,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') {
|
||||
if (
|
||||
(field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') ||
|
||||
(field?.type === 'ImagePolymorphic' &&
|
||||
fieldTemplate?.type === 'ImagePolymorphic')
|
||||
) {
|
||||
return (
|
||||
<ImageInputField
|
||||
nodeId={nodeId}
|
||||
|
@ -4,12 +4,17 @@ import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
BooleanInputFieldTemplate,
|
||||
BooleanInputFieldValue,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const BooleanInputFieldComponent = (
|
||||
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
|
||||
props: FieldComponentProps<
|
||||
BooleanInputFieldValue | BooleanPolymorphicInputFieldValue,
|
||||
BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
|
@ -12,6 +12,8 @@ import {
|
||||
FieldComponentProps,
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
@ -19,7 +21,10 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/types';
|
||||
|
||||
const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||
props: FieldComponentProps<
|
||||
ImageInputFieldValue | ImagePolymorphicInputFieldValue,
|
||||
ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
@ -12,15 +12,25 @@ import {
|
||||
FieldComponentProps,
|
||||
FloatInputFieldTemplate,
|
||||
FloatInputFieldValue,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
const NumberInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
IntegerInputFieldValue | FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate | FloatInputFieldTemplate
|
||||
| IntegerInputFieldValue
|
||||
| IntegerPolymorphicInputFieldValue
|
||||
| FloatInputFieldValue
|
||||
| FloatPolymorphicInputFieldValue,
|
||||
| IntegerInputFieldTemplate
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| FloatPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
|
@ -6,11 +6,16 @@ import {
|
||||
StringInputFieldTemplate,
|
||||
StringInputFieldValue,
|
||||
FieldComponentProps,
|
||||
StringPolymorphicInputFieldValue,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const StringInputFieldComponent = (
|
||||
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
|
||||
props: FieldComponentProps<
|
||||
StringInputFieldValue | StringPolymorphicInputFieldValue,
|
||||
StringInputFieldTemplate | StringPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
@ -45,6 +45,7 @@ const NodeEditorPanelGroup = () => {
|
||||
<PanelGroup
|
||||
ref={panelGroupRef}
|
||||
id="workflow-panel-group"
|
||||
autoSaveId="workflow-panel-group"
|
||||
direction="vertical"
|
||||
style={{ height: '100%', width: '100%' }}
|
||||
storage={panelStorage}
|
||||
|
@ -5,6 +5,10 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
@ -21,7 +25,12 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
return [];
|
||||
}
|
||||
return map(nodeTemplate.inputs)
|
||||
.filter((field) => ['any', 'direct'].includes(field.input))
|
||||
.filter(
|
||||
(field) =>
|
||||
(['any', 'direct'].includes(field.input) ||
|
||||
POLYMORPHIC_TYPES.includes(field.type)) &&
|
||||
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
)
|
||||
.filter((field) => !field.ui_hidden)
|
||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||
.map((field) => field.name)
|
||||
|
@ -4,6 +4,10 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
POLYMORPHIC_TYPES,
|
||||
TYPES_WITH_INPUT_COMPONENTS,
|
||||
} from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
@ -21,7 +25,12 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
return [];
|
||||
}
|
||||
return map(nodeTemplate.inputs)
|
||||
.filter((field) => field.input === 'connection')
|
||||
.filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' &&
|
||||
!POLYMORPHIC_TYPES.includes(field.type)) ||
|
||||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
|
||||
)
|
||||
.filter((field) => !field.ui_hidden)
|
||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||
.map((field) => field.name)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { cloneDeep, forEach, isEqual, map, uniqBy } from 'lodash-es';
|
||||
import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es';
|
||||
import {
|
||||
addEdge,
|
||||
applyEdgeChanges,
|
||||
@ -16,9 +16,9 @@ import {
|
||||
OnConnectStartParams,
|
||||
SelectionMode,
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { sessionCanceled, sessionInvoked } from 'services/api/thunks/session';
|
||||
import { ImageField } from 'services/api/types';
|
||||
import {
|
||||
appSocketGeneratorProgress,
|
||||
@ -722,8 +722,30 @@ const nodesSlice = createSlice({
|
||||
selectionCopied: (state) => {
|
||||
state.nodesToCopy = state.nodes.filter((n) => n.selected).map(cloneDeep);
|
||||
state.edgesToCopy = state.edges.filter((e) => e.selected).map(cloneDeep);
|
||||
|
||||
if (state.nodesToCopy.length > 0) {
|
||||
const averagePosition = { x: 0, y: 0 };
|
||||
state.nodesToCopy.forEach((e) => {
|
||||
const xOffset = 0.15 * (e.width ?? 0);
|
||||
const yOffset = 0.5 * (e.height ?? 0);
|
||||
averagePosition.x += e.position.x + xOffset;
|
||||
averagePosition.y += e.position.y + yOffset;
|
||||
});
|
||||
|
||||
averagePosition.x /= state.nodesToCopy.length;
|
||||
averagePosition.y /= state.nodesToCopy.length;
|
||||
|
||||
state.nodesToCopy.forEach((e) => {
|
||||
e.position.x -= averagePosition.x;
|
||||
e.position.y -= averagePosition.y;
|
||||
});
|
||||
}
|
||||
},
|
||||
selectionPasted: (state) => {
|
||||
selectionPasted: (
|
||||
state,
|
||||
action: PayloadAction<{ cursorPosition?: XYPosition }>
|
||||
) => {
|
||||
const { cursorPosition } = action.payload;
|
||||
const newNodes = state.nodesToCopy.map(cloneDeep);
|
||||
const oldNodeIds = newNodes.map((n) => n.data.id);
|
||||
const newEdges = state.edgesToCopy
|
||||
@ -752,8 +774,8 @@ const nodesSlice = createSlice({
|
||||
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
node.position.x,
|
||||
node.position.y
|
||||
node.position.x + (cursorPosition?.x ?? 0),
|
||||
node.position.y + (cursorPosition?.y ?? 0)
|
||||
);
|
||||
|
||||
node.position = position;
|
||||
@ -853,28 +875,10 @@ const nodesSlice = createSlice({
|
||||
node.progressImage = progress_image ?? null;
|
||||
}
|
||||
});
|
||||
builder.addCase(sessionInvoked.fulfilled, (state) => {
|
||||
forEach(state.nodeExecutionStates, (nes) => {
|
||||
nes.status = NodeStatus.PENDING;
|
||||
nes.error = null;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
nes.outputs = [];
|
||||
});
|
||||
});
|
||||
builder.addCase(sessionCanceled.fulfilled, (state) => {
|
||||
map(state.nodeExecutionStates, (nes) => {
|
||||
if (nes.status === NodeStatus.IN_PROGRESS) {
|
||||
nes.status = NodeStatus.PENDING;
|
||||
}
|
||||
});
|
||||
});
|
||||
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
|
||||
if (
|
||||
['completed', 'canceled', 'failed'].includes(action.payload.data.status)
|
||||
) {
|
||||
if (['in_progress'].includes(action.payload.data.status)) {
|
||||
forEach(state.nodeExecutionStates, (nes) => {
|
||||
nes.status = NodeStatus.PENDING;
|
||||
nes.status = NodeStatus.IN_PROGRESS;
|
||||
nes.error = null;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
|
@ -102,6 +102,29 @@ export const POLYMORPHIC_TO_SINGLE_MAP = {
|
||||
T2IAdapterPolymorphic: 'T2IAdapterField',
|
||||
};
|
||||
|
||||
export const TYPES_WITH_INPUT_COMPONENTS = [
|
||||
'string',
|
||||
'StringPolymorphic',
|
||||
'boolean',
|
||||
'BooleanPolymorphic',
|
||||
'integer',
|
||||
'float',
|
||||
'FloatPolymorphic',
|
||||
'IntegerPolymorphic',
|
||||
'enum',
|
||||
'ImageField',
|
||||
'ImagePolymorphic',
|
||||
'MainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
'ColorField',
|
||||
'SDXLMainModelField',
|
||||
'Scheduler',
|
||||
'IPAdapterModelField',
|
||||
];
|
||||
|
||||
export const isPolymorphicItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||
|
@ -224,7 +224,7 @@ export type IntegerCollectionInputFieldValue = z.infer<
|
||||
|
||||
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IntegerPolymorphic'),
|
||||
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
|
||||
value: z.number().int().optional(),
|
||||
});
|
||||
export type IntegerPolymorphicInputFieldValue = z.infer<
|
||||
typeof zIntegerPolymorphicInputFieldValue
|
||||
@ -246,7 +246,7 @@ export type FloatCollectionInputFieldValue = z.infer<
|
||||
|
||||
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FloatPolymorphic'),
|
||||
value: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
value: z.number().optional(),
|
||||
});
|
||||
export type FloatPolymorphicInputFieldValue = z.infer<
|
||||
typeof zFloatPolymorphicInputFieldValue
|
||||
@ -268,7 +268,7 @@ export type StringCollectionInputFieldValue = z.infer<
|
||||
|
||||
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('StringPolymorphic'),
|
||||
value: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
value: z.string().optional(),
|
||||
});
|
||||
export type StringPolymorphicInputFieldValue = z.infer<
|
||||
typeof zStringPolymorphicInputFieldValue
|
||||
@ -290,7 +290,7 @@ export type BooleanCollectionInputFieldValue = z.infer<
|
||||
|
||||
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('BooleanPolymorphic'),
|
||||
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
|
||||
value: z.boolean().optional(),
|
||||
});
|
||||
export type BooleanPolymorphicInputFieldValue = z.infer<
|
||||
typeof zBooleanPolymorphicInputFieldValue
|
||||
@ -542,7 +542,7 @@ export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
||||
|
||||
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ImagePolymorphic'),
|
||||
value: z.union([zImageField, z.array(zImageField)]).optional(),
|
||||
value: zImageField.optional(),
|
||||
});
|
||||
export type ImagePolymorphicInputFieldValue = z.infer<
|
||||
typeof zImagePolymorphicInputFieldValue
|
||||
|
@ -8,7 +8,11 @@ import {
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import { CONTROL_NET_COLLECT, METADATA_ACCUMULATOR } from './constants';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CONTROL_NET_COLLECT,
|
||||
METADATA_ACCUMULATOR,
|
||||
} from './constants';
|
||||
|
||||
export const addControlNetToLinearGraph = (
|
||||
state: RootState,
|
||||
@ -100,6 +104,16 @@ export const addControlNetToLinearGraph = (
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
|
||||
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { IPAdapterInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import { IP_ADAPTER } from './constants';
|
||||
import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER } from './constants';
|
||||
|
||||
export const addIPAdapterToLinearGraph = (
|
||||
state: RootState,
|
||||
@ -55,5 +55,15 @@ export const addIPAdapterToLinearGraph = (
|
||||
field: 'ip_adapter',
|
||||
},
|
||||
});
|
||||
|
||||
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'ip_adapter',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -45,7 +45,6 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
seed,
|
||||
steps,
|
||||
vaePrecision,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
@ -339,7 +338,6 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
};
|
||||
|
@ -46,7 +46,6 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
seed,
|
||||
steps,
|
||||
vaePrecision,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
@ -321,7 +320,6 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
|
@ -49,7 +49,6 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
shouldFitToWidthHeight,
|
||||
width,
|
||||
height,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
seamlessXAxis,
|
||||
@ -349,7 +348,6 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
clip_skip: clipSkip,
|
||||
strength: strength,
|
||||
init_image: initialImage.imageName,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
|
@ -38,7 +38,6 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
seamlessXAxis,
|
||||
@ -243,7 +242,6 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
clip_skip: clipSkip,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
@ -13,16 +13,16 @@ import ParamClipSkip from './ParamClipSkip';
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state: RootState) => {
|
||||
const { clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
state.generation;
|
||||
|
||||
return { clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
|
||||
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
export default function ParamAdvancedCollapse() {
|
||||
const { clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
||||
useAppSelector(selector);
|
||||
const { t } = useTranslation();
|
||||
const activeLabel = useMemo(() => {
|
||||
@ -34,7 +34,7 @@ export default function ParamAdvancedCollapse() {
|
||||
activeLabel.push(t('parameters.gpuNoise'));
|
||||
}
|
||||
|
||||
if (clipSkip > 0) {
|
||||
if (clipSkip > 0 && model && model.base_model !== 'sdxl') {
|
||||
activeLabel.push(
|
||||
t('parameters.clipSkipWithLayerCount', { layerCount: clipSkip })
|
||||
);
|
||||
@ -49,15 +49,19 @@ export default function ParamAdvancedCollapse() {
|
||||
}
|
||||
|
||||
return activeLabel.join(', ');
|
||||
}, [clipSkip, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
|
||||
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
|
||||
|
||||
return (
|
||||
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
|
||||
<Flex sx={{ flexDir: 'column', gap: 2 }}>
|
||||
<ParamSeamless />
|
||||
<Divider />
|
||||
<ParamClipSkip />
|
||||
<Divider pt={2} />
|
||||
{model && model?.base_model !== 'sdxl' && (
|
||||
<>
|
||||
<ParamClipSkip />
|
||||
<Divider pt={2} />
|
||||
</>
|
||||
)}
|
||||
<ParamCpuNoiseToggle />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
|
@ -42,6 +42,10 @@ export default function ParamClipSkip() {
|
||||
return clipSkipMap[model.base_model].markers;
|
||||
}, [model]);
|
||||
|
||||
if (model?.base_model === 'sdxl') {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<IAIInformationalPopover details="clipSkip">
|
||||
<IAISlider
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import { Box, Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
@ -94,20 +94,22 @@ export default function ParamBoundingBoxSize() {
|
||||
}}
|
||||
>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<IAIInformationalPopover details="paramRatio">
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
width: 'full',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('parameters.aspectRatio')}
|
||||
</Text>
|
||||
</IAIInformationalPopover>
|
||||
<Box width="full">
|
||||
<IAIInformationalPopover details="paramRatio">
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
width: 'full',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('parameters.aspectRatio')}
|
||||
</Text>
|
||||
</IAIInformationalPopover>
|
||||
</Box>
|
||||
<Spacer />
|
||||
<ParamAspectRatio />
|
||||
<IAIIconButton
|
||||
|
@ -23,17 +23,18 @@ import { v4 as uuidv4 } from 'uuid';
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ controlNet }) => {
|
||||
const { controlNets, isEnabled, isIPAdapterEnabled } = controlNet;
|
||||
const { controlNets, isEnabled, isIPAdapterEnabled, ipAdapterInfo } =
|
||||
controlNet;
|
||||
|
||||
const validControlNets = getValidControlNets(controlNets);
|
||||
|
||||
const isIPAdapterValid = ipAdapterInfo.model && ipAdapterInfo.adapterImage;
|
||||
let activeLabel = undefined;
|
||||
|
||||
if (isEnabled && validControlNets.length > 0) {
|
||||
activeLabel = `${validControlNets.length} ControlNet`;
|
||||
}
|
||||
|
||||
if (isIPAdapterEnabled) {
|
||||
if (isIPAdapterEnabled && isIPAdapterValid) {
|
||||
if (activeLabel) {
|
||||
activeLabel = `${activeLabel}, IP Adapter`;
|
||||
} else {
|
||||
|
@ -61,7 +61,7 @@ const ParamIterations = ({ asSlider }: Props) => {
|
||||
}, [dispatch, initial]);
|
||||
|
||||
return asSlider || shouldUseSliders ? (
|
||||
<IAIInformationalPopover details="paramImages">
|
||||
<IAIInformationalPopover details="paramIterations">
|
||||
<IAISlider
|
||||
label={t('parameters.iterations')}
|
||||
step={step}
|
||||
@ -77,7 +77,7 @@ const ParamIterations = ({ asSlider }: Props) => {
|
||||
/>
|
||||
</IAIInformationalPopover>
|
||||
) : (
|
||||
<IAIInformationalPopover details="paramImages">
|
||||
<IAIInformationalPopover details="paramIterations">
|
||||
<IAINumberInput
|
||||
label={t('parameters.iterations')}
|
||||
step={step}
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIInformationalPopover from 'common/components/IAIInformationalPopover';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
@ -9,7 +10,6 @@ import { ChangeEvent, KeyboardEvent, memo, useCallback, useRef } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
import IAIInformationalPopover from 'common/components/IAIInformationalPopover';
|
||||
|
||||
const ParamNegativeConditioning = () => {
|
||||
const negativePrompt = useAppSelector(
|
||||
@ -76,13 +76,16 @@ const ParamNegativeConditioning = () => {
|
||||
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAIInformationalPopover details="paramNegativeConditioning">
|
||||
<IAIInformationalPopover
|
||||
placement="right"
|
||||
details="paramNegativeConditioning"
|
||||
>
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="negativePrompt"
|
||||
name="negativePrompt"
|
||||
@ -95,20 +98,20 @@ const ParamNegativeConditioning = () => {
|
||||
minH={16}
|
||||
{...(isEmbeddingEnabled && { onKeyDown: handleKeyDown })}
|
||||
/>
|
||||
</IAIInformationalPopover>
|
||||
</ParamEmbeddingPopover>
|
||||
{!isOpen && isEmbeddingEnabled && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</FormControl>
|
||||
</ParamEmbeddingPopover>
|
||||
{!isOpen && isEmbeddingEnabled && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</FormControl>
|
||||
</IAIInformationalPopover>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -104,13 +104,16 @@ const ParamPositiveConditioning = () => {
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAIInformationalPopover details="paramPositiveConditioning">
|
||||
<IAIInformationalPopover
|
||||
placement="right"
|
||||
details="paramPositiveConditioning"
|
||||
>
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
@ -122,9 +125,9 @@ const ParamPositiveConditioning = () => {
|
||||
resize="vertical"
|
||||
minH={32}
|
||||
/>
|
||||
</IAIInformationalPopover>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
</IAIInformationalPopover>
|
||||
{!isOpen && isEmbeddingEnabled && (
|
||||
<Box
|
||||
sx={{
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import { Box, Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
@ -84,20 +84,21 @@ export default function ParamSize() {
|
||||
}}
|
||||
>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<IAIInformationalPopover details="paramRatio">
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
width: 'full',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('parameters.aspectRatio')}
|
||||
</Text>
|
||||
</IAIInformationalPopover>
|
||||
<Box width="full">
|
||||
<IAIInformationalPopover details="paramRatio">
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('parameters.aspectRatio')}
|
||||
</Text>
|
||||
</IAIInformationalPopover>
|
||||
</Box>
|
||||
<Spacer />
|
||||
<ParamAspectRatio />
|
||||
<IAIIconButton
|
||||
|
@ -119,8 +119,8 @@ const ParamMainModelSelect = () => {
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<IAIInformationalPopover details="paramModel" placement="bottom">
|
||||
<Flex w="100%" alignItems="center" gap={3}>
|
||||
<Flex w="100%" alignItems="center" gap={3}>
|
||||
<IAIInformationalPopover details="paramModel" placement="bottom">
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.model')}
|
||||
@ -134,13 +134,13 @@ const ParamMainModelSelect = () => {
|
||||
onChange={handleChangeModel}
|
||||
w="100%"
|
||||
/>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
</IAIInformationalPopover>
|
||||
</IAIInformationalPopover>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -21,7 +21,7 @@ import {
|
||||
loraModelsAdapter,
|
||||
useGetLoRAModelsQuery,
|
||||
} from '../../../services/api/endpoints/models';
|
||||
import { loraRecalled } from '../../lora/store/loraSlice';
|
||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import {
|
||||
setCfgScale,
|
||||
@ -509,6 +509,7 @@ export const useRecallParameters = () => {
|
||||
dispatch(setRefinerStart(refiner_start));
|
||||
}
|
||||
|
||||
dispatch(lorasCleared());
|
||||
loras?.forEach((lora) => {
|
||||
const result = prepareLoRAMetadataItem(lora);
|
||||
if (result.lora) {
|
||||
|
@ -12,12 +12,12 @@ type Props = {
|
||||
|
||||
const CancelCurrentQueueItemButton = ({ asIconButton, sx }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { cancelQueueItem, isLoading, currentQueueItemId } =
|
||||
const { cancelQueueItem, isLoading, isDisabled } =
|
||||
useCancelCurrentQueueItem();
|
||||
|
||||
return (
|
||||
<QueueButton
|
||||
isDisabled={!currentQueueItemId}
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
asIconButton={asIconButton}
|
||||
label={t('queue.cancel')}
|
||||
|
@ -0,0 +1,22 @@
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearInvocationCache } from '../hooks/useClearInvocationCache';
|
||||
|
||||
const ClearInvocationCacheButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const { clearInvocationCache, isDisabled, isLoading } =
|
||||
useClearInvocationCache();
|
||||
|
||||
return (
|
||||
<IAIButton
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
onClick={clearInvocationCache}
|
||||
>
|
||||
{t('invocationCache.clear')}
|
||||
</IAIButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ClearInvocationCacheButton);
|
@ -13,7 +13,7 @@ type Props = {
|
||||
|
||||
const ClearQueueButton = ({ asIconButton, sx }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { clearQueue, isLoading, queueStatus } = useClearQueue();
|
||||
const { clearQueue, isLoading, isDisabled } = useClearQueue();
|
||||
|
||||
return (
|
||||
<IAIAlertDialog
|
||||
@ -22,7 +22,7 @@ const ClearQueueButton = ({ asIconButton, sx }: Props) => {
|
||||
acceptButtonText={t('queue.clear')}
|
||||
triggerComponent={
|
||||
<QueueButton
|
||||
isDisabled={!queueStatus?.queue.total}
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
asIconButton={asIconButton}
|
||||
label={t('queue.clear')}
|
||||
|
@ -0,0 +1,55 @@
|
||||
import { ButtonGroup } from '@chakra-ui/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
|
||||
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
|
||||
import StatusStatGroup from './common/StatusStatGroup';
|
||||
import StatusStatItem from './common/StatusStatItem';
|
||||
|
||||
const InvocationCacheStatus = () => {
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery(undefined);
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
|
||||
pollingInterval:
|
||||
isConnected &&
|
||||
queueStatus?.processor.is_started &&
|
||||
queueStatus?.queue.pending > 0
|
||||
? 5000
|
||||
: 0,
|
||||
});
|
||||
|
||||
return (
|
||||
<StatusStatGroup>
|
||||
<StatusStatItem
|
||||
isDisabled={!cacheStatus?.enabled}
|
||||
label={t('invocationCache.cacheSize')}
|
||||
value={cacheStatus?.size ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
isDisabled={!cacheStatus?.enabled}
|
||||
label={t('invocationCache.hits')}
|
||||
value={cacheStatus?.hits ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
isDisabled={!cacheStatus?.enabled}
|
||||
label={t('invocationCache.misses')}
|
||||
value={cacheStatus?.misses ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
isDisabled={!cacheStatus?.enabled}
|
||||
label={t('invocationCache.maxCacheSize')}
|
||||
value={cacheStatus?.max_size ?? 0}
|
||||
/>
|
||||
<ButtonGroup w={24} orientation="vertical" size="xs">
|
||||
<ClearInvocationCacheButton />
|
||||
<ToggleInvocationCacheButton />
|
||||
</ButtonGroup>
|
||||
</StatusStatGroup>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InvocationCacheStatus);
|
@ -10,14 +10,14 @@ type Props = {
|
||||
|
||||
const PauseProcessorButton = ({ asIconButton }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { pauseProcessor, isLoading, isStarted } = usePauseProcessor();
|
||||
const { pauseProcessor, isLoading, isDisabled } = usePauseProcessor();
|
||||
|
||||
return (
|
||||
<QueueButton
|
||||
asIconButton={asIconButton}
|
||||
label={t('queue.pause')}
|
||||
tooltip={t('queue.pauseTooltip')}
|
||||
isDisabled={!isStarted}
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
icon={<FaPause />}
|
||||
onClick={pauseProcessor}
|
||||
|
@ -10,11 +10,11 @@ type Props = {
|
||||
|
||||
const PruneQueueButton = ({ asIconButton }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { pruneQueue, isLoading, finishedCount } = usePruneQueue();
|
||||
const { pruneQueue, isLoading, finishedCount, isDisabled } = usePruneQueue();
|
||||
|
||||
return (
|
||||
<QueueButton
|
||||
isDisabled={!finishedCount}
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
asIconButton={asIconButton}
|
||||
label={t('queue.prune')}
|
||||
|
@ -1,9 +1,10 @@
|
||||
import { ChakraProps } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useQueueBack } from '../hooks/useQueueBack';
|
||||
import EnqueueButtonTooltip from './QueueButtonTooltip';
|
||||
import QueueButton from './common/QueueButton';
|
||||
import { ChakraProps } from '@chakra-ui/react';
|
||||
import GreyscaleInvokeAIIcon from 'common/components/GreyscaleInvokeAIIcon';
|
||||
|
||||
type Props = {
|
||||
asIconButton?: boolean;
|
||||
@ -23,6 +24,7 @@ const QueueBackButton = ({ asIconButton, sx }: Props) => {
|
||||
onClick={queueBack}
|
||||
tooltip={<EnqueueButtonTooltip />}
|
||||
sx={sx}
|
||||
icon={asIconButton ? <GreyscaleInvokeAIIcon /> : undefined}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,38 +1,39 @@
|
||||
import { Stat, StatGroup, StatLabel, StatNumber } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import StatusStatGroup from './common/StatusStatGroup';
|
||||
import StatusStatItem from './common/StatusStatItem';
|
||||
|
||||
const QueueStatus = () => {
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<StatGroup alignItems="center" justifyContent="center" w="full" h="full">
|
||||
<Stat w={24}>
|
||||
<StatLabel>{t('queue.in_progress')}</StatLabel>
|
||||
<StatNumber>{queueStatus?.queue.in_progress ?? 0}</StatNumber>
|
||||
</Stat>
|
||||
<Stat w={24}>
|
||||
<StatLabel>{t('queue.pending')}</StatLabel>
|
||||
<StatNumber>{queueStatus?.queue.pending ?? 0}</StatNumber>
|
||||
</Stat>
|
||||
<Stat w={24}>
|
||||
<StatLabel>{t('queue.completed')}</StatLabel>
|
||||
<StatNumber>{queueStatus?.queue.completed ?? 0}</StatNumber>
|
||||
</Stat>
|
||||
<Stat w={24}>
|
||||
<StatLabel>{t('queue.failed')}</StatLabel>
|
||||
<StatNumber>{queueStatus?.queue.failed ?? 0}</StatNumber>
|
||||
</Stat>
|
||||
<Stat w={24}>
|
||||
<StatLabel>{t('queue.canceled')}</StatLabel>
|
||||
<StatNumber>{queueStatus?.queue.canceled ?? 0}</StatNumber>
|
||||
</Stat>
|
||||
<Stat w={24}>
|
||||
<StatLabel>{t('queue.total')}</StatLabel>
|
||||
<StatNumber>{queueStatus?.queue.total}</StatNumber>
|
||||
</Stat>
|
||||
</StatGroup>
|
||||
<StatusStatGroup>
|
||||
<StatusStatItem
|
||||
label={t('queue.in_progress')}
|
||||
value={queueStatus?.queue.in_progress ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
label={t('queue.pending')}
|
||||
value={queueStatus?.queue.pending ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
label={t('queue.completed')}
|
||||
value={queueStatus?.queue.completed ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
label={t('queue.failed')}
|
||||
value={queueStatus?.queue.failed ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
label={t('queue.canceled')}
|
||||
value={queueStatus?.queue.canceled ?? 0}
|
||||
/>
|
||||
<StatusStatItem
|
||||
label={t('queue.total')}
|
||||
value={queueStatus?.queue.total ?? 0}
|
||||
/>
|
||||
</StatusStatGroup>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,16 +1,14 @@
|
||||
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import ClearQueueButton from './ClearQueueButton';
|
||||
import PauseProcessorButton from './PauseProcessorButton';
|
||||
import PruneQueueButton from './PruneQueueButton';
|
||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||
import InvocationCacheStatus from './InvocationCacheStatus';
|
||||
import QueueList from './QueueList/QueueList';
|
||||
import QueueStatus from './QueueStatus';
|
||||
import ResumeProcessorButton from './ResumeProcessorButton';
|
||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||
import QueueTabQueueControls from './QueueTabQueueControls';
|
||||
|
||||
const QueueTabContent = () => {
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
||||
const isInvocationCacheEnabled =
|
||||
useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -23,33 +21,9 @@ const QueueTabContent = () => {
|
||||
gap={2}
|
||||
>
|
||||
<Flex gap={2} w="full">
|
||||
<Flex layerStyle="second" borderRadius="base" p={2} gap={2}>
|
||||
{isPauseEnabled || isResumeEnabled ? (
|
||||
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
|
||||
{isResumeEnabled ? <ResumeProcessorButton /> : <></>}
|
||||
{isPauseEnabled ? <PauseProcessorButton /> : <></>}
|
||||
</ButtonGroup>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
|
||||
<PruneQueueButton />
|
||||
<ClearQueueButton />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
borderRadius="base"
|
||||
flexDir="column"
|
||||
py={2}
|
||||
px={3}
|
||||
gap={2}
|
||||
>
|
||||
<QueueStatus />
|
||||
</Flex>
|
||||
{/* <QueueStatusCard />
|
||||
<CurrentQueueItemCard />
|
||||
<NextQueueItemCard /> */}
|
||||
<QueueTabQueueControls />
|
||||
<QueueStatus />
|
||||
{isInvocationCacheEnabled && <InvocationCacheStatus />}
|
||||
</Flex>
|
||||
<Box layerStyle="second" p={2} borderRadius="base" w="full" h="full">
|
||||
<QueueList />
|
||||
|
@ -0,0 +1,30 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import ClearQueueButton from './ClearQueueButton';
|
||||
import PauseProcessorButton from './PauseProcessorButton';
|
||||
import PruneQueueButton from './PruneQueueButton';
|
||||
import ResumeProcessorButton from './ResumeProcessorButton';
|
||||
|
||||
const QueueTabQueueControls = () => {
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
||||
return (
|
||||
<Flex layerStyle="second" borderRadius="base" p={2} gap={2}>
|
||||
{isPauseEnabled || isResumeEnabled ? (
|
||||
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
|
||||
{isResumeEnabled ? <ResumeProcessorButton /> : <></>}
|
||||
{isPauseEnabled ? <PauseProcessorButton /> : <></>}
|
||||
</ButtonGroup>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
<ButtonGroup w={28} orientation="vertical" isAttached size="sm">
|
||||
<PruneQueueButton />
|
||||
<ClearQueueButton />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QueueTabQueueControls);
|
@ -10,14 +10,14 @@ type Props = {
|
||||
|
||||
const ResumeProcessorButton = ({ asIconButton }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { resumeProcessor, isLoading, isStarted } = useResumeProcessor();
|
||||
const { resumeProcessor, isLoading, isDisabled } = useResumeProcessor();
|
||||
|
||||
return (
|
||||
<QueueButton
|
||||
asIconButton={asIconButton}
|
||||
label={t('queue.resume')}
|
||||
tooltip={t('queue.resumeTooltip')}
|
||||
isDisabled={isStarted}
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
icon={<FaPlay />}
|
||||
onClick={resumeProcessor}
|
||||
|
@ -0,0 +1,47 @@
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||
import { useDisableInvocationCache } from '../hooks/useDisableInvocationCache';
|
||||
import { useEnableInvocationCache } from '../hooks/useEnableInvocationCache';
|
||||
|
||||
const ToggleInvocationCacheButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
|
||||
|
||||
const {
|
||||
enableInvocationCache,
|
||||
isDisabled: isEnableDisabled,
|
||||
isLoading: isEnableLoading,
|
||||
} = useEnableInvocationCache();
|
||||
|
||||
const {
|
||||
disableInvocationCache,
|
||||
isDisabled: isDisableDisabled,
|
||||
isLoading: isDisableLoading,
|
||||
} = useDisableInvocationCache();
|
||||
|
||||
if (cacheStatus?.enabled) {
|
||||
return (
|
||||
<IAIButton
|
||||
isDisabled={isDisableDisabled}
|
||||
isLoading={isDisableLoading}
|
||||
onClick={disableInvocationCache}
|
||||
>
|
||||
{t('invocationCache.disable')}
|
||||
</IAIButton>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<IAIButton
|
||||
isDisabled={isEnableDisabled}
|
||||
isLoading={isEnableLoading}
|
||||
onClick={enableInvocationCache}
|
||||
>
|
||||
{t('invocationCache.enable')}
|
||||
</IAIButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ToggleInvocationCacheButton);
|
@ -1,27 +0,0 @@
|
||||
import { ButtonGroup, ButtonGroupProps, Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import ClearQueueButton from './ClearQueueButton';
|
||||
import PauseProcessorButton from './PauseProcessorButton';
|
||||
import PruneQueueButton from './PruneQueueButton';
|
||||
import ResumeProcessorButton from './ResumeProcessorButton';
|
||||
|
||||
type Props = ButtonGroupProps & {
|
||||
asIconButtons?: boolean;
|
||||
};
|
||||
|
||||
const VerticalQueueControls = ({ asIconButtons, ...rest }: Props) => {
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<ButtonGroup w="full" isAttached {...rest}>
|
||||
<ResumeProcessorButton asIconButton={asIconButtons} />
|
||||
<PauseProcessorButton asIconButton={asIconButtons} />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup w="full" isAttached {...rest}>
|
||||
<PruneQueueButton asIconButton={asIconButtons} />
|
||||
<ClearQueueButton asIconButton={asIconButtons} />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VerticalQueueControls);
|
@ -0,0 +1,22 @@
|
||||
import { StatGroup, StatGroupProps } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
|
||||
const StatusStatGroup = ({ children, ...rest }: StatGroupProps) => (
|
||||
<StatGroup
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
w="full"
|
||||
h="full"
|
||||
layerStyle="second"
|
||||
borderRadius="base"
|
||||
py={2}
|
||||
px={3}
|
||||
gap={6}
|
||||
flexWrap="nowrap"
|
||||
{...rest}
|
||||
>
|
||||
{children}
|
||||
</StatGroup>
|
||||
);
|
||||
|
||||
export default memo(StatusStatGroup);
|
@ -0,0 +1,47 @@
|
||||
import {
|
||||
ChakraProps,
|
||||
Stat,
|
||||
StatLabel,
|
||||
StatNumber,
|
||||
StatProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
|
||||
const sx: ChakraProps['sx'] = {
|
||||
'&[aria-disabled="true"]': {
|
||||
color: 'base.400',
|
||||
_dark: {
|
||||
color: 'base.500',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
type Props = Omit<StatProps, 'children'> & {
|
||||
label: string;
|
||||
value: string | number;
|
||||
isDisabled?: boolean;
|
||||
};
|
||||
|
||||
const StatusStatItem = ({
|
||||
label,
|
||||
value,
|
||||
isDisabled = false,
|
||||
...rest
|
||||
}: Props) => (
|
||||
<Stat
|
||||
flexGrow={1}
|
||||
textOverflow="ellipsis"
|
||||
overflow="hidden"
|
||||
whiteSpace="nowrap"
|
||||
aria-disabled={isDisabled}
|
||||
sx={sx}
|
||||
{...rest}
|
||||
>
|
||||
<StatLabel textOverflow="ellipsis" overflow="hidden" whiteSpace="nowrap">
|
||||
{label}
|
||||
</StatLabel>
|
||||
<StatNumber>{value}</StatNumber>
|
||||
</Stat>
|
||||
);
|
||||
|
||||
export default memo(StatusStatItem);
|
@ -1,4 +1,4 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -8,6 +8,7 @@ import {
|
||||
} from 'services/api/endpoints/queue';
|
||||
|
||||
export const useCancelBatch = (batch_id: string) => {
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const { isCanceled } = useGetBatchStatusQuery(
|
||||
{ batch_id },
|
||||
{
|
||||
@ -49,5 +50,5 @@ export const useCancelBatch = (batch_id: string) => {
|
||||
}
|
||||
}, [batch_id, dispatch, isCanceled, t, trigger]);
|
||||
|
||||
return { cancelBatch, isLoading, isCanceled };
|
||||
return { cancelBatch, isLoading, isCanceled, isDisabled: !isConnected };
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -8,6 +8,7 @@ import {
|
||||
} from 'services/api/endpoints/queue';
|
||||
|
||||
export const useCancelCurrentQueueItem = () => {
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const [trigger, { isLoading }] = useCancelQueueItemMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
@ -38,5 +39,15 @@ export const useCancelCurrentQueueItem = () => {
|
||||
}
|
||||
}, [currentQueueItemId, dispatch, t, trigger]);
|
||||
|
||||
return { cancelQueueItem, isLoading, currentQueueItemId };
|
||||
const isDisabled = useMemo(
|
||||
() => !isConnected || !currentQueueItemId,
|
||||
[isConnected, currentQueueItemId]
|
||||
);
|
||||
|
||||
return {
|
||||
cancelQueueItem,
|
||||
isLoading,
|
||||
currentQueueItemId,
|
||||
isDisabled,
|
||||
};
|
||||
};
|
||||
|
@ -1,10 +1,11 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
|
||||
|
||||
export const useCancelQueueItem = (item_id: number) => {
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const [trigger, { isLoading }] = useCancelQueueItemMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
@ -27,5 +28,5 @@ export const useCancelQueueItem = (item_id: number) => {
|
||||
}
|
||||
}, [dispatch, item_id, t, trigger]);
|
||||
|
||||
return { cancelQueueItem, isLoading };
|
||||
return { cancelQueueItem, isLoading, isDisabled: !isConnected };
|
||||
};
|
||||
|
@ -0,0 +1,48 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useClearInvocationCacheMutation,
|
||||
useGetInvocationCacheStatusQuery,
|
||||
} from 'services/api/endpoints/appInfo';
|
||||
|
||||
export const useClearInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const [trigger, { isLoading }] = useClearInvocationCacheMutation({
|
||||
fixedCacheKey: 'clearInvocationCache',
|
||||
});
|
||||
|
||||
const isDisabled = useMemo(
|
||||
() => !cacheStatus?.size || !isConnected,
|
||||
[cacheStatus?.size, isConnected]
|
||||
);
|
||||
|
||||
const clearInvocationCache = useCallback(async () => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('invocationCache.clearSucceeded'),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('invocationCache.clearFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
}, [isDisabled, trigger, dispatch, t]);
|
||||
|
||||
return { clearInvocationCache, isLoading, cacheStatus, isDisabled };
|
||||
};
|
@ -1,6 +1,6 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useClearQueueMutation,
|
||||
@ -10,10 +10,9 @@ import { listCursorChanged, listPriorityChanged } from '../store/queueSlice';
|
||||
|
||||
export const useClearQueue = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const [trigger, { isLoading }] = useClearQueueMutation({
|
||||
fixedCacheKey: 'clearQueue',
|
||||
});
|
||||
@ -43,5 +42,10 @@ export const useClearQueue = () => {
|
||||
}
|
||||
}, [queueStatus?.queue.total, trigger, dispatch, t]);
|
||||
|
||||
return { clearQueue, isLoading, queueStatus };
|
||||
const isDisabled = useMemo(
|
||||
() => !isConnected || !queueStatus?.queue.total,
|
||||
[isConnected, queueStatus?.queue.total]
|
||||
);
|
||||
|
||||
return { clearQueue, isLoading, queueStatus, isDisabled };
|
||||
};
|
||||
|
@ -0,0 +1,48 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useDisableInvocationCacheMutation,
|
||||
useGetInvocationCacheStatusQuery,
|
||||
} from 'services/api/endpoints/appInfo';
|
||||
|
||||
export const useDisableInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const [trigger, { isLoading }] = useDisableInvocationCacheMutation({
|
||||
fixedCacheKey: 'disableInvocationCache',
|
||||
});
|
||||
|
||||
const isDisabled = useMemo(
|
||||
() => !cacheStatus?.enabled || !isConnected || cacheStatus?.max_size === 0,
|
||||
[cacheStatus?.enabled, cacheStatus?.max_size, isConnected]
|
||||
);
|
||||
|
||||
const disableInvocationCache = useCallback(async () => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('invocationCache.disableSucceeded'),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('invocationCache.disableFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
}, [isDisabled, trigger, dispatch, t]);
|
||||
|
||||
return { disableInvocationCache, isLoading, cacheStatus, isDisabled };
|
||||
};
|
@ -0,0 +1,48 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useEnableInvocationCacheMutation,
|
||||
useGetInvocationCacheStatusQuery,
|
||||
} from 'services/api/endpoints/appInfo';
|
||||
|
||||
export const useEnableInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const [trigger, { isLoading }] = useEnableInvocationCacheMutation({
|
||||
fixedCacheKey: 'enableInvocationCache',
|
||||
});
|
||||
|
||||
const isDisabled = useMemo(
|
||||
() => cacheStatus?.enabled || !isConnected || cacheStatus?.max_size === 0,
|
||||
[cacheStatus?.enabled, cacheStatus?.max_size, isConnected]
|
||||
);
|
||||
|
||||
const enableInvocationCache = useCallback(async () => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('invocationCache.enableSucceeded'),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('invocationCache.enableFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
}, [isDisabled, trigger, dispatch, t]);
|
||||
|
||||
return { enableInvocationCache, isLoading, cacheStatus, isDisabled };
|
||||
};
|
@ -1,4 +1,4 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -10,6 +10,7 @@ import {
|
||||
export const usePauseProcessor = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const [trigger, { isLoading }] = usePauseProcessorMutation({
|
||||
fixedCacheKey: 'pauseProcessor',
|
||||
@ -42,5 +43,10 @@ export const usePauseProcessor = () => {
|
||||
}
|
||||
}, [isStarted, trigger, dispatch, t]);
|
||||
|
||||
return { pauseProcessor, isLoading, isStarted };
|
||||
const isDisabled = useMemo(
|
||||
() => !isConnected || !isStarted,
|
||||
[isConnected, isStarted]
|
||||
);
|
||||
|
||||
return { pauseProcessor, isLoading, isStarted, isDisabled };
|
||||
};
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useGetQueueStatusQuery,
|
||||
@ -11,6 +11,7 @@ import { listCursorChanged, listPriorityChanged } from '../store/queueSlice';
|
||||
export const usePruneQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const [trigger, { isLoading }] = usePruneQueueMutation({
|
||||
fixedCacheKey: 'pruneQueue',
|
||||
});
|
||||
@ -51,5 +52,10 @@ export const usePruneQueue = () => {
|
||||
}
|
||||
}, [finishedCount, trigger, dispatch, t]);
|
||||
|
||||
return { pruneQueue, isLoading, finishedCount };
|
||||
const isDisabled = useMemo(
|
||||
() => !isConnected || !finishedCount,
|
||||
[finishedCount, isConnected]
|
||||
);
|
||||
|
||||
return { pruneQueue, isLoading, finishedCount, isDisabled };
|
||||
};
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user