mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/feat/config-migration
This commit is contained in:
commit
048306b417
@ -51,6 +51,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
|||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
|
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
@ -185,7 +186,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
|||||||
title="Create Gradient Mask",
|
title="Create Gradient Mask",
|
||||||
tags=["mask", "denoise"],
|
tags=["mask", "denoise"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class CreateGradientMaskInvocation(BaseInvocation):
|
class CreateGradientMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@ -198,6 +199,32 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
minimum_denoise: float = InputField(
|
minimum_denoise: float = InputField(
|
||||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||||
)
|
)
|
||||||
|
image: Optional[ImageField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||||
|
title="[OPTIONAL] Image",
|
||||||
|
ui_order=6,
|
||||||
|
)
|
||||||
|
unet: Optional[UNetField] = InputField(
|
||||||
|
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="[OPTIONAL] UNet",
|
||||||
|
ui_order=5,
|
||||||
|
)
|
||||||
|
vae: Optional[VAEField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||||
|
title="[OPTIONAL] VAE",
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=7,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||||
|
fp32: bool = InputField(
|
||||||
|
default=DEFAULT_PRECISION == "float32",
|
||||||
|
description=FieldDescriptions.fp32,
|
||||||
|
ui_order=9,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||||
@ -233,8 +260,27 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||||
|
|
||||||
|
masked_latents_name = None
|
||||||
|
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||||
|
# all three fields must be present at the same time
|
||||||
|
main_model_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
assert isinstance(main_model_config, MainConfigBase)
|
||||||
|
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||||
|
mask = blur_tensor
|
||||||
|
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
|
masked_latents = ImageToLatentsInvocation.vae_encode(
|
||||||
|
vae_info, self.fp32, self.tiled, masked_image.clone()
|
||||||
|
)
|
||||||
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||||
|
|
||||||
return GradientMaskOutput(
|
return GradientMaskOutput(
|
||||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||||
from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata
|
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import MaskOutput
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
|
|
||||||
|
|
||||||
@ -34,3 +35,86 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
width=self.width,
|
width=self.width,
|
||||||
height=self.height,
|
height=self.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"alpha_mask_to_tensor",
|
||||||
|
title="Alpha Mask to Tensor",
|
||||||
|
tags=["conditioning"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
)
|
||||||
|
class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||||
|
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The mask image to convert.")
|
||||||
|
invert: bool = InputField(default=False, description="Whether to invert the mask.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
|
||||||
|
if self.invert:
|
||||||
|
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
mask[0] = torch.tensor(np.array(image)[:, :, 3] > 0, dtype=torch.bool)
|
||||||
|
|
||||||
|
return MaskOutput(
|
||||||
|
mask=TensorField(tensor_name=context.tensors.save(mask)),
|
||||||
|
height=mask.shape[1],
|
||||||
|
width=mask.shape[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"invert_tensor_mask",
|
||||||
|
title="Invert Tensor Mask",
|
||||||
|
tags=["conditioning"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
)
|
||||||
|
class InvertTensorMaskInvocation(BaseInvocation):
|
||||||
|
"""Inverts a tensor mask."""
|
||||||
|
|
||||||
|
mask: TensorField = InputField(description="The tensor mask to convert.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
mask = context.tensors.load(self.mask.tensor_name)
|
||||||
|
inverted = ~mask
|
||||||
|
|
||||||
|
return MaskOutput(
|
||||||
|
mask=TensorField(tensor_name=context.tensors.save(inverted)),
|
||||||
|
height=inverted.shape[1],
|
||||||
|
width=inverted.shape[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"image_mask_to_tensor",
|
||||||
|
title="Image Mask to Tensor",
|
||||||
|
tags=["conditioning"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||||
|
"""Convert a mask image to a tensor. Converts the image to grayscale and uses thresholding at the specified value."""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The mask image to convert.")
|
||||||
|
cutoff: int = InputField(ge=0, le=255, description="Cutoff (<)", default=128)
|
||||||
|
invert: bool = InputField(default=False, description="Whether to invert the mask.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
image = context.images.get_pil(self.image.image_name, mode="L")
|
||||||
|
|
||||||
|
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
|
||||||
|
if self.invert:
|
||||||
|
mask[0] = torch.tensor(np.array(image)[:, :] >= self.cutoff, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
mask[0] = torch.tensor(np.array(image)[:, :] < self.cutoff, dtype=torch.bool)
|
||||||
|
|
||||||
|
return MaskOutput(
|
||||||
|
mask=TensorField(tensor_name=context.tensors.save(mask)),
|
||||||
|
height=mask.shape[1],
|
||||||
|
width=mask.shape[2],
|
||||||
|
)
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import locale
|
import locale
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import signal
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
@ -43,6 +42,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
|||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
|
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
@ -112,17 +112,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def start(self, invoker: Optional[Invoker] = None) -> None:
|
def start(self, invoker: Optional[Invoker] = None) -> None:
|
||||||
"""Start the installer thread."""
|
"""Start the installer thread."""
|
||||||
|
|
||||||
# Yes, this is weird. When the installer thread is running, the
|
|
||||||
# thread masks the ^C signal. When we receive a
|
|
||||||
# sigINT, we stop the thread, reset sigINT, and send a new
|
|
||||||
# sigINT to the parent process.
|
|
||||||
def sigint_handler(signum, frame):
|
|
||||||
self.stop()
|
|
||||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
|
||||||
signal.raise_signal(signal.SIGINT)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._running:
|
if self._running:
|
||||||
raise Exception("Attempt to start the installer service twice")
|
raise Exception("Attempt to start the installer service twice")
|
||||||
@ -132,7 +121,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
|
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
|
||||||
# We should only do the scan when the flag is set (which should only be set when testing).
|
# We should only do the scan when the flag is set (which should only be set when testing).
|
||||||
if self.app_config.scan_models_on_startup:
|
if self.app_config.scan_models_on_startup:
|
||||||
self._register_orphaned_models()
|
with catch_sigint():
|
||||||
|
self._register_orphaned_models()
|
||||||
|
|
||||||
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
|
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
|
||||||
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
|
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||||
|
|
||||||
@ -17,12 +17,6 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DeleteAllResult:
|
|
||||||
deleted_count: int
|
|
||||||
freed_space_bytes: float
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||||
|
|
||||||
@ -35,6 +29,12 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
|||||||
self._ephemeral = ephemeral
|
self._ephemeral = ephemeral
|
||||||
self._base_output_dir = output_dir
|
self._base_output_dir = output_dir
|
||||||
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self._ephemeral:
|
||||||
|
# Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown.
|
||||||
|
for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")):
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
||||||
self._tempdir = (
|
self._tempdir = (
|
||||||
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
||||||
|
@ -301,12 +301,12 @@ class MainConfigBase(ModelConfigBase):
|
|||||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||||
description="Default settings for this model", default=None
|
description="Default settings for this model", default=None
|
||||||
)
|
)
|
||||||
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||||
"""Model config for main checkpoint models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
StarterModel(
|
StarterModel(
|
||||||
name="IP Adapter",
|
name="IP Adapter",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="InvokeAI/ip_adapter_sd15",
|
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
|
||||||
description="IP-Adapter for SD 1.5 models",
|
description="IP-Adapter for SD 1.5 models",
|
||||||
type=ModelType.IPAdapter,
|
type=ModelType.IPAdapter,
|
||||||
dependencies=[ip_adapter_sd_image_encoder],
|
dependencies=[ip_adapter_sd_image_encoder],
|
||||||
@ -163,7 +163,7 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
StarterModel(
|
StarterModel(
|
||||||
name="IP Adapter Plus",
|
name="IP Adapter Plus",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="InvokeAI/ip_adapter_plus_sd15",
|
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
|
||||||
description="Refined IP-Adapter for SD 1.5 models",
|
description="Refined IP-Adapter for SD 1.5 models",
|
||||||
type=ModelType.IPAdapter,
|
type=ModelType.IPAdapter,
|
||||||
dependencies=[ip_adapter_sd_image_encoder],
|
dependencies=[ip_adapter_sd_image_encoder],
|
||||||
@ -171,7 +171,7 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
StarterModel(
|
StarterModel(
|
||||||
name="IP Adapter Plus Face",
|
name="IP Adapter Plus Face",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="InvokeAI/ip_adapter_plus_face_sd15",
|
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
|
||||||
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
|
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
|
||||||
type=ModelType.IPAdapter,
|
type=ModelType.IPAdapter,
|
||||||
dependencies=[ip_adapter_sd_image_encoder],
|
dependencies=[ip_adapter_sd_image_encoder],
|
||||||
@ -179,7 +179,7 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
StarterModel(
|
StarterModel(
|
||||||
name="IP Adapter SDXL",
|
name="IP Adapter SDXL",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="InvokeAI/ip_adapter_sdxl",
|
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
|
||||||
description="IP-Adapter for SDXL models",
|
description="IP-Adapter for SDXL models",
|
||||||
type=ModelType.IPAdapter,
|
type=ModelType.IPAdapter,
|
||||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||||
|
29
invokeai/backend/util/catch_sigint.py
Normal file
29
invokeai/backend/util/catch_sigint.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
This module defines a context manager `catch_sigint()` which temporarily replaces
|
||||||
|
the sigINT handler defined by the ASGI in order to allow the user to ^C the application
|
||||||
|
and shut it down immediately. This was implemented in order to allow the user to interrupt
|
||||||
|
slow model hashing during startup.
|
||||||
|
|
||||||
|
Use like this:
|
||||||
|
|
||||||
|
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||||
|
with catch_sigint():
|
||||||
|
run_some_hard_to_interrupt_process()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import signal
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
|
||||||
|
def sigint_handler(signum, frame): # type: ignore
|
||||||
|
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||||
|
signal.raise_signal(signal.SIGINT)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def catch_sigint() -> Generator[None, None, None]:
|
||||||
|
original_handler = signal.getsignal(signal.SIGINT)
|
||||||
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
yield
|
||||||
|
signal.signal(signal.SIGINT, original_handler)
|
@ -11,6 +11,7 @@ import { createStore } from '../src/app/store/store';
|
|||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
import translationEN from '../public/locales/en.json';
|
import translationEN from '../public/locales/en.json';
|
||||||
import { ReduxInit } from './ReduxInit';
|
import { ReduxInit } from './ReduxInit';
|
||||||
|
import { $store } from 'app/store/nanostores/store';
|
||||||
|
|
||||||
i18n.use(initReactI18next).init({
|
i18n.use(initReactI18next).init({
|
||||||
lng: 'en',
|
lng: 'en',
|
||||||
@ -25,6 +26,7 @@ i18n.use(initReactI18next).init({
|
|||||||
});
|
});
|
||||||
|
|
||||||
const store = createStore(undefined, false);
|
const store = createStore(undefined, false);
|
||||||
|
$store.set(store);
|
||||||
$baseUrl.set('http://localhost:9090');
|
$baseUrl.set('http://localhost:9090');
|
||||||
|
|
||||||
const preview: Preview = {
|
const preview: Preview = {
|
||||||
|
@ -25,7 +25,7 @@
|
|||||||
"typegen": "node scripts/typegen.js",
|
"typegen": "node scripts/typegen.js",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:knip": "knip",
|
"lint:knip": "knip",
|
||||||
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:0 src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
"lint:prettier": "prettier --check .",
|
"lint:prettier": "prettier --check .",
|
||||||
"lint:tsc": "tsc --noEmit",
|
"lint:tsc": "tsc --noEmit",
|
||||||
@ -95,6 +95,7 @@
|
|||||||
"reactflow": "^11.10.4",
|
"reactflow": "^11.10.4",
|
||||||
"redux-dynamic-middlewares": "^2.2.0",
|
"redux-dynamic-middlewares": "^2.2.0",
|
||||||
"redux-remember": "^5.1.0",
|
"redux-remember": "^5.1.0",
|
||||||
|
"redux-undo": "^1.1.0",
|
||||||
"rfdc": "^1.3.1",
|
"rfdc": "^1.3.1",
|
||||||
"roarr": "^7.21.1",
|
"roarr": "^7.21.1",
|
||||||
"serialize-error": "^11.0.3",
|
"serialize-error": "^11.0.3",
|
||||||
|
7
invokeai/frontend/web/pnpm-lock.yaml
generated
7
invokeai/frontend/web/pnpm-lock.yaml
generated
@ -140,6 +140,9 @@ dependencies:
|
|||||||
redux-remember:
|
redux-remember:
|
||||||
specifier: ^5.1.0
|
specifier: ^5.1.0
|
||||||
version: 5.1.0(redux@5.0.1)
|
version: 5.1.0(redux@5.0.1)
|
||||||
|
redux-undo:
|
||||||
|
specifier: ^1.1.0
|
||||||
|
version: 1.1.0
|
||||||
rfdc:
|
rfdc:
|
||||||
specifier: ^1.3.1
|
specifier: ^1.3.1
|
||||||
version: 1.3.1
|
version: 1.3.1
|
||||||
@ -11962,6 +11965,10 @@ packages:
|
|||||||
redux: 5.0.1
|
redux: 5.0.1
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/redux-undo@1.1.0:
|
||||||
|
resolution: {integrity: sha512-zzLFh2qeF0MTIlzDhDLm9NtkfBqCllQJ3OCuIl5RKlG/ayHw6GUdIFdMhzMS9NnrnWdBX5u//ExMOHpfudGGOg==}
|
||||||
|
dev: false
|
||||||
|
|
||||||
/redux@5.0.1:
|
/redux@5.0.1:
|
||||||
resolution: {integrity: sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==}
|
resolution: {integrity: sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==}
|
||||||
dev: false
|
dev: false
|
||||||
|
BIN
invokeai/frontend/web/public/assets/images/transparent_bg.png
Normal file
BIN
invokeai/frontend/web/public/assets/images/transparent_bg.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 KiB |
@ -84,6 +84,8 @@
|
|||||||
"direction": "Direction",
|
"direction": "Direction",
|
||||||
"ipAdapter": "IP Adapter",
|
"ipAdapter": "IP Adapter",
|
||||||
"t2iAdapter": "T2I Adapter",
|
"t2iAdapter": "T2I Adapter",
|
||||||
|
"positivePrompt": "Positive Prompt",
|
||||||
|
"negativePrompt": "Negative Prompt",
|
||||||
"discordLabel": "Discord",
|
"discordLabel": "Discord",
|
||||||
"dontAskMeAgain": "Don't ask me again",
|
"dontAskMeAgain": "Don't ask me again",
|
||||||
"error": "Error",
|
"error": "Error",
|
||||||
@ -136,7 +138,9 @@
|
|||||||
"red": "Red",
|
"red": "Red",
|
||||||
"green": "Green",
|
"green": "Green",
|
||||||
"blue": "Blue",
|
"blue": "Blue",
|
||||||
"alpha": "Alpha"
|
"alpha": "Alpha",
|
||||||
|
"selected": "Selected",
|
||||||
|
"viewer": "Viewer"
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
"controlAdapter_one": "Control Adapter",
|
"controlAdapter_one": "Control Adapter",
|
||||||
@ -893,6 +897,7 @@
|
|||||||
"denoisingStrength": "Denoising Strength",
|
"denoisingStrength": "Denoising Strength",
|
||||||
"downloadImage": "Download Image",
|
"downloadImage": "Download Image",
|
||||||
"general": "General",
|
"general": "General",
|
||||||
|
"globalSettings": "Global Settings",
|
||||||
"height": "Height",
|
"height": "Height",
|
||||||
"imageFit": "Fit Initial Image To Output Size",
|
"imageFit": "Fit Initial Image To Output Size",
|
||||||
"images": "Images",
|
"images": "Images",
|
||||||
@ -1505,5 +1510,27 @@
|
|||||||
},
|
},
|
||||||
"app": {
|
"app": {
|
||||||
"storeNotInitialized": "Store is not initialized"
|
"storeNotInitialized": "Store is not initialized"
|
||||||
|
},
|
||||||
|
"regionalPrompts": {
|
||||||
|
"deleteAll": "Delete All",
|
||||||
|
"addLayer": "Add Layer",
|
||||||
|
"moveToFront": "Move to Front",
|
||||||
|
"moveToBack": "Move to Back",
|
||||||
|
"moveForward": "Move Forward",
|
||||||
|
"moveBackward": "Move Backward",
|
||||||
|
"brushSize": "Brush Size",
|
||||||
|
"regionalControl": "Regional Control (ALPHA)",
|
||||||
|
"enableRegionalPrompts": "Enable $t(regionalPrompts.regionalPrompts)",
|
||||||
|
"globalMaskOpacity": "Global Mask Opacity",
|
||||||
|
"autoNegative": "Auto Negative",
|
||||||
|
"toggleVisibility": "Toggle Layer Visibility",
|
||||||
|
"deletePrompt": "Delete Prompt",
|
||||||
|
"resetRegion": "Reset Region",
|
||||||
|
"debugLayers": "Debug Layers",
|
||||||
|
"rectangle": "Rectangle",
|
||||||
|
"maskPreviewColor": "Mask Preview Color",
|
||||||
|
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||||
|
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||||
|
"addIPAdapter": "Add $t(common.ipAdapter)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,8 @@ export type LoggerNamespace =
|
|||||||
| 'socketio'
|
| 'socketio'
|
||||||
| 'session'
|
| 'session'
|
||||||
| 'queue'
|
| 'queue'
|
||||||
| 'dnd';
|
| 'dnd'
|
||||||
|
| 'regionalPrompts';
|
||||||
|
|
||||||
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
|
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
|
||||||
|
|
||||||
|
@ -21,6 +21,11 @@ import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workf
|
|||||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||||
import { queueSlice } from 'features/queue/store/queueSlice';
|
import { queueSlice } from 'features/queue/store/queueSlice';
|
||||||
|
import {
|
||||||
|
regionalPromptsPersistConfig,
|
||||||
|
regionalPromptsSlice,
|
||||||
|
regionalPromptsUndoableConfig,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { configSlice } from 'features/system/store/configSlice';
|
import { configSlice } from 'features/system/store/configSlice';
|
||||||
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
|
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
|
||||||
@ -30,6 +35,7 @@ import { defaultsDeep, keys, omit, pick } from 'lodash-es';
|
|||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
|
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
|
||||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||||
|
import undoable from 'redux-undo';
|
||||||
import { serializeError } from 'serialize-error';
|
import { serializeError } from 'serialize-error';
|
||||||
import { api } from 'services/api';
|
import { api } from 'services/api';
|
||||||
import { authToastMiddleware } from 'services/api/authToastMiddleware';
|
import { authToastMiddleware } from 'services/api/authToastMiddleware';
|
||||||
@ -59,6 +65,7 @@ const allReducers = {
|
|||||||
[queueSlice.name]: queueSlice.reducer,
|
[queueSlice.name]: queueSlice.reducer,
|
||||||
[workflowSlice.name]: workflowSlice.reducer,
|
[workflowSlice.name]: workflowSlice.reducer,
|
||||||
[hrfSlice.name]: hrfSlice.reducer,
|
[hrfSlice.name]: hrfSlice.reducer,
|
||||||
|
[regionalPromptsSlice.name]: undoable(regionalPromptsSlice.reducer, regionalPromptsUndoableConfig),
|
||||||
[api.reducerPath]: api.reducer,
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -103,6 +110,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
|||||||
[loraPersistConfig.name]: loraPersistConfig,
|
[loraPersistConfig.name]: loraPersistConfig,
|
||||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||||
|
[regionalPromptsPersistConfig.name]: regionalPromptsPersistConfig,
|
||||||
};
|
};
|
||||||
|
|
||||||
const unserialize: UnserializeFunction = (data, key) => {
|
const unserialize: UnserializeFunction = (data, key) => {
|
||||||
@ -114,6 +122,7 @@ const unserialize: UnserializeFunction = (data, key) => {
|
|||||||
try {
|
try {
|
||||||
const { initialState, migrate } = persistConfig;
|
const { initialState, migrate } = persistConfig;
|
||||||
const parsed = JSON.parse(data);
|
const parsed = JSON.parse(data);
|
||||||
|
|
||||||
// strip out old keys
|
// strip out old keys
|
||||||
const stripped = pick(parsed, keys(initialState));
|
const stripped = pick(parsed, keys(initialState));
|
||||||
// run (additive) migrations
|
// run (additive) migrations
|
||||||
@ -141,7 +150,9 @@ const serialize: SerializeFunction = (data, key) => {
|
|||||||
if (!persistConfig) {
|
if (!persistConfig) {
|
||||||
throw new Error(`No persist config for slice "${key}"`);
|
throw new Error(`No persist config for slice "${key}"`);
|
||||||
}
|
}
|
||||||
const result = omit(data, persistConfig.persistDenylist);
|
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
|
||||||
|
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data;
|
||||||
|
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist);
|
||||||
return JSON.stringify(result);
|
return JSON.stringify(result);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ const sx: ChakraProps['sx'] = {
|
|||||||
|
|
||||||
const colorPickerStyles: CSSProperties = { width: '100%' };
|
const colorPickerStyles: CSSProperties = { width: '100%' };
|
||||||
|
|
||||||
const numberInputWidth: ChakraProps['w'] = '4.2rem';
|
const numberInputWidth: ChakraProps['w'] = '3.5rem';
|
||||||
|
|
||||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||||
const { color, onChange, withNumberInput, ...rest } = props;
|
const { color, onChange, withNumberInput, ...rest } = props;
|
||||||
@ -41,7 +41,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
{withNumberInput && (
|
{withNumberInput && (
|
||||||
<Flex gap={5}>
|
<Flex gap={5}>
|
||||||
<FormControl gap={0}>
|
<FormControl gap={0}>
|
||||||
<FormLabel>{t('common.red')}</FormLabel>
|
<FormLabel>{t('common.red')[0]}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.r}
|
value={color.r}
|
||||||
onChange={handleChangeR}
|
onChange={handleChangeR}
|
||||||
@ -53,7 +53,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl gap={0}>
|
<FormControl gap={0}>
|
||||||
<FormLabel>{t('common.green')}</FormLabel>
|
<FormLabel>{t('common.green')[0]}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.g}
|
value={color.g}
|
||||||
onChange={handleChangeG}
|
onChange={handleChangeG}
|
||||||
@ -65,7 +65,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl gap={0}>
|
<FormControl gap={0}>
|
||||||
<FormLabel>{t('common.blue')}</FormLabel>
|
<FormLabel>{t('common.blue')[0]}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.b}
|
value={color.b}
|
||||||
onChange={handleChangeB}
|
onChange={handleChangeB}
|
||||||
@ -77,7 +77,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl gap={0}>
|
<FormControl gap={0}>
|
||||||
<FormLabel>{t('common.alpha')}</FormLabel>
|
<FormLabel>{t('common.alpha')[0]}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.a}
|
value={color.a}
|
||||||
onChange={handleChangeA}
|
onChange={handleChangeA}
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||||
|
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import type { CSSProperties } from 'react';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { RgbColorPicker as ColorfulRgbColorPicker } from 'react-colorful';
|
||||||
|
import type { ColorPickerBaseProps, RgbColor } from 'react-colorful/dist/types';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type RgbColorPickerProps = ColorPickerBaseProps<RgbColor> & {
|
||||||
|
withNumberInput?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const colorPickerPointerStyles: NonNullable<ChakraProps['sx']> = {
|
||||||
|
width: 6,
|
||||||
|
height: 6,
|
||||||
|
borderColor: 'base.100',
|
||||||
|
};
|
||||||
|
|
||||||
|
const sx: ChakraProps['sx'] = {
|
||||||
|
'.react-colorful__hue-pointer': colorPickerPointerStyles,
|
||||||
|
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
|
||||||
|
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
|
||||||
|
gap: 5,
|
||||||
|
flexDir: 'column',
|
||||||
|
};
|
||||||
|
|
||||||
|
const colorPickerStyles: CSSProperties = { width: '100%' };
|
||||||
|
|
||||||
|
const numberInputWidth: ChakraProps['w'] = '3.5rem';
|
||||||
|
|
||||||
|
const RgbColorPicker = (props: RgbColorPickerProps) => {
|
||||||
|
const { color, onChange, withNumberInput, ...rest } = props;
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const handleChangeR = useCallback((r: number) => onChange({ ...color, r }), [color, onChange]);
|
||||||
|
const handleChangeG = useCallback((g: number) => onChange({ ...color, g }), [color, onChange]);
|
||||||
|
const handleChangeB = useCallback((b: number) => onChange({ ...color, b }), [color, onChange]);
|
||||||
|
return (
|
||||||
|
<Flex sx={sx}>
|
||||||
|
<ColorfulRgbColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
|
||||||
|
{withNumberInput && (
|
||||||
|
<Flex gap={5}>
|
||||||
|
<FormControl gap={0}>
|
||||||
|
<FormLabel>{t('common.red')[0]}</FormLabel>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={color.r}
|
||||||
|
onChange={handleChangeR}
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
step={1}
|
||||||
|
w={numberInputWidth}
|
||||||
|
defaultValue={90}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl gap={0}>
|
||||||
|
<FormLabel>{t('common.green')[0]}</FormLabel>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={color.g}
|
||||||
|
onChange={handleChangeG}
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
step={1}
|
||||||
|
w={numberInputWidth}
|
||||||
|
defaultValue={90}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl gap={0}>
|
||||||
|
<FormLabel>{t('common.blue')[0]}</FormLabel>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={color.b}
|
||||||
|
onChange={handleChangeB}
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
step={1}
|
||||||
|
w={numberInputWidth}
|
||||||
|
defaultValue={255}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(RgbColorPicker);
|
85
invokeai/frontend/web/src/common/util/arrayUtils.test.ts
Normal file
85
invokeai/frontend/web/src/common/util/arrayUtils.test.ts
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
|
||||||
|
import { describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
|
describe('Array Manipulation Functions', () => {
|
||||||
|
const originalArray = ['a', 'b', 'c', 'd'];
|
||||||
|
describe('moveForwardOne', () => {
|
||||||
|
it('should move an item forward by one position', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveForward(array, (item) => item === 'b');
|
||||||
|
expect(result).toEqual(['a', 'c', 'b', 'd']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should do nothing if the item is at the end', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveForward(array, (item) => item === 'd');
|
||||||
|
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveForward(array, (item) => item === 'z');
|
||||||
|
expect(result).toEqual(originalArray);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('moveToFront', () => {
|
||||||
|
it('should move an item to the front', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveToFront(array, (item) => item === 'c');
|
||||||
|
expect(result).toEqual(['c', 'a', 'b', 'd']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should do nothing if the item is already at the front', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveToFront(array, (item) => item === 'a');
|
||||||
|
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveToFront(array, (item) => item === 'z');
|
||||||
|
expect(result).toEqual(originalArray);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('moveBackwardsOne', () => {
|
||||||
|
it('should move an item backward by one position', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveBackward(array, (item) => item === 'c');
|
||||||
|
expect(result).toEqual(['a', 'c', 'b', 'd']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should do nothing if the item is at the beginning', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveBackward(array, (item) => item === 'a');
|
||||||
|
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveBackward(array, (item) => item === 'z');
|
||||||
|
expect(result).toEqual(originalArray);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('moveToBack', () => {
|
||||||
|
it('should move an item to the back', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveToBack(array, (item) => item === 'b');
|
||||||
|
expect(result).toEqual(['a', 'c', 'd', 'b']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should do nothing if the item is already at the back', () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveToBack(array, (item) => item === 'd');
|
||||||
|
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||||
|
const array = [...originalArray];
|
||||||
|
const result = moveToBack(array, (item) => item === 'z');
|
||||||
|
expect(result).toEqual(originalArray);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
37
invokeai/frontend/web/src/common/util/arrayUtils.ts
Normal file
37
invokeai/frontend/web/src/common/util/arrayUtils.ts
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
export const moveForward = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||||
|
const index = array.findIndex(callback);
|
||||||
|
if (index >= 0 && index < array.length - 1) {
|
||||||
|
//@ts-expect-error - These indicies are safe per the previous check
|
||||||
|
[array[index], array[index + 1]] = [array[index + 1], array[index]];
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const moveToFront = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||||
|
const index = array.findIndex(callback);
|
||||||
|
if (index > 0) {
|
||||||
|
const [item] = array.splice(index, 1);
|
||||||
|
//@ts-expect-error - These indicies are safe per the previous check
|
||||||
|
array.unshift(item);
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const moveBackward = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||||
|
const index = array.findIndex(callback);
|
||||||
|
if (index > 0) {
|
||||||
|
//@ts-expect-error - These indicies are safe per the previous check
|
||||||
|
[array[index], array[index - 1]] = [array[index - 1], array[index]];
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const moveToBack = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||||
|
const index = array.findIndex(callback);
|
||||||
|
if (index >= 0 && index < array.length - 1) {
|
||||||
|
const [item] = array.splice(index, 1);
|
||||||
|
//@ts-expect-error - These indicies are safe per the previous check
|
||||||
|
array.push(item);
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
};
|
@ -10,6 +10,18 @@ import { clamp } from 'lodash-es';
|
|||||||
import type { MutableRefObject } from 'react';
|
import type { MutableRefObject } from 'react';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
export const calculateNewBrushSize = (brushSize: number, delta: number) => {
|
||||||
|
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
||||||
|
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
||||||
|
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||||
|
// This needs to be clamped to prevent the delta from getting too large
|
||||||
|
const finalDelta = clamp(targetDelta, -20, 20);
|
||||||
|
// The new brush size is also clamped to prevent it from getting too large or small
|
||||||
|
const newBrushSize = clamp(brushSize + finalDelta, 1, 500);
|
||||||
|
|
||||||
|
return newBrushSize;
|
||||||
|
};
|
||||||
|
|
||||||
const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const stageScale = useAppSelector((s) => s.canvas.stageScale);
|
const stageScale = useAppSelector((s) => s.canvas.stageScale);
|
||||||
@ -36,15 +48,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ($ctrl.get() || $meta.get()) {
|
if ($ctrl.get() || $meta.get()) {
|
||||||
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
dispatch(setBrushSize(calculateNewBrushSize(brushSize, delta)));
|
||||||
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
|
||||||
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
|
|
||||||
// This needs to be clamped to prevent the delta from getting too large
|
|
||||||
const finalDelta = clamp(targetDelta, -20, 20);
|
|
||||||
// The new brush size is also clamped to prevent it from getting too large or small
|
|
||||||
const newBrushSize = clamp(brushSize + finalDelta, 1, 500);
|
|
||||||
|
|
||||||
dispatch(setBrushSize(newBrushSize));
|
|
||||||
} else {
|
} else {
|
||||||
const cursorPos = stageRef.current.getPointerPosition();
|
const cursorPos = stageRef.current.getPointerPosition();
|
||||||
let delta = e.evt.deltaY;
|
let delta = e.evt.deltaY;
|
||||||
|
@ -7,3 +7,22 @@ export const blobToDataURL = (blob: Blob): Promise<string> => {
|
|||||||
reader.readAsDataURL(blob);
|
reader.readAsDataURL(blob);
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export function imageDataToDataURL(imageData: ImageData): string {
|
||||||
|
const { width, height } = imageData;
|
||||||
|
|
||||||
|
// Create a canvas to transfer the ImageData to
|
||||||
|
const canvas = document.createElement('canvas');
|
||||||
|
canvas.width = width;
|
||||||
|
canvas.height = height;
|
||||||
|
|
||||||
|
// Draw the ImageData onto the canvas
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
if (!ctx) {
|
||||||
|
throw new Error('Unable to get canvas context');
|
||||||
|
}
|
||||||
|
ctx.putImageData(imageData, 0, 0);
|
||||||
|
|
||||||
|
// Convert the canvas to a data URL (base64)
|
||||||
|
return canvas.toDataURL();
|
||||||
|
}
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
import type { RgbaColor } from 'react-colorful';
|
import type { RgbaColor, RgbColor } from 'react-colorful';
|
||||||
|
|
||||||
export const rgbaColorToString = (color: RgbaColor): string => {
|
export const rgbaColorToString = (color: RgbaColor): string => {
|
||||||
const { r, g, b, a } = color;
|
const { r, g, b, a } = color;
|
||||||
return `rgba(${r}, ${g}, ${b}, ${a})`;
|
return `rgba(${r}, ${g}, ${b}, ${a})`;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const rgbColorToString = (color: RgbColor): string => {
|
||||||
|
const { r, g, b } = color;
|
||||||
|
return `rgba(${r}, ${g}, ${b})`;
|
||||||
|
};
|
||||||
|
@ -6,6 +6,7 @@ import { deepClone } from 'common/util/deepClone';
|
|||||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
|
import { maskLayerIPAdapterAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
import { merge, uniq } from 'lodash-es';
|
import { merge, uniq } from 'lodash-es';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
import { socketInvocationError } from 'services/events/actions';
|
import { socketInvocationError } from 'services/events/actions';
|
||||||
@ -382,6 +383,10 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
builder.addCase(socketInvocationError, (state) => {
|
builder.addCase(socketInvocationError, (state) => {
|
||||||
state.pendingControlImages = [];
|
state.pendingControlImages = [];
|
||||||
});
|
});
|
||||||
|
|
||||||
|
builder.addCase(maskLayerIPAdapterAdded, (state, action) => {
|
||||||
|
caAdapter.addOne(state, buildControlAdapter(action.meta.uuid, 'ip_adapter'));
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box, Flex, IconButton, Tooltip } from '@invoke-ai/ui-library';
|
import { Box, Flex, IconButton, Tooltip, useShiftModifier } from '@invoke-ai/ui-library';
|
||||||
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||||
import { isString } from 'lodash-es';
|
import { isString } from 'lodash-es';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
@ -9,18 +9,19 @@ import { PiCopyBold, PiDownloadSimpleBold } from 'react-icons/pi';
|
|||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
label: string;
|
label: string;
|
||||||
data: object | string;
|
data: unknown;
|
||||||
fileName?: string;
|
fileName?: string;
|
||||||
withDownload?: boolean;
|
withDownload?: boolean;
|
||||||
withCopy?: boolean;
|
withCopy?: boolean;
|
||||||
|
extraCopyActions?: { label: string; getData: (data: unknown) => unknown }[];
|
||||||
};
|
};
|
||||||
|
|
||||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams('scroll', 'scroll').options;
|
const overlayscrollbarsOptions = getOverlayScrollbarsParams('scroll', 'scroll').options;
|
||||||
|
|
||||||
const DataViewer = (props: Props) => {
|
const DataViewer = (props: Props) => {
|
||||||
const { label, data, fileName, withDownload = true, withCopy = true } = props;
|
const { label, data, fileName, withDownload = true, withCopy = true, extraCopyActions } = props;
|
||||||
const dataString = useMemo(() => (isString(data) ? data : JSON.stringify(data, null, 2)), [data]);
|
const dataString = useMemo(() => (isString(data) ? data : JSON.stringify(data, null, 2)), [data]);
|
||||||
|
const shift = useShiftModifier();
|
||||||
const handleCopy = useCallback(() => {
|
const handleCopy = useCallback(() => {
|
||||||
navigator.clipboard.writeText(dataString);
|
navigator.clipboard.writeText(dataString);
|
||||||
}, [dataString]);
|
}, [dataString]);
|
||||||
@ -67,6 +68,10 @@ const DataViewer = (props: Props) => {
|
|||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
|
{shift &&
|
||||||
|
extraCopyActions?.map(({ label, getData }) => (
|
||||||
|
<ExtraCopyAction label={label} getData={getData} data={data} key={label} />
|
||||||
|
))}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
@ -78,3 +83,27 @@ const overlayScrollbarsStyles: CSSProperties = {
|
|||||||
height: '100%',
|
height: '100%',
|
||||||
width: '100%',
|
width: '100%',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type ExtraCopyActionProps = {
|
||||||
|
label: string;
|
||||||
|
data: unknown;
|
||||||
|
getData: (data: unknown) => unknown;
|
||||||
|
};
|
||||||
|
const ExtraCopyAction = ({ label, data, getData }: ExtraCopyActionProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const handleCopy = useCallback(() => {
|
||||||
|
navigator.clipboard.writeText(JSON.stringify(getData(data), null, 2));
|
||||||
|
}, [data, getData]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip label={`${t('gallery.copy')} ${label} JSON`}>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`${t('gallery.copy')} ${label} JSON`}
|
||||||
|
icon={<PiCopyBold size={16} />}
|
||||||
|
variant="ghost"
|
||||||
|
opacity={0.7}
|
||||||
|
onClick={handleCopy}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||||
@ -92,13 +92,9 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
|||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex flexDir="column" gap={8}>
|
<SimpleGrid columns={2} gap={8}>
|
||||||
<Flex gap={8}>
|
<DefaultPreprocessor control={control} name="preprocessor" />
|
||||||
<Flex gap={4} w="full">
|
</SimpleGrid>
|
||||||
<DefaultPreprocessor control={control} name="preprocessor" />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||||
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||||
@ -122,40 +122,16 @@ export const MainModelDefaultSettings = () => {
|
|||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex flexDir="column" gap={8}>
|
<SimpleGrid columns={2} gap={8}>
|
||||||
<Flex gap={8}>
|
<DefaultVae control={control} name="vae" />
|
||||||
<Flex gap={4} w="full">
|
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||||
<DefaultVae control={control} name="vae" />
|
<DefaultScheduler control={control} name="scheduler" />
|
||||||
</Flex>
|
<DefaultSteps control={control} name="steps" />
|
||||||
<Flex gap={4} w="full">
|
<DefaultCfgScale control={control} name="cfgScale" />
|
||||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||||
</Flex>
|
<DefaultWidth control={control} optimalDimension={optimalDimension} />
|
||||||
</Flex>
|
<DefaultHeight control={control} optimalDimension={optimalDimension} />
|
||||||
<Flex gap={8}>
|
</SimpleGrid>
|
||||||
<Flex gap={4} w="full">
|
|
||||||
<DefaultScheduler control={control} name="scheduler" />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4} w="full">
|
|
||||||
<DefaultSteps control={control} name="steps" />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={8}>
|
|
||||||
<Flex gap={4} w="full">
|
|
||||||
<DefaultCfgScale control={control} name="cfgScale" />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4} w="full">
|
|
||||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={8}>
|
|
||||||
<Flex gap={4} w="full">
|
|
||||||
<DefaultWidth control={control} optimalDimension={optimalDimension} />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4} w="full">
|
|
||||||
<DefaultHeight control={control} optimalDimension={optimalDimension} />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -6,6 +6,7 @@ import {
|
|||||||
FormLabel,
|
FormLabel,
|
||||||
Heading,
|
Heading,
|
||||||
Input,
|
Input,
|
||||||
|
SimpleGrid,
|
||||||
Text,
|
Text,
|
||||||
Textarea,
|
Textarea,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
@ -66,25 +67,21 @@ export const ModelEdit = ({ form }: Props) => {
|
|||||||
<Heading as="h3" fontSize="md" mt="4">
|
<Heading as="h3" fontSize="md" mt="4">
|
||||||
{t('modelManager.modelSettings')}
|
{t('modelManager.modelSettings')}
|
||||||
</Heading>
|
</Heading>
|
||||||
<Flex gap={4}>
|
<SimpleGrid columns={2} gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||||
<BaseModelSelect control={form.control} />
|
<BaseModelSelect control={form.control} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||||
<>
|
<ModelVariantSelect control={form.control} />
|
||||||
<Flex gap={4}>
|
</FormControl>
|
||||||
|
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||||
|
<>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
|
||||||
<ModelVariantSelect control={form.control} />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||||
<PredictionTypeSelect control={form.control} />
|
<PredictionTypeSelect control={form.control} />
|
||||||
@ -93,9 +90,9 @@ export const ModelEdit = ({ form }: Props) => {
|
|||||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||||
<Checkbox {...form.register('upcast_attention')} />
|
<Checkbox {...form.register('upcast_attention')} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</>
|
||||||
</>
|
)}
|
||||||
)}
|
</SimpleGrid>
|
||||||
</Flex>
|
</Flex>
|
||||||
</form>
|
</form>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box, Flex, Text } from '@invoke-ai/ui-library';
|
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||||
@ -24,57 +24,32 @@ export const ModelView = () => {
|
|||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full" gap={4}>
|
<Flex flexDir="column" h="full" gap={4}>
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||||
<Flex flexDir="column" gap={4}>
|
<SimpleGrid columns={2} gap={4}>
|
||||||
<Flex gap={2}>
|
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
||||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
||||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
<ModelAttrView label={t('common.format')} value={data.format} />
|
||||||
</Flex>
|
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
||||||
<Flex gap={2}>
|
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
|
||||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
|
||||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
||||||
<Flex gap={2}>
|
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
||||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
|
||||||
</Flex>
|
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||||
<>
|
<>
|
||||||
<Flex gap={2}>
|
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
||||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
||||||
<ModelAttrView label={t('modelManager.variant')} value={data.variant} />
|
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
||||||
</Flex>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
|
||||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
|
||||||
</Flex>
|
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||||
<Flex gap={2}>
|
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
|
||||||
</Flex>
|
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</SimpleGrid>
|
||||||
|
</Box>
|
||||||
|
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||||
|
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
|
||||||
|
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
|
||||||
|
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
|
||||||
</Box>
|
</Box>
|
||||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && (
|
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
|
||||||
<MainModelDefaultSettings />
|
|
||||||
</Box>
|
|
||||||
)}
|
|
||||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && (
|
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
|
||||||
<ControlNetOrT2IAdapterDefaultSettings />
|
|
||||||
</Box>
|
|
||||||
)}
|
|
||||||
{(data.type === 'main' || data.type === 'lora') && (
|
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
|
||||||
<TriggerPhrases />
|
|
||||||
</Box>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -77,9 +77,17 @@ export const TriggerPhrases = () => {
|
|||||||
[updateModel, selectedModelKey, triggerPhrases]
|
[updateModel, selectedModelKey, triggerPhrases]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const onTriggerPhraseAddFormSubmit = useCallback(
|
||||||
|
(e: React.FormEvent<HTMLFormElement>) => {
|
||||||
|
e.preventDefault();
|
||||||
|
addTriggerPhrase();
|
||||||
|
},
|
||||||
|
[addTriggerPhrase]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" w="full" gap="5">
|
<Flex flexDir="column" w="full" gap="5">
|
||||||
<form>
|
<form onSubmit={onTriggerPhraseAddFormSubmit}>
|
||||||
<FormControl w="full" isInvalid={Boolean(errors.length)} orientation="vertical">
|
<FormControl w="full" isInvalid={Boolean(errors.length)} orientation="vertical">
|
||||||
<FormLabel>{t('modelManager.triggerPhrases')}</FormLabel>
|
<FormLabel>{t('modelManager.triggerPhrases')}</FormLabel>
|
||||||
<Flex flexDir="column" w="full">
|
<Flex flexDir="column" w="full">
|
||||||
|
@ -23,6 +23,7 @@ export type NodesState = {
|
|||||||
nodeOpacity: number;
|
nodeOpacity: number;
|
||||||
shouldSnapToGrid: boolean;
|
shouldSnapToGrid: boolean;
|
||||||
shouldColorEdges: boolean;
|
shouldColorEdges: boolean;
|
||||||
|
shouldShowEdgeLabels: boolean;
|
||||||
selectedNodes: string[];
|
selectedNodes: string[];
|
||||||
selectedEdges: string[];
|
selectedEdges: string[];
|
||||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||||
@ -32,7 +33,6 @@ export type NodesState = {
|
|||||||
isAddNodePopoverOpen: boolean;
|
isAddNodePopoverOpen: boolean;
|
||||||
addNewNodePosition: XYPosition | null;
|
addNewNodePosition: XYPosition | null;
|
||||||
selectionMode: SelectionMode;
|
selectionMode: SelectionMode;
|
||||||
shouldShowEdgeLabels: boolean;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type WorkflowMode = 'edit' | 'view';
|
export type WorkflowMode = 'edit' | 'view';
|
||||||
|
@ -19,12 +19,14 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string
|
baseNodeId: string
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
|
const validIPAdapters = selectValidIPAdapters(state.controlAdapters)
|
||||||
const hasModel = Boolean(model);
|
.filter(({ model, controlImage, isEnabled }) => {
|
||||||
const doesBaseMatch = model?.base === state.generation.model?.base;
|
const hasModel = Boolean(model);
|
||||||
const hasControlImage = controlImage;
|
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||||
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
const hasControlImage = controlImage;
|
||||||
});
|
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||||
|
})
|
||||||
|
.filter((ca) => !state.regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)));
|
||||||
|
|
||||||
if (validIPAdapters.length) {
|
if (validIPAdapters.length) {
|
||||||
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
||||||
|
@ -0,0 +1,346 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { selectAllIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import {
|
||||||
|
IP_ADAPTER_COLLECT,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NEGATIVE_CONDITIONING_COLLECT,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
POSITIVE_CONDITIONING_COLLECT,
|
||||||
|
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||||
|
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
|
||||||
|
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||||
|
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||||
|
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||||
|
} from 'features/nodes/util/graph/constants';
|
||||||
|
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
|
||||||
|
import { size } from 'lodash-es';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import type { CollectInvocation, Edge, IPAdapterInvocation, NonNullableGraph, S } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||||
|
if (!state.regionalPrompts.present.isEnabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
const isSDXL = state.generation.model?.base === 'sdxl';
|
||||||
|
const layers = state.regionalPrompts.present.layers
|
||||||
|
// Only support vector mask layers now
|
||||||
|
// TODO: Image masks
|
||||||
|
.filter(isVectorMaskLayer)
|
||||||
|
// Only visible layers are rendered on the canvas
|
||||||
|
.filter((l) => l.isVisible)
|
||||||
|
// Only layers with prompts get added to the graph
|
||||||
|
.filter((l) => {
|
||||||
|
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
|
||||||
|
const hasIPAdapter = l.ipAdapterIds.length !== 0;
|
||||||
|
return hasTextPrompt || hasIPAdapter;
|
||||||
|
});
|
||||||
|
|
||||||
|
const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter(
|
||||||
|
({ id, model, controlImage, isEnabled }) => {
|
||||||
|
const hasModel = Boolean(model);
|
||||||
|
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||||
|
const hasControlImage = controlImage;
|
||||||
|
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
|
||||||
|
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const layerIds = layers.map((l) => l.id);
|
||||||
|
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||||
|
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||||
|
|
||||||
|
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||||
|
// the existing conditioning nodes.
|
||||||
|
|
||||||
|
// With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up
|
||||||
|
const posCondCollectNode: CollectInvocation = {
|
||||||
|
id: POSITIVE_CONDITIONING_COLLECT,
|
||||||
|
type: 'collect',
|
||||||
|
};
|
||||||
|
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
|
||||||
|
const negCondCollectNode: CollectInvocation = {
|
||||||
|
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||||
|
type: 'collect',
|
||||||
|
};
|
||||||
|
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
|
||||||
|
|
||||||
|
// Re-route the denoise node's OG conditioning inputs to the collect nodes
|
||||||
|
const newEdges: Edge[] = [];
|
||||||
|
for (const edge of graph.edges) {
|
||||||
|
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
|
||||||
|
newEdges.push({
|
||||||
|
source: edge.source,
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
|
||||||
|
newEdges.push({
|
||||||
|
source: edge.source,
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
newEdges.push(edge);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
graph.edges = newEdges;
|
||||||
|
|
||||||
|
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: denoiseNodeId,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: denoiseNodeId,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!graph.nodes[IP_ADAPTER_COLLECT] && regionalIPAdapters.length > 0) {
|
||||||
|
const ipAdapterCollectNode: CollectInvocation = {
|
||||||
|
id: IP_ADAPTER_COLLECT,
|
||||||
|
type: 'collect',
|
||||||
|
is_intermediate: true,
|
||||||
|
};
|
||||||
|
graph.nodes[IP_ADAPTER_COLLECT] = ipAdapterCollectNode;
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IP_ADAPTER_COLLECT, field: 'collection' },
|
||||||
|
destination: {
|
||||||
|
node_id: denoiseNodeId,
|
||||||
|
field: 'ip_adapter',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload the blobs to the backend, add each to graph
|
||||||
|
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
|
||||||
|
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
|
||||||
|
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
|
||||||
|
for (const layer of layers) {
|
||||||
|
const blob = blobs[layer.id];
|
||||||
|
assert(blob, `Blob for layer ${layer.id} not found`);
|
||||||
|
|
||||||
|
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||||
|
const req = dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||||
|
);
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
// TODO: This will raise on network error
|
||||||
|
const { image_name } = await req.unwrap();
|
||||||
|
|
||||||
|
// The main mask-to-tensor node
|
||||||
|
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||||
|
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
|
||||||
|
type: 'alpha_mask_to_tensor',
|
||||||
|
image: {
|
||||||
|
image_name,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
|
||||||
|
|
||||||
|
if (layer.positivePrompt) {
|
||||||
|
// The main positive conditioning node
|
||||||
|
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||||
|
? {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
type: 'compel',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
};
|
||||||
|
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
|
||||||
|
|
||||||
|
// Connect the mask to the conditioning
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||||
|
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Connect the conditioning to the collector
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
|
||||||
|
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||||
|
for (const edge of graph.edges) {
|
||||||
|
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||||
|
graph.edges.push({
|
||||||
|
source: edge.source,
|
||||||
|
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (layer.negativePrompt) {
|
||||||
|
// The main negative conditioning node
|
||||||
|
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||||
|
? {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.negativePrompt,
|
||||||
|
style: layer.negativePrompt,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
type: 'compel',
|
||||||
|
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.negativePrompt,
|
||||||
|
};
|
||||||
|
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
|
||||||
|
|
||||||
|
// Connect the mask to the conditioning
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||||
|
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Connect the conditioning to the collector
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
|
||||||
|
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Copy the connections to the "global" negative conditioning node to the regional cond
|
||||||
|
for (const edge of graph.edges) {
|
||||||
|
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||||
|
graph.edges.push({
|
||||||
|
source: edge.source,
|
||||||
|
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||||
|
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
|
||||||
|
// We re-use the mask image, but invert it when converting to tensor
|
||||||
|
const invertTensorMaskNode: S['InvertTensorMaskInvocation'] = {
|
||||||
|
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
|
||||||
|
type: 'invert_tensor_mask',
|
||||||
|
};
|
||||||
|
graph.nodes[invertTensorMaskNode.id] = invertTensorMaskNode;
|
||||||
|
|
||||||
|
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: maskToTensorNode.id,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: invertTensorMaskNode.id,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
|
||||||
|
// positive prompt
|
||||||
|
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||||
|
? {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
style: layer.positivePrompt,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
type: 'compel',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
};
|
||||||
|
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
|
||||||
|
// Connect the inverted mask to the conditioning
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: invertTensorMaskNode.id, field: 'mask' },
|
||||||
|
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
|
||||||
|
});
|
||||||
|
// Connect the conditioning to the negative collector
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
|
||||||
|
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||||
|
});
|
||||||
|
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||||
|
for (const edge of graph.edges) {
|
||||||
|
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||||
|
graph.edges.push({
|
||||||
|
source: edge.source,
|
||||||
|
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const ipAdapterId of layer.ipAdapterIds) {
|
||||||
|
const ipAdapter = selectAllIPAdapters(state.controlAdapters)
|
||||||
|
.filter(({ id, model, controlImage, isEnabled }) => {
|
||||||
|
const hasModel = Boolean(model);
|
||||||
|
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||||
|
const hasControlImage = controlImage;
|
||||||
|
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
|
||||||
|
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
|
||||||
|
})
|
||||||
|
.find((ca) => ca.id === ipAdapterId);
|
||||||
|
|
||||||
|
if (!ipAdapter?.model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||||
|
|
||||||
|
assert(controlImage, 'IP Adapter image is required');
|
||||||
|
|
||||||
|
const ipAdapterNode: IPAdapterInvocation = {
|
||||||
|
id: `ip_adapter_${id}`,
|
||||||
|
type: 'ip_adapter',
|
||||||
|
is_intermediate: true,
|
||||||
|
weight: weight,
|
||||||
|
method: method,
|
||||||
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
image: {
|
||||||
|
image_name: controlImage,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||||
|
|
||||||
|
// Connect the mask to the conditioning
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||||
|
destination: { node_id: ipAdapterNode.id, field: 'mask' },
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||||
|
destination: {
|
||||||
|
node_id: IP_ADAPTER_COLLECT,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
@ -9,6 +9,7 @@ import {
|
|||||||
CANVAS_TEXT_TO_IMAGE_GRAPH,
|
CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
IMAGE_TO_LATENTS,
|
IMAGE_TO_LATENTS,
|
||||||
|
INPAINT_CREATE_MASK,
|
||||||
INPAINT_IMAGE,
|
INPAINT_IMAGE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
@ -145,6 +146,16 @@ export const addVAEToGraph = async (
|
|||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
|
@ -133,6 +133,8 @@ export const buildCanvasInpaintGraph = async (
|
|||||||
coherence_mode: canvasCoherenceMode,
|
coherence_mode: canvasCoherenceMode,
|
||||||
minimum_denoise: canvasCoherenceMinDenoise,
|
minimum_denoise: canvasCoherenceMinDenoise,
|
||||||
edge_radius: canvasCoherenceEdgeSize,
|
edge_radius: canvasCoherenceEdgeSize,
|
||||||
|
tiled: false,
|
||||||
|
fp32: fp32,
|
||||||
},
|
},
|
||||||
[DENOISE_LATENTS]: {
|
[DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -182,6 +184,16 @@ export const buildCanvasInpaintGraph = async (
|
|||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: modelLoaderNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Connect CLIP Skip to Conditioning
|
// Connect CLIP Skip to Conditioning
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -331,6 +343,16 @@ export const buildCanvasInpaintGraph = async (
|
|||||||
field: 'mask',
|
field: 'mask',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Resize Down
|
// Resize Down
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
|
@ -157,6 +157,8 @@ export const buildCanvasOutpaintGraph = async (
|
|||||||
coherence_mode: canvasCoherenceMode,
|
coherence_mode: canvasCoherenceMode,
|
||||||
edge_radius: canvasCoherenceEdgeSize,
|
edge_radius: canvasCoherenceEdgeSize,
|
||||||
minimum_denoise: canvasCoherenceMinDenoise,
|
minimum_denoise: canvasCoherenceMinDenoise,
|
||||||
|
tiled: false,
|
||||||
|
fp32: fp32,
|
||||||
},
|
},
|
||||||
[DENOISE_LATENTS]: {
|
[DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -207,6 +209,16 @@ export const buildCanvasOutpaintGraph = async (
|
|||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: modelLoaderNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Connect CLIP Skip to Conditioning
|
// Connect CLIP Skip to Conditioning
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -453,6 +465,16 @@ export const buildCanvasOutpaintGraph = async (
|
|||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Resize Results Down
|
// Resize Results Down
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
|
@ -135,6 +135,8 @@ export const buildCanvasSDXLInpaintGraph = async (
|
|||||||
coherence_mode: canvasCoherenceMode,
|
coherence_mode: canvasCoherenceMode,
|
||||||
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||||
edge_radius: canvasCoherenceEdgeSize,
|
edge_radius: canvasCoherenceEdgeSize,
|
||||||
|
tiled: false,
|
||||||
|
fp32: fp32,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -214,6 +216,16 @@ export const buildCanvasSDXLInpaintGraph = async (
|
|||||||
field: 'clip2',
|
field: 'clip2',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: modelLoaderNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Connect Everything To Inpaint Node
|
// Connect Everything To Inpaint Node
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -342,6 +354,16 @@ export const buildCanvasSDXLInpaintGraph = async (
|
|||||||
field: 'mask',
|
field: 'mask',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Resize Down
|
// Resize Down
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
|
@ -157,6 +157,8 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
coherence_mode: canvasCoherenceMode,
|
coherence_mode: canvasCoherenceMode,
|
||||||
edge_radius: canvasCoherenceEdgeSize,
|
edge_radius: canvasCoherenceEdgeSize,
|
||||||
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||||
|
tiled: false,
|
||||||
|
fp32: fp32,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -237,6 +239,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
field: 'clip2',
|
field: 'clip2',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: modelLoaderNodeId,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Connect Infill Result To Inpaint Image
|
// Connect Infill Result To Inpaint Image
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
@ -451,6 +463,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
field: 'image',
|
field: 'image',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
},
|
||||||
// Take combined mask and resize
|
// Take combined mask and resize
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
|
||||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
@ -273,6 +274,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
|||||||
|
|
||||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
|
await addRegionalPromptsToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
// must add before watermarker!
|
// must add before watermarker!
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
@ -255,6 +256,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
|||||||
|
|
||||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
|
await addRegionalPromptsToGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// High resolution fix.
|
// High resolution fix.
|
||||||
if (state.hrf.hrfEnabled) {
|
if (state.hrf.hrfEnabled) {
|
||||||
addHrfToGraph(state, graph);
|
addHrfToGraph(state, graph);
|
||||||
|
@ -46,6 +46,13 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
|
|||||||
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
||||||
export const SEAMLESS = 'seamless';
|
export const SEAMLESS = 'seamless';
|
||||||
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
||||||
|
export const PROMPT_REGION_MASK_TO_TENSOR_PREFIX = 'prompt_region_mask_to_tensor';
|
||||||
|
export const PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX = 'prompt_region_invert_tensor_mask';
|
||||||
|
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
|
||||||
|
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
|
||||||
|
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
|
||||||
|
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||||
|
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||||
|
|
||||||
// friendly graph ids
|
// friendly graph ids
|
||||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
export const AspectRatioCanvasPreview = memo(() => {
|
||||||
|
return (
|
||||||
|
<Flex w="full" h="full" alignItems="center" justifyContent="center" position="relative">
|
||||||
|
<StageComponent asPreview />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
AspectRatioCanvasPreview.displayName = 'AspectRatioCanvasPreview';
|
@ -2,7 +2,7 @@ import { useSize } from '@chakra-ui/react-use-size';
|
|||||||
import { Flex, Icon } from '@invoke-ai/ui-library';
|
import { Flex, Icon } from '@invoke-ai/ui-library';
|
||||||
import { useImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
import { useImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import { useMemo, useRef } from 'react';
|
import { memo, useMemo, useRef } from 'react';
|
||||||
import { PiFrameCorners } from 'react-icons/pi';
|
import { PiFrameCorners } from 'react-icons/pi';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -15,7 +15,7 @@ import {
|
|||||||
MOTION_ICON_INITIAL,
|
MOTION_ICON_INITIAL,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
|
||||||
export const AspectRatioPreview = () => {
|
export const AspectRatioIconPreview = memo(() => {
|
||||||
const ctx = useImageSizeContext();
|
const ctx = useImageSizeContext();
|
||||||
const containerRef = useRef<HTMLDivElement>(null);
|
const containerRef = useRef<HTMLDivElement>(null);
|
||||||
const containerSize = useSize(containerRef);
|
const containerSize = useSize(containerRef);
|
||||||
@ -70,4 +70,6 @@ export const AspectRatioPreview = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
AspectRatioIconPreview.displayName = 'AspectRatioIconPreview';
|
@ -1,6 +1,5 @@
|
|||||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||||
import { Flex, FormControlGroup } from '@invoke-ai/ui-library';
|
import { Flex, FormControlGroup } from '@invoke-ai/ui-library';
|
||||||
import { AspectRatioPreview } from 'features/parameters/components/ImageSize/AspectRatioPreview';
|
|
||||||
import { AspectRatioSelect } from 'features/parameters/components/ImageSize/AspectRatioSelect';
|
import { AspectRatioSelect } from 'features/parameters/components/ImageSize/AspectRatioSelect';
|
||||||
import type { ImageSizeContextInnerValue } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
import type { ImageSizeContextInnerValue } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||||
import { ImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
import { ImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||||
@ -13,10 +12,11 @@ import { memo } from 'react';
|
|||||||
type ImageSizeProps = ImageSizeContextInnerValue & {
|
type ImageSizeProps = ImageSizeContextInnerValue & {
|
||||||
widthComponent: ReactNode;
|
widthComponent: ReactNode;
|
||||||
heightComponent: ReactNode;
|
heightComponent: ReactNode;
|
||||||
|
previewComponent: ReactNode;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ImageSize = memo((props: ImageSizeProps) => {
|
export const ImageSize = memo((props: ImageSizeProps) => {
|
||||||
const { widthComponent, heightComponent, ...ctx } = props;
|
const { widthComponent, heightComponent, previewComponent, ...ctx } = props;
|
||||||
return (
|
return (
|
||||||
<ImageSizeContext.Provider value={ctx}>
|
<ImageSizeContext.Provider value={ctx}>
|
||||||
<Flex gap={4} alignItems="center">
|
<Flex gap={4} alignItems="center">
|
||||||
@ -33,7 +33,7 @@ export const ImageSize = memo((props: ImageSizeProps) => {
|
|||||||
</FormControlGroup>
|
</FormControlGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex w="108px" h="108px" flexShrink={0} flexGrow={0}>
|
<Flex w="108px" h="108px" flexShrink={0} flexGrow={0}>
|
||||||
<AspectRatioPreview />
|
{previewComponent}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</ImageSizeContext.Provider>
|
</ImageSizeContext.Provider>
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
|
||||||
import type { AspectRatioID, AspectRatioState } from './types';
|
import type { AspectRatioID, AspectRatioState } from './types';
|
||||||
|
|
||||||
// When the aspect ratio is between these two values, we show the icon (experimentally determined)
|
// When the aspect ratio is between these two values, we show the icon (experimentally determined)
|
||||||
export const ICON_LOW_CUTOFF = 0.23;
|
export const ICON_LOW_CUTOFF = 0.23;
|
||||||
export const ICON_HIGH_CUTOFF = 1 / ICON_LOW_CUTOFF;
|
export const ICON_HIGH_CUTOFF = 1 / ICON_LOW_CUTOFF;
|
||||||
@ -25,7 +24,6 @@ export const ICON_CONTAINER_STYLES = {
|
|||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
|
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
|
||||||
{ label: 'Free' as const, value: 'Free' },
|
{ label: 'Free' as const, value: 'Free' },
|
||||||
{ label: '16:9' as const, value: '16:9' },
|
{ label: '16:9' as const, value: '16:9' },
|
||||||
|
@ -196,3 +196,8 @@ const zLoRAWeight = z.number();
|
|||||||
type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
|
type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
|
||||||
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
|
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region Regional Prompts AutoNegative
|
||||||
|
const zAutoNegative = z.enum(['off', 'invert']);
|
||||||
|
export type ParameterAutoNegative = z.infer<typeof zAutoNegative>;
|
||||||
|
// #endregion
|
||||||
|
@ -3,6 +3,7 @@ import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataView
|
|||||||
import { useCancelBatch } from 'features/queue/hooks/useCancelBatch';
|
import { useCancelBatch } from 'features/queue/hooks/useCancelBatch';
|
||||||
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
|
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
|
||||||
import { getSecondsFromTimestamps } from 'features/queue/util/getSecondsFromTimestamps';
|
import { getSecondsFromTimestamps } from 'features/queue/util/getSecondsFromTimestamps';
|
||||||
|
import { get } from 'lodash-es';
|
||||||
import type { ReactNode } from 'react';
|
import type { ReactNode } from 'react';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -92,7 +93,15 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
<Flex layerStyle="second" h={512} w="full" borderRadius="base" alignItems="center" justifyContent="center">
|
<Flex layerStyle="second" h={512} w="full" borderRadius="base" alignItems="center" justifyContent="center">
|
||||||
{queueItem ? <DataViewer label="Queue Item" data={queueItem} /> : <Spinner opacity={0.5} />}
|
{queueItem ? (
|
||||||
|
<DataViewer
|
||||||
|
label="Queue Item"
|
||||||
|
data={queueItem}
|
||||||
|
extraCopyActions={[{ label: 'Graph', getData: (data) => get(data, 'session.graph') }]}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Spinner opacity={0.5} />
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
import { Button } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { layerAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
export const AddLayerButton = memo(() => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const onClick = useCallback(() => {
|
||||||
|
dispatch(layerAdded('vector_mask_layer'));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Button onClick={onClick} leftIcon={<PiPlusBold />} variant="ghost">
|
||||||
|
{t('regionalPrompts.addLayer')}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
AddLayerButton.displayName = 'AddLayerButton';
|
@ -0,0 +1,70 @@
|
|||||||
|
import { Button, Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
isVectorMaskLayer,
|
||||||
|
maskLayerIPAdapterAdded,
|
||||||
|
maskLayerNegativePromptChanged,
|
||||||
|
maskLayerPositivePromptChanged,
|
||||||
|
selectRegionalPromptsSlice,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
type AddPromptButtonProps = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const AddPromptButtons = ({ layerId }: AddPromptButtonProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const selectValidActions = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||||
|
return {
|
||||||
|
canAddPositivePrompt: layer.positivePrompt === null,
|
||||||
|
canAddNegativePrompt: layer.negativePrompt === null,
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const validActions = useAppSelector(selectValidActions);
|
||||||
|
const addPositivePrompt = useCallback(() => {
|
||||||
|
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: '' }));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const addNegativePrompt = useCallback(() => {
|
||||||
|
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const addIPAdapter = useCallback(() => {
|
||||||
|
dispatch(maskLayerIPAdapterAdded(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" p={2} justifyContent="space-between">
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
leftIcon={<PiPlusBold />}
|
||||||
|
onClick={addPositivePrompt}
|
||||||
|
isDisabled={!validActions.canAddPositivePrompt}
|
||||||
|
>
|
||||||
|
{t('common.positivePrompt')}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
leftIcon={<PiPlusBold />}
|
||||||
|
onClick={addNegativePrompt}
|
||||||
|
isDisabled={!validActions.canAddNegativePrompt}
|
||||||
|
>
|
||||||
|
{t('common.negativePrompt')}
|
||||||
|
</Button>
|
||||||
|
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||||
|
{t('common.ipAdapter')}
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,63 @@
|
|||||||
|
import {
|
||||||
|
CompositeNumberInput,
|
||||||
|
CompositeSlider,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Popover,
|
||||||
|
PopoverArrow,
|
||||||
|
PopoverBody,
|
||||||
|
PopoverContent,
|
||||||
|
PopoverTrigger,
|
||||||
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { brushSizeChanged, initialRegionalPromptsState } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const marks = [0, 100, 200, 300];
|
||||||
|
const formatPx = (v: number | string) => `${v} px`;
|
||||||
|
|
||||||
|
export const BrushSize = memo(() => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const brushSize = useAppSelector((s) => s.regionalPrompts.present.brushSize);
|
||||||
|
const onChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(brushSizeChanged(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<FormControl w="min-content">
|
||||||
|
<FormLabel m={0}>{t('regionalPrompts.brushSize')}</FormLabel>
|
||||||
|
<Popover isLazy>
|
||||||
|
<PopoverTrigger>
|
||||||
|
<CompositeNumberInput
|
||||||
|
min={1}
|
||||||
|
max={600}
|
||||||
|
defaultValue={initialRegionalPromptsState.brushSize}
|
||||||
|
value={brushSize}
|
||||||
|
onChange={onChange}
|
||||||
|
w={24}
|
||||||
|
format={formatPx}
|
||||||
|
/>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent w={200} py={2} px={4}>
|
||||||
|
<PopoverArrow />
|
||||||
|
<PopoverBody>
|
||||||
|
<CompositeSlider
|
||||||
|
min={1}
|
||||||
|
max={300}
|
||||||
|
defaultValue={initialRegionalPromptsState.brushSize}
|
||||||
|
value={brushSize}
|
||||||
|
onChange={onChange}
|
||||||
|
marks={marks}
|
||||||
|
/>
|
||||||
|
</PopoverBody>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
BrushSize.displayName = 'BrushSize';
|
@ -0,0 +1,22 @@
|
|||||||
|
import { Button } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { allLayersDeleted } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
export const DeleteAllLayersButton = memo(() => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const onClick = useCallback(() => {
|
||||||
|
dispatch(allLayersDeleted());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Button onClick={onClick} leftIcon={<PiTrashSimpleBold />} variant="ghost" colorScheme="error">
|
||||||
|
{t('regionalPrompts.deleteAll')}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
DeleteAllLayersButton.displayName = 'DeleteAllLayersButton';
|
@ -0,0 +1,70 @@
|
|||||||
|
import {
|
||||||
|
CompositeNumberInput,
|
||||||
|
CompositeSlider,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Popover,
|
||||||
|
PopoverArrow,
|
||||||
|
PopoverBody,
|
||||||
|
PopoverContent,
|
||||||
|
PopoverTrigger,
|
||||||
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
globalMaskLayerOpacityChanged,
|
||||||
|
initialRegionalPromptsState,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const marks = [0, 25, 50, 75, 100];
|
||||||
|
const formatPct = (v: number | string) => `${v} %`;
|
||||||
|
|
||||||
|
export const GlobalMaskLayerOpacity = memo(() => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const globalMaskLayerOpacity = useAppSelector((s) =>
|
||||||
|
Math.round(s.regionalPrompts.present.globalMaskLayerOpacity * 100)
|
||||||
|
);
|
||||||
|
const onChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(globalMaskLayerOpacityChanged(v / 100));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<FormControl w="min-content">
|
||||||
|
<FormLabel m={0}>{t('regionalPrompts.globalMaskOpacity')}</FormLabel>
|
||||||
|
<Popover isLazy>
|
||||||
|
<PopoverTrigger>
|
||||||
|
<CompositeNumberInput
|
||||||
|
min={0}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
value={globalMaskLayerOpacity}
|
||||||
|
defaultValue={initialRegionalPromptsState.globalMaskLayerOpacity * 100}
|
||||||
|
onChange={onChange}
|
||||||
|
w={24}
|
||||||
|
format={formatPct}
|
||||||
|
/>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent w={200} py={2} px={4}>
|
||||||
|
<PopoverArrow />
|
||||||
|
<PopoverBody>
|
||||||
|
<CompositeSlider
|
||||||
|
min={0}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
value={globalMaskLayerOpacity}
|
||||||
|
defaultValue={initialRegionalPromptsState.globalMaskLayerOpacity * 100}
|
||||||
|
onChange={onChange}
|
||||||
|
marks={marks}
|
||||||
|
/>
|
||||||
|
</PopoverBody>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
GlobalMaskLayerOpacity.displayName = 'GlobalMaskLayerOpacity';
|
@ -0,0 +1,51 @@
|
|||||||
|
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
isVectorMaskLayer,
|
||||||
|
maskLayerAutoNegativeChanged,
|
||||||
|
selectRegionalPromptsSlice,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import type { ChangeEvent } from 'react';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const useAutoNegative = (layerId: string) => {
|
||||||
|
const selectAutoNegative = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||||
|
return layer.autoNegative;
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const autoNegative = useAppSelector(selectAutoNegative);
|
||||||
|
return autoNegative;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const autoNegative = useAutoNegative(layerId);
|
||||||
|
const onChange = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(maskLayerAutoNegativeChanged({ layerId, autoNegative: e.target.checked ? 'invert' : 'off' }));
|
||||||
|
},
|
||||||
|
[dispatch, layerId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl gap={2}>
|
||||||
|
<FormLabel m={0}>{t('regionalPrompts.autoNegative')}</FormLabel>
|
||||||
|
<Checkbox size="md" isChecked={autoNegative === 'invert'} onChange={onChange} />
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerAutoNegativeCheckbox.displayName = 'RPLayerAutoNegativeCheckbox';
|
@ -0,0 +1,67 @@
|
|||||||
|
import { Flex, Popover, PopoverBody, PopoverContent, PopoverTrigger, Tooltip } from '@invoke-ai/ui-library';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import RgbColorPicker from 'common/components/RgbColorPicker';
|
||||||
|
import { rgbColorToString } from 'features/canvas/util/colorToString';
|
||||||
|
import {
|
||||||
|
isVectorMaskLayer,
|
||||||
|
maskLayerPreviewColorChanged,
|
||||||
|
selectRegionalPromptsSlice,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import type { RgbColor } from 'react-colorful';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerColorPicker = memo(({ layerId }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const selectColor = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an vector mask layer`);
|
||||||
|
return layer.previewColor;
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const color = useAppSelector(selectColor);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const onColorChange = useCallback(
|
||||||
|
(color: RgbColor) => {
|
||||||
|
dispatch(maskLayerPreviewColorChanged({ layerId, color }));
|
||||||
|
},
|
||||||
|
[dispatch, layerId]
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<Popover isLazy>
|
||||||
|
<PopoverTrigger>
|
||||||
|
<span>
|
||||||
|
<Tooltip label={t('regionalPrompts.maskPreviewColor')}>
|
||||||
|
<Flex
|
||||||
|
as="button"
|
||||||
|
aria-label={t('regionalPrompts.maskPreviewColor')}
|
||||||
|
borderRadius="base"
|
||||||
|
borderWidth={1}
|
||||||
|
bg={rgbColorToString(color)}
|
||||||
|
w={8}
|
||||||
|
h={8}
|
||||||
|
cursor="pointer"
|
||||||
|
tabIndex={-1}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
</span>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent>
|
||||||
|
<PopoverBody minH={64}>
|
||||||
|
<RgbColorPicker color={color} onChange={onColorChange} withNumberInput />
|
||||||
|
</PopoverBody>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerColorPicker.displayName = 'RPLayerColorPicker';
|
@ -0,0 +1,28 @@
|
|||||||
|
import { IconButton } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { layerDeleted } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
type Props = { layerId: string };
|
||||||
|
|
||||||
|
export const RPLayerDeleteButton = memo(({ layerId }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const deleteLayer = useCallback(() => {
|
||||||
|
dispatch(layerDeleted(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
return (
|
||||||
|
<IconButton
|
||||||
|
size="sm"
|
||||||
|
colorScheme="error"
|
||||||
|
aria-label={t('common.delete')}
|
||||||
|
tooltip={t('common.delete')}
|
||||||
|
icon={<PiTrashSimpleBold />}
|
||||||
|
onClick={deleteLayer}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerDeleteButton.displayName = 'RPLayerDeleteButton';
|
@ -0,0 +1,34 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
|
||||||
|
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useMemo } from 'react';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerIPAdapterList = memo(({ layerId }: Props) => {
|
||||||
|
const selectIPAdapterIds = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(layer, `Layer ${layerId} not found`);
|
||||||
|
return layer.ipAdapterIds;
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const ipAdapterIds = useAppSelector(selectIPAdapterIds);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" flexDir="column" gap={2}>
|
||||||
|
{ipAdapterIds.map((id, index) => (
|
||||||
|
<ControlAdapterConfig key={id} id={id} number={index + 1} />
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerIPAdapterList.displayName = 'RPLayerIPAdapterList';
|
@ -0,0 +1,87 @@
|
|||||||
|
import { Badge, Flex, Spacer } from '@invoke-ai/ui-library';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { rgbColorToString } from 'features/canvas/util/colorToString';
|
||||||
|
import { RPLayerColorPicker } from 'features/regionalPrompts/components/RPLayerColorPicker';
|
||||||
|
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
|
||||||
|
import { RPLayerIPAdapterList } from 'features/regionalPrompts/components/RPLayerIPAdapterList';
|
||||||
|
import { RPLayerMenu } from 'features/regionalPrompts/components/RPLayerMenu';
|
||||||
|
import { RPLayerNegativePrompt } from 'features/regionalPrompts/components/RPLayerNegativePrompt';
|
||||||
|
import { RPLayerPositivePrompt } from 'features/regionalPrompts/components/RPLayerPositivePrompt';
|
||||||
|
import RPLayerSettingsPopover from 'features/regionalPrompts/components/RPLayerSettingsPopover';
|
||||||
|
import { RPLayerVisibilityToggle } from 'features/regionalPrompts/components/RPLayerVisibilityToggle';
|
||||||
|
import {
|
||||||
|
isVectorMaskLayer,
|
||||||
|
layerSelected,
|
||||||
|
selectRegionalPromptsSlice,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
import { AddPromptButtons } from './AddPromptButtons';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerListItem = memo(({ layerId }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||||
|
return {
|
||||||
|
color: rgbColorToString(layer.previewColor),
|
||||||
|
hasPositivePrompt: layer.positivePrompt !== null,
|
||||||
|
hasNegativePrompt: layer.negativePrompt !== null,
|
||||||
|
hasIPAdapters: layer.ipAdapterIds.length > 0,
|
||||||
|
isSelected: layerId === regionalPrompts.present.selectedLayerId,
|
||||||
|
autoNegative: layer.autoNegative,
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const { autoNegative, color, hasPositivePrompt, hasNegativePrompt, hasIPAdapters, isSelected } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
const onClickCapture = useCallback(() => {
|
||||||
|
// Must be capture so that the layer is selected before deleting/resetting/etc
|
||||||
|
dispatch(layerSelected(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
gap={2}
|
||||||
|
onClickCapture={onClickCapture}
|
||||||
|
bg={isSelected ? color : 'base.800'}
|
||||||
|
ps={2}
|
||||||
|
borderRadius="base"
|
||||||
|
pe="1px"
|
||||||
|
py="1px"
|
||||||
|
cursor="pointer"
|
||||||
|
>
|
||||||
|
<Flex flexDir="column" gap={2} w="full" bg="base.850" p={2} borderRadius="base">
|
||||||
|
<Flex gap={3} alignItems="center">
|
||||||
|
<RPLayerVisibilityToggle layerId={layerId} />
|
||||||
|
<RPLayerColorPicker layerId={layerId} />
|
||||||
|
<Spacer />
|
||||||
|
{autoNegative === 'invert' && (
|
||||||
|
<Badge color="base.300" bg="transparent" borderWidth={1}>
|
||||||
|
{t('regionalPrompts.autoNegative')}
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
<RPLayerDeleteButton layerId={layerId} />
|
||||||
|
<RPLayerSettingsPopover layerId={layerId} />
|
||||||
|
<RPLayerMenu layerId={layerId} />
|
||||||
|
</Flex>
|
||||||
|
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons layerId={layerId} />}
|
||||||
|
{hasPositivePrompt && <RPLayerPositivePrompt layerId={layerId} />}
|
||||||
|
{hasNegativePrompt && <RPLayerNegativePrompt layerId={layerId} />}
|
||||||
|
{hasIPAdapters && <RPLayerIPAdapterList layerId={layerId} />}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerListItem.displayName = 'RPLayerListItem';
|
@ -0,0 +1,120 @@
|
|||||||
|
import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
isVectorMaskLayer,
|
||||||
|
layerDeleted,
|
||||||
|
layerMovedBackward,
|
||||||
|
layerMovedForward,
|
||||||
|
layerMovedToBack,
|
||||||
|
layerMovedToFront,
|
||||||
|
layerReset,
|
||||||
|
maskLayerIPAdapterAdded,
|
||||||
|
maskLayerNegativePromptChanged,
|
||||||
|
maskLayerPositivePromptChanged,
|
||||||
|
selectRegionalPromptsSlice,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import {
|
||||||
|
PiArrowCounterClockwiseBold,
|
||||||
|
PiArrowDownBold,
|
||||||
|
PiArrowLineDownBold,
|
||||||
|
PiArrowLineUpBold,
|
||||||
|
PiArrowUpBold,
|
||||||
|
PiDotsThreeVerticalBold,
|
||||||
|
PiPlusBold,
|
||||||
|
PiTrashSimpleBold,
|
||||||
|
} from 'react-icons/pi';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
type Props = { layerId: string };
|
||||||
|
|
||||||
|
export const RPLayerMenu = memo(({ layerId }: Props) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const selectValidActions = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||||
|
const layerIndex = regionalPrompts.present.layers.findIndex((l) => l.id === layerId);
|
||||||
|
const layerCount = regionalPrompts.present.layers.length;
|
||||||
|
return {
|
||||||
|
canAddPositivePrompt: layer.positivePrompt === null,
|
||||||
|
canAddNegativePrompt: layer.negativePrompt === null,
|
||||||
|
canMoveForward: layerIndex < layerCount - 1,
|
||||||
|
canMoveBackward: layerIndex > 0,
|
||||||
|
canMoveToFront: layerIndex < layerCount - 1,
|
||||||
|
canMoveToBack: layerIndex > 0,
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const validActions = useAppSelector(selectValidActions);
|
||||||
|
const addPositivePrompt = useCallback(() => {
|
||||||
|
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: '' }));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const addNegativePrompt = useCallback(() => {
|
||||||
|
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const addIPAdapter = useCallback(() => {
|
||||||
|
dispatch(maskLayerIPAdapterAdded(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const moveForward = useCallback(() => {
|
||||||
|
dispatch(layerMovedForward(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const moveToFront = useCallback(() => {
|
||||||
|
dispatch(layerMovedToFront(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const moveBackward = useCallback(() => {
|
||||||
|
dispatch(layerMovedBackward(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const moveToBack = useCallback(() => {
|
||||||
|
dispatch(layerMovedToBack(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const resetLayer = useCallback(() => {
|
||||||
|
dispatch(layerReset(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
const deleteLayer = useCallback(() => {
|
||||||
|
dispatch(layerDeleted(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
return (
|
||||||
|
<Menu>
|
||||||
|
<MenuButton as={IconButton} aria-label="Layer menu" size="sm" icon={<PiDotsThreeVerticalBold />} />
|
||||||
|
<MenuList>
|
||||||
|
<MenuItem onClick={addPositivePrompt} isDisabled={!validActions.canAddPositivePrompt} icon={<PiPlusBold />}>
|
||||||
|
{t('regionalPrompts.addPositivePrompt')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem onClick={addNegativePrompt} isDisabled={!validActions.canAddNegativePrompt} icon={<PiPlusBold />}>
|
||||||
|
{t('regionalPrompts.addNegativePrompt')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem onClick={addIPAdapter} icon={<PiPlusBold />}>
|
||||||
|
{t('regionalPrompts.addIPAdapter')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuDivider />
|
||||||
|
<MenuItem onClick={moveToFront} isDisabled={!validActions.canMoveToFront} icon={<PiArrowLineUpBold />}>
|
||||||
|
{t('regionalPrompts.moveToFront')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem onClick={moveForward} isDisabled={!validActions.canMoveForward} icon={<PiArrowUpBold />}>
|
||||||
|
{t('regionalPrompts.moveForward')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem onClick={moveBackward} isDisabled={!validActions.canMoveBackward} icon={<PiArrowDownBold />}>
|
||||||
|
{t('regionalPrompts.moveBackward')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem onClick={moveToBack} isDisabled={!validActions.canMoveToBack} icon={<PiArrowLineDownBold />}>
|
||||||
|
{t('regionalPrompts.moveToBack')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuDivider />
|
||||||
|
<MenuItem onClick={resetLayer} icon={<PiArrowCounterClockwiseBold />}>
|
||||||
|
{t('accessibility.reset')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem onClick={deleteLayer} icon={<PiTrashSimpleBold />} color="error.300">
|
||||||
|
{t('common.delete')}
|
||||||
|
</MenuItem>
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerMenu.displayName = 'RPLayerMenu';
|
@ -0,0 +1,58 @@
|
|||||||
|
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||||
|
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||||
|
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||||
|
import { usePrompt } from 'features/prompt/usePrompt';
|
||||||
|
import { RPLayerPromptDeleteButton } from 'features/regionalPrompts/components/RPLayerPromptDeleteButton';
|
||||||
|
import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||||
|
import { maskLayerNegativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback, useRef } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerNegativePrompt = memo(({ layerId }: Props) => {
|
||||||
|
const prompt = useLayerNegativePrompt(layerId);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(v: string) => {
|
||||||
|
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: v }));
|
||||||
|
},
|
||||||
|
[dispatch, layerId]
|
||||||
|
);
|
||||||
|
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||||
|
prompt,
|
||||||
|
textareaRef,
|
||||||
|
onChange: _onChange,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||||
|
<Box pos="relative" w="full">
|
||||||
|
<Textarea
|
||||||
|
id="prompt"
|
||||||
|
name="prompt"
|
||||||
|
ref={textareaRef}
|
||||||
|
value={prompt}
|
||||||
|
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||||
|
onChange={onChange}
|
||||||
|
onKeyDown={onKeyDown}
|
||||||
|
variant="darkFilled"
|
||||||
|
paddingRight={30}
|
||||||
|
fontSize="sm"
|
||||||
|
/>
|
||||||
|
<PromptOverlayButtonWrapper>
|
||||||
|
<RPLayerPromptDeleteButton layerId={layerId} polarity="negative" />
|
||||||
|
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||||
|
</PromptOverlayButtonWrapper>
|
||||||
|
</Box>
|
||||||
|
</PromptPopover>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerNegativePrompt.displayName = 'RPLayerNegativePrompt';
|
@ -0,0 +1,58 @@
|
|||||||
|
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||||
|
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||||
|
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||||
|
import { usePrompt } from 'features/prompt/usePrompt';
|
||||||
|
import { RPLayerPromptDeleteButton } from 'features/regionalPrompts/components/RPLayerPromptDeleteButton';
|
||||||
|
import { useLayerPositivePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||||
|
import { maskLayerPositivePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback, useRef } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerPositivePrompt = memo(({ layerId }: Props) => {
|
||||||
|
const prompt = useLayerPositivePrompt(layerId);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(v: string) => {
|
||||||
|
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: v }));
|
||||||
|
},
|
||||||
|
[dispatch, layerId]
|
||||||
|
);
|
||||||
|
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||||
|
prompt,
|
||||||
|
textareaRef,
|
||||||
|
onChange: _onChange,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||||
|
<Box pos="relative" w="full">
|
||||||
|
<Textarea
|
||||||
|
id="prompt"
|
||||||
|
name="prompt"
|
||||||
|
ref={textareaRef}
|
||||||
|
value={prompt}
|
||||||
|
placeholder={t('parameters.positivePromptPlaceholder')}
|
||||||
|
onChange={onChange}
|
||||||
|
onKeyDown={onKeyDown}
|
||||||
|
variant="darkFilled"
|
||||||
|
paddingRight={30}
|
||||||
|
minH={28}
|
||||||
|
/>
|
||||||
|
<PromptOverlayButtonWrapper>
|
||||||
|
<RPLayerPromptDeleteButton layerId={layerId} polarity="positive" />
|
||||||
|
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||||
|
</PromptOverlayButtonWrapper>
|
||||||
|
</Box>
|
||||||
|
</PromptPopover>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerPositivePrompt.displayName = 'RPLayerPositivePrompt';
|
@ -0,0 +1,38 @@
|
|||||||
|
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
maskLayerNegativePromptChanged,
|
||||||
|
maskLayerPositivePromptChanged,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
polarity: 'positive' | 'negative';
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerPromptDeleteButton = memo(({ layerId, polarity }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const onClick = useCallback(() => {
|
||||||
|
if (polarity === 'positive') {
|
||||||
|
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: null }));
|
||||||
|
} else {
|
||||||
|
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: null }));
|
||||||
|
}
|
||||||
|
}, [dispatch, layerId, polarity]);
|
||||||
|
return (
|
||||||
|
<Tooltip label={t('regionalPrompts.deletePrompt')}>
|
||||||
|
<IconButton
|
||||||
|
variant="promptOverlay"
|
||||||
|
aria-label={t('regionalPrompts.deletePrompt')}
|
||||||
|
icon={<PiTrashSimpleBold />}
|
||||||
|
onClick={onClick}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerPromptDeleteButton.displayName = 'RPLayerPromptDeleteButton';
|
@ -0,0 +1,53 @@
|
|||||||
|
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControlGroup,
|
||||||
|
IconButton,
|
||||||
|
Popover,
|
||||||
|
PopoverArrow,
|
||||||
|
PopoverBody,
|
||||||
|
PopoverContent,
|
||||||
|
PopoverTrigger,
|
||||||
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { RPLayerAutoNegativeCheckbox } from 'features/regionalPrompts/components/RPLayerAutoNegativeCheckbox';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiGearSixBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const formLabelProps: FormLabelProps = {
|
||||||
|
flexGrow: 1,
|
||||||
|
minW: 32,
|
||||||
|
};
|
||||||
|
|
||||||
|
const RPLayerSettingsPopover = ({ layerId }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Popover isLazy>
|
||||||
|
<PopoverTrigger>
|
||||||
|
<IconButton
|
||||||
|
tooltip={t('common.settingsLabel')}
|
||||||
|
aria-label={t('common.settingsLabel')}
|
||||||
|
size="sm"
|
||||||
|
icon={<PiGearSixBold />}
|
||||||
|
/>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent>
|
||||||
|
<PopoverArrow />
|
||||||
|
<PopoverBody>
|
||||||
|
<Flex direction="column" gap={2}>
|
||||||
|
<FormControlGroup formLabelProps={formLabelProps}>
|
||||||
|
<RPLayerAutoNegativeCheckbox layerId={layerId} />
|
||||||
|
</FormControlGroup>
|
||||||
|
</Flex>
|
||||||
|
</PopoverBody>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(RPLayerSettingsPopover);
|
@ -0,0 +1,34 @@
|
|||||||
|
import { IconButton } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useLayerIsVisible } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||||
|
import { layerVisibilityToggled } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiCheckBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RPLayerVisibilityToggle = memo(({ layerId }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const isVisible = useLayerIsVisible(layerId);
|
||||||
|
const onClick = useCallback(() => {
|
||||||
|
dispatch(layerVisibilityToggled(layerId));
|
||||||
|
}, [dispatch, layerId]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IconButton
|
||||||
|
size="sm"
|
||||||
|
aria-label={t('regionalPrompts.toggleVisibility')}
|
||||||
|
tooltip={t('regionalPrompts.toggleVisibility')}
|
||||||
|
variant="outline"
|
||||||
|
icon={isVisible ? <PiCheckBold /> : undefined}
|
||||||
|
onClick={onClick}
|
||||||
|
colorScheme="base"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RPLayerVisibilityToggle.displayName = 'RPLayerVisibilityToggle';
|
@ -0,0 +1,24 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import type { Meta, StoryObj } from '@storybook/react';
|
||||||
|
import { RegionalPromptsEditor } from 'features/regionalPrompts/components/RegionalPromptsEditor';
|
||||||
|
|
||||||
|
const meta: Meta<typeof RegionalPromptsEditor> = {
|
||||||
|
title: 'Feature/RegionalPrompts',
|
||||||
|
tags: ['autodocs'],
|
||||||
|
component: RegionalPromptsEditor,
|
||||||
|
};
|
||||||
|
|
||||||
|
export default meta;
|
||||||
|
type Story = StoryObj<typeof RegionalPromptsEditor>;
|
||||||
|
|
||||||
|
const Component = () => {
|
||||||
|
return (
|
||||||
|
<Flex w={1500} h={1500}>
|
||||||
|
<RegionalPromptsEditor />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const Default: Story = {
|
||||||
|
render: Component,
|
||||||
|
};
|
@ -0,0 +1,24 @@
|
|||||||
|
/* eslint-disable i18next/no-literal-string */
|
||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { RegionalPromptsToolbar } from 'features/regionalPrompts/components/RegionalPromptsToolbar';
|
||||||
|
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
export const RegionalPromptsEditor = memo(() => {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
position="relative"
|
||||||
|
flexDirection="column"
|
||||||
|
height="100%"
|
||||||
|
width="100%"
|
||||||
|
rowGap={4}
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
>
|
||||||
|
<RegionalPromptsToolbar />
|
||||||
|
<StageComponent />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RegionalPromptsEditor.displayName = 'RegionalPromptsEditor';
|
@ -0,0 +1,38 @@
|
|||||||
|
/* eslint-disable i18next/no-literal-string */
|
||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
|
import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButton';
|
||||||
|
import { DeleteAllLayersButton } from 'features/regionalPrompts/components/DeleteAllLayersButton';
|
||||||
|
import { RPLayerListItem } from 'features/regionalPrompts/components/RPLayerListItem';
|
||||||
|
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
const selectRPLayerIdsReversed = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
|
||||||
|
regionalPrompts.present.layers
|
||||||
|
.filter(isVectorMaskLayer)
|
||||||
|
.map((l) => l.id)
|
||||||
|
.reverse()
|
||||||
|
);
|
||||||
|
|
||||||
|
export const RegionalPromptsPanelContent = memo(() => {
|
||||||
|
const rpLayerIdsReversed = useAppSelector(selectRPLayerIdsReversed);
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" gap={4} w="full" h="full">
|
||||||
|
<Flex justifyContent="space-around">
|
||||||
|
<AddLayerButton />
|
||||||
|
<DeleteAllLayersButton />
|
||||||
|
</Flex>
|
||||||
|
<ScrollableContent>
|
||||||
|
<Flex flexDir="column" gap={4}>
|
||||||
|
{rpLayerIdsReversed.map((id) => (
|
||||||
|
<RPLayerListItem key={id} layerId={id} />
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</ScrollableContent>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RegionalPromptsPanelContent.displayName = 'RegionalPromptsPanelContent';
|
@ -0,0 +1,20 @@
|
|||||||
|
/* eslint-disable i18next/no-literal-string */
|
||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { BrushSize } from 'features/regionalPrompts/components/BrushSize';
|
||||||
|
import { GlobalMaskLayerOpacity } from 'features/regionalPrompts/components/GlobalMaskLayerOpacity';
|
||||||
|
import { ToolChooser } from 'features/regionalPrompts/components/ToolChooser';
|
||||||
|
import { UndoRedoButtonGroup } from 'features/regionalPrompts/components/UndoRedoButtonGroup';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
export const RegionalPromptsToolbar = memo(() => {
|
||||||
|
return (
|
||||||
|
<Flex gap={4}>
|
||||||
|
<BrushSize />
|
||||||
|
<GlobalMaskLayerOpacity />
|
||||||
|
<UndoRedoButtonGroup />
|
||||||
|
<ToolChooser />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RegionalPromptsToolbar.displayName = 'RegionalPromptsToolbar';
|
@ -0,0 +1,232 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { useStore } from '@nanostores/react';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useMouseEvents } from 'features/regionalPrompts/hooks/mouseEventHooks';
|
||||||
|
import {
|
||||||
|
$cursorPosition,
|
||||||
|
$isMouseOver,
|
||||||
|
$lastMouseDownPos,
|
||||||
|
$tool,
|
||||||
|
isVectorMaskLayer,
|
||||||
|
layerBboxChanged,
|
||||||
|
layerSelected,
|
||||||
|
layerTranslated,
|
||||||
|
selectRegionalPromptsSlice,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { debouncedRenderers, renderers as normalRenderers } from 'features/regionalPrompts/util/renderers';
|
||||||
|
import Konva from 'konva';
|
||||||
|
import type { IRect } from 'konva/lib/types';
|
||||||
|
import type { MutableRefObject } from 'react';
|
||||||
|
import { memo, useCallback, useLayoutEffect, useMemo, useRef, useState } from 'react';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
// This will log warnings when layers > 5 - maybe use `import.meta.env.MODE === 'development'` instead?
|
||||||
|
Konva.showWarnings = false;
|
||||||
|
|
||||||
|
const log = logger('regionalPrompts');
|
||||||
|
|
||||||
|
const selectSelectedLayerColor = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === regionalPrompts.present.selectedLayerId);
|
||||||
|
if (!layer) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${regionalPrompts.present.selectedLayerId} is not an RP layer`);
|
||||||
|
return layer.previewColor;
|
||||||
|
});
|
||||||
|
|
||||||
|
const useStageRenderer = (
|
||||||
|
stageRef: MutableRefObject<Konva.Stage>,
|
||||||
|
container: HTMLDivElement | null,
|
||||||
|
wrapper: HTMLDivElement | null,
|
||||||
|
asPreview: boolean
|
||||||
|
) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const width = useAppSelector((s) => s.generation.width);
|
||||||
|
const height = useAppSelector((s) => s.generation.height);
|
||||||
|
const state = useAppSelector((s) => s.regionalPrompts.present);
|
||||||
|
const tool = useStore($tool);
|
||||||
|
const { onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel } = useMouseEvents();
|
||||||
|
const cursorPosition = useStore($cursorPosition);
|
||||||
|
const lastMouseDownPos = useStore($lastMouseDownPos);
|
||||||
|
const isMouseOver = useStore($isMouseOver);
|
||||||
|
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
|
||||||
|
const layerIds = useMemo(() => state.layers.map((l) => l.id), [state.layers]);
|
||||||
|
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
|
||||||
|
|
||||||
|
const onLayerPosChanged = useCallback(
|
||||||
|
(layerId: string, x: number, y: number) => {
|
||||||
|
dispatch(layerTranslated({ layerId, x, y }));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onBboxChanged = useCallback(
|
||||||
|
(layerId: string, bbox: IRect | null) => {
|
||||||
|
dispatch(layerBboxChanged({ layerId, bbox }));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onBboxMouseDown = useCallback(
|
||||||
|
(layerId: string) => {
|
||||||
|
dispatch(layerSelected(layerId));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Initializing stage');
|
||||||
|
if (!container) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const stage = stageRef.current.container(container);
|
||||||
|
return () => {
|
||||||
|
log.trace('Cleaning up stage');
|
||||||
|
stage.destroy();
|
||||||
|
};
|
||||||
|
}, [container, stageRef]);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Adding stage listeners');
|
||||||
|
if (asPreview) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
stageRef.current.on('mousedown', onMouseDown);
|
||||||
|
stageRef.current.on('mouseup', onMouseUp);
|
||||||
|
stageRef.current.on('mousemove', onMouseMove);
|
||||||
|
stageRef.current.on('mouseenter', onMouseEnter);
|
||||||
|
stageRef.current.on('mouseleave', onMouseLeave);
|
||||||
|
stageRef.current.on('wheel', onMouseWheel);
|
||||||
|
const stage = stageRef.current;
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
log.trace('Cleaning up stage listeners');
|
||||||
|
stage.off('mousedown', onMouseDown);
|
||||||
|
stage.off('mouseup', onMouseUp);
|
||||||
|
stage.off('mousemove', onMouseMove);
|
||||||
|
stage.off('mouseenter', onMouseEnter);
|
||||||
|
stage.off('mouseleave', onMouseLeave);
|
||||||
|
stage.off('wheel', onMouseWheel);
|
||||||
|
};
|
||||||
|
}, [stageRef, asPreview, onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel]);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Updating stage dimensions');
|
||||||
|
if (!wrapper) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const stage = stageRef.current;
|
||||||
|
|
||||||
|
const fitStageToContainer = () => {
|
||||||
|
const newXScale = wrapper.offsetWidth / width;
|
||||||
|
const newYScale = wrapper.offsetHeight / height;
|
||||||
|
const newScale = Math.min(newXScale, newYScale, 1);
|
||||||
|
stage.width(width * newScale);
|
||||||
|
stage.height(height * newScale);
|
||||||
|
stage.scaleX(newScale);
|
||||||
|
stage.scaleY(newScale);
|
||||||
|
};
|
||||||
|
|
||||||
|
const resizeObserver = new ResizeObserver(fitStageToContainer);
|
||||||
|
resizeObserver.observe(wrapper);
|
||||||
|
fitStageToContainer();
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
resizeObserver.disconnect();
|
||||||
|
};
|
||||||
|
}, [stageRef, width, height, wrapper]);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Rendering tool preview');
|
||||||
|
if (asPreview) {
|
||||||
|
// Preview should not display tool
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
renderers.renderToolPreview(
|
||||||
|
stageRef.current,
|
||||||
|
tool,
|
||||||
|
selectedLayerIdColor,
|
||||||
|
state.globalMaskLayerOpacity,
|
||||||
|
cursorPosition,
|
||||||
|
lastMouseDownPos,
|
||||||
|
isMouseOver,
|
||||||
|
state.brushSize
|
||||||
|
);
|
||||||
|
}, [
|
||||||
|
asPreview,
|
||||||
|
stageRef,
|
||||||
|
tool,
|
||||||
|
selectedLayerIdColor,
|
||||||
|
state.globalMaskLayerOpacity,
|
||||||
|
cursorPosition,
|
||||||
|
lastMouseDownPos,
|
||||||
|
isMouseOver,
|
||||||
|
state.brushSize,
|
||||||
|
renderers,
|
||||||
|
]);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Rendering layers');
|
||||||
|
renderers.renderLayers(stageRef.current, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||||
|
}, [stageRef, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged, renderers]);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Rendering bbox');
|
||||||
|
if (asPreview) {
|
||||||
|
// Preview should not display bboxes
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
renderers.renderBbox(stageRef.current, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
|
||||||
|
}, [stageRef, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderers]);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Rendering background');
|
||||||
|
if (asPreview) {
|
||||||
|
// The preview should not have a background
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
renderers.renderBackground(stageRef.current, width, height);
|
||||||
|
}, [stageRef, asPreview, width, height, renderers]);
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
log.trace('Arranging layers');
|
||||||
|
renderers.arrangeLayers(stageRef.current, layerIds);
|
||||||
|
}, [stageRef, layerIds, renderers]);
|
||||||
|
};
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
asPreview?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const StageComponent = memo(({ asPreview = false }: Props) => {
|
||||||
|
const stageRef = useRef<Konva.Stage>(
|
||||||
|
new Konva.Stage({
|
||||||
|
container: document.createElement('div'), // We will overwrite this shortly...
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const [container, setContainer] = useState<HTMLDivElement | null>(null);
|
||||||
|
const [wrapper, setWrapper] = useState<HTMLDivElement | null>(null);
|
||||||
|
|
||||||
|
const containerRef = useCallback((el: HTMLDivElement | null) => {
|
||||||
|
setContainer(el);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const wrapperRef = useCallback((el: HTMLDivElement | null) => {
|
||||||
|
setWrapper(el);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useStageRenderer(stageRef, container, wrapper, asPreview);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex overflow="hidden" w="full" h="full">
|
||||||
|
<Flex ref={wrapperRef} w="full" h="full" alignItems="center" justifyContent="center">
|
||||||
|
<Flex ref={containerRef} tabIndex={-1} bg="base.850" />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
StageComponent.displayName = 'StageComponent';
|
@ -0,0 +1,89 @@
|
|||||||
|
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
|
||||||
|
import { useStore } from '@nanostores/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
$tool,
|
||||||
|
layerAdded,
|
||||||
|
selectedLayerDeleted,
|
||||||
|
selectedLayerReset,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiArrowsOutCardinalBold, PiEraserBold, PiPaintBrushBold, PiRectangleBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
export const ToolChooser: React.FC = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const isDisabled = useAppSelector((s) => s.regionalPrompts.present.layers.length === 0);
|
||||||
|
const tool = useStore($tool);
|
||||||
|
|
||||||
|
const setToolToBrush = useCallback(() => {
|
||||||
|
$tool.set('brush');
|
||||||
|
}, []);
|
||||||
|
useHotkeys('b', setToolToBrush, { enabled: !isDisabled }, [isDisabled]);
|
||||||
|
const setToolToEraser = useCallback(() => {
|
||||||
|
$tool.set('eraser');
|
||||||
|
}, []);
|
||||||
|
useHotkeys('e', setToolToEraser, { enabled: !isDisabled }, [isDisabled]);
|
||||||
|
const setToolToRect = useCallback(() => {
|
||||||
|
$tool.set('rect');
|
||||||
|
}, []);
|
||||||
|
useHotkeys('u', setToolToRect, { enabled: !isDisabled }, [isDisabled]);
|
||||||
|
const setToolToMove = useCallback(() => {
|
||||||
|
$tool.set('move');
|
||||||
|
}, []);
|
||||||
|
useHotkeys('v', setToolToMove, { enabled: !isDisabled }, [isDisabled]);
|
||||||
|
|
||||||
|
const resetSelectedLayer = useCallback(() => {
|
||||||
|
dispatch(selectedLayerReset());
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys('shift+c', resetSelectedLayer);
|
||||||
|
|
||||||
|
const addLayer = useCallback(() => {
|
||||||
|
dispatch(layerAdded('vector_mask_layer'));
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys('shift+a', addLayer);
|
||||||
|
|
||||||
|
const deleteSelectedLayer = useCallback(() => {
|
||||||
|
dispatch(selectedLayerDeleted());
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys('shift+d', deleteSelectedLayer);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ButtonGroup isAttached>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`${t('unifiedCanvas.brush')} (B)`}
|
||||||
|
tooltip={`${t('unifiedCanvas.brush')} (B)`}
|
||||||
|
icon={<PiPaintBrushBold />}
|
||||||
|
variant={tool === 'brush' ? 'solid' : 'outline'}
|
||||||
|
onClick={setToolToBrush}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`${t('unifiedCanvas.eraser')} (E)`}
|
||||||
|
tooltip={`${t('unifiedCanvas.eraser')} (E)`}
|
||||||
|
icon={<PiEraserBold />}
|
||||||
|
variant={tool === 'eraser' ? 'solid' : 'outline'}
|
||||||
|
onClick={setToolToEraser}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`${t('regionalPrompts.rectangle')} (U)`}
|
||||||
|
tooltip={`${t('regionalPrompts.rectangle')} (U)`}
|
||||||
|
icon={<PiRectangleBold />}
|
||||||
|
variant={tool === 'rect' ? 'solid' : 'outline'}
|
||||||
|
onClick={setToolToRect}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`${t('unifiedCanvas.move')} (V)`}
|
||||||
|
tooltip={`${t('unifiedCanvas.move')} (V)`}
|
||||||
|
icon={<PiArrowsOutCardinalBold />}
|
||||||
|
variant={tool === 'move' ? 'solid' : 'outline'}
|
||||||
|
onClick={setToolToMove}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
</ButtonGroup>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,49 @@
|
|||||||
|
/* eslint-disable i18next/no-literal-string */
|
||||||
|
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { redo, undo } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiArrowClockwiseBold, PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
export const UndoRedoButtonGroup = memo(() => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const mayUndo = useAppSelector((s) => s.regionalPrompts.past.length > 0);
|
||||||
|
const handleUndo = useCallback(() => {
|
||||||
|
dispatch(undo());
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys(['meta+z', 'ctrl+z'], handleUndo, { enabled: mayUndo, preventDefault: true }, [mayUndo, handleUndo]);
|
||||||
|
|
||||||
|
const mayRedo = useAppSelector((s) => s.regionalPrompts.future.length > 0);
|
||||||
|
const handleRedo = useCallback(() => {
|
||||||
|
dispatch(redo());
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], handleRedo, { enabled: mayRedo, preventDefault: true }, [
|
||||||
|
mayRedo,
|
||||||
|
handleRedo,
|
||||||
|
]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ButtonGroup>
|
||||||
|
<IconButton
|
||||||
|
aria-label={t('unifiedCanvas.undo')}
|
||||||
|
tooltip={t('unifiedCanvas.undo')}
|
||||||
|
onClick={handleUndo}
|
||||||
|
icon={<PiArrowCounterClockwiseBold />}
|
||||||
|
isDisabled={!mayUndo}
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
aria-label={t('unifiedCanvas.redo')}
|
||||||
|
tooltip={t('unifiedCanvas.redo')}
|
||||||
|
onClick={handleRedo}
|
||||||
|
icon={<PiArrowClockwiseBold />}
|
||||||
|
isDisabled={!mayRedo}
|
||||||
|
/>
|
||||||
|
</ButtonGroup>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
UndoRedoButtonGroup.displayName = 'UndoRedoButtonGroup';
|
@ -0,0 +1,49 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
export const useLayerPositivePrompt = (layerId: string) => {
|
||||||
|
const selectLayer = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||||
|
assert(layer.positivePrompt !== null, `Layer ${layerId} does not have a positive prompt`);
|
||||||
|
return layer.positivePrompt;
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const prompt = useAppSelector(selectLayer);
|
||||||
|
return prompt;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useLayerNegativePrompt = (layerId: string) => {
|
||||||
|
const selectLayer = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||||
|
assert(layer.negativePrompt !== null, `Layer ${layerId} does not have a negative prompt`);
|
||||||
|
return layer.negativePrompt;
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const prompt = useAppSelector(selectLayer);
|
||||||
|
return prompt;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useLayerIsVisible = (layerId: string) => {
|
||||||
|
const selectLayer = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||||
|
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||||
|
return layer.isVisible;
|
||||||
|
}),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const isVisible = useAppSelector(selectLayer);
|
||||||
|
return isVisible;
|
||||||
|
};
|
@ -0,0 +1,217 @@
|
|||||||
|
import { $ctrl, $meta } from '@invoke-ai/ui-library';
|
||||||
|
import { useStore } from '@nanostores/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
|
||||||
|
import {
|
||||||
|
$cursorPosition,
|
||||||
|
$isMouseDown,
|
||||||
|
$isMouseOver,
|
||||||
|
$lastMouseDownPos,
|
||||||
|
$tool,
|
||||||
|
brushSizeChanged,
|
||||||
|
maskLayerLineAdded,
|
||||||
|
maskLayerPointsAdded,
|
||||||
|
maskLayerRectAdded,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import type Konva from 'konva';
|
||||||
|
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||||
|
import type { Vector2d } from 'konva/lib/types';
|
||||||
|
import { useCallback, useRef } from 'react';
|
||||||
|
|
||||||
|
const getIsFocused = (stage: Konva.Stage) => {
|
||||||
|
return stage.container().contains(document.activeElement);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getScaledFlooredCursorPosition = (stage: Konva.Stage) => {
|
||||||
|
const pointerPosition = stage.getPointerPosition();
|
||||||
|
const stageTransform = stage.getAbsoluteTransform().copy();
|
||||||
|
if (!pointerPosition || !stageTransform) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
||||||
|
return {
|
||||||
|
x: Math.floor(scaledCursorPosition.x),
|
||||||
|
y: Math.floor(scaledCursorPosition.y),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const syncCursorPos = (stage: Konva.Stage): Vector2d | null => {
|
||||||
|
const pos = getScaledFlooredCursorPosition(stage);
|
||||||
|
if (!pos) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
$cursorPosition.set(pos);
|
||||||
|
return pos;
|
||||||
|
};
|
||||||
|
|
||||||
|
const BRUSH_SPACING = 20;
|
||||||
|
|
||||||
|
export const useMouseEvents = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const selectedLayerId = useAppSelector((s) => s.regionalPrompts.present.selectedLayerId);
|
||||||
|
const tool = useStore($tool);
|
||||||
|
const lastCursorPosRef = useRef<[number, number] | null>(null);
|
||||||
|
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||||
|
const brushSize = useAppSelector((s) => s.regionalPrompts.present.brushSize);
|
||||||
|
|
||||||
|
const onMouseDown = useCallback(
|
||||||
|
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||||
|
const stage = e.target.getStage();
|
||||||
|
if (!stage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const pos = syncCursorPos(stage);
|
||||||
|
if (!pos) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
$isMouseDown.set(true);
|
||||||
|
$lastMouseDownPos.set(pos);
|
||||||
|
if (!selectedLayerId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (tool === 'brush' || tool === 'eraser') {
|
||||||
|
dispatch(
|
||||||
|
maskLayerLineAdded({
|
||||||
|
layerId: selectedLayerId,
|
||||||
|
points: [pos.x, pos.y, pos.x, pos.y],
|
||||||
|
tool,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[dispatch, selectedLayerId, tool]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onMouseUp = useCallback(
|
||||||
|
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||||
|
const stage = e.target.getStage();
|
||||||
|
if (!stage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
$isMouseDown.set(false);
|
||||||
|
const pos = $cursorPosition.get();
|
||||||
|
const lastPos = $lastMouseDownPos.get();
|
||||||
|
const tool = $tool.get();
|
||||||
|
if (pos && lastPos && selectedLayerId && tool === 'rect') {
|
||||||
|
dispatch(
|
||||||
|
maskLayerRectAdded({
|
||||||
|
layerId: selectedLayerId,
|
||||||
|
rect: {
|
||||||
|
x: Math.min(pos.x, lastPos.x),
|
||||||
|
y: Math.min(pos.y, lastPos.y),
|
||||||
|
width: Math.abs(pos.x - lastPos.x),
|
||||||
|
height: Math.abs(pos.y - lastPos.y),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
$lastMouseDownPos.set(null);
|
||||||
|
},
|
||||||
|
[dispatch, selectedLayerId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onMouseMove = useCallback(
|
||||||
|
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||||
|
const stage = e.target.getStage();
|
||||||
|
if (!stage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const pos = syncCursorPos(stage);
|
||||||
|
if (!pos || !selectedLayerId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (getIsFocused(stage) && $isMouseOver.get() && $isMouseDown.get() && (tool === 'brush' || tool === 'eraser')) {
|
||||||
|
if (lastCursorPosRef.current) {
|
||||||
|
// Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
|
||||||
|
if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < BRUSH_SPACING) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastCursorPosRef.current = [pos.x, pos.y];
|
||||||
|
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[dispatch, selectedLayerId, tool]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onMouseLeave = useCallback(
|
||||||
|
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||||
|
const stage = e.target.getStage();
|
||||||
|
if (!stage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const pos = syncCursorPos(stage);
|
||||||
|
if (
|
||||||
|
pos &&
|
||||||
|
selectedLayerId &&
|
||||||
|
getIsFocused(stage) &&
|
||||||
|
$isMouseOver.get() &&
|
||||||
|
$isMouseDown.get() &&
|
||||||
|
(tool === 'brush' || tool === 'eraser')
|
||||||
|
) {
|
||||||
|
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
|
||||||
|
}
|
||||||
|
$isMouseOver.set(false);
|
||||||
|
$isMouseDown.set(false);
|
||||||
|
$cursorPosition.set(null);
|
||||||
|
},
|
||||||
|
[selectedLayerId, tool, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onMouseEnter = useCallback(
|
||||||
|
(e: KonvaEventObject<MouseEvent>) => {
|
||||||
|
const stage = e.target.getStage();
|
||||||
|
if (!stage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
$isMouseOver.set(true);
|
||||||
|
const pos = syncCursorPos(stage);
|
||||||
|
if (!pos) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!getIsFocused(stage)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (e.evt.buttons !== 1) {
|
||||||
|
$isMouseDown.set(false);
|
||||||
|
} else {
|
||||||
|
$isMouseDown.set(true);
|
||||||
|
if (!selectedLayerId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (tool === 'brush' || tool === 'eraser') {
|
||||||
|
dispatch(
|
||||||
|
maskLayerLineAdded({
|
||||||
|
layerId: selectedLayerId,
|
||||||
|
points: [pos.x, pos.y, pos.x, pos.y],
|
||||||
|
tool,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[dispatch, selectedLayerId, tool]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onMouseWheel = useCallback(
|
||||||
|
(e: KonvaEventObject<WheelEvent>) => {
|
||||||
|
e.evt.preventDefault();
|
||||||
|
|
||||||
|
// checking for ctrl key is pressed or not,
|
||||||
|
// so that brush size can be controlled using ctrl + scroll up/down
|
||||||
|
|
||||||
|
// Invert the delta if the property is set to true
|
||||||
|
let delta = e.evt.deltaY;
|
||||||
|
if (shouldInvertBrushSizeScrollDirection) {
|
||||||
|
delta = -delta;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($ctrl.get() || $meta.get()) {
|
||||||
|
dispatch(brushSizeChanged(calculateNewBrushSize(brushSize, delta)));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[shouldInvertBrushSizeScrollDirection, brushSize, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return { onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel };
|
||||||
|
};
|
@ -0,0 +1,30 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||||
|
if (!regionalPrompts.present.isEnabled) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
const validLayers = regionalPrompts.present.layers
|
||||||
|
.filter((l) => l.isVisible)
|
||||||
|
.filter((l) => {
|
||||||
|
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
|
||||||
|
const hasAtLeastOneImagePrompt = l.ipAdapterIds.length > 0;
|
||||||
|
return hasTextPrompt || hasAtLeastOneImagePrompt;
|
||||||
|
});
|
||||||
|
|
||||||
|
return validLayers.length;
|
||||||
|
});
|
||||||
|
|
||||||
|
export const useRegionalControlTitle = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const validLayerCount = useAppSelector(selectValidLayerCount);
|
||||||
|
const title = useMemo(() => {
|
||||||
|
const suffix = validLayerCount > 0 ? ` (${validLayerCount})` : '';
|
||||||
|
return `${t('regionalPrompts.regionalControl')}${suffix}`;
|
||||||
|
}, [t, validLayerCount]);
|
||||||
|
return title;
|
||||||
|
};
|
@ -0,0 +1,496 @@
|
|||||||
|
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||||
|
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||||
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
|
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
|
||||||
|
import { controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { atom } from 'nanostores';
|
||||||
|
import type { RgbColor } from 'react-colorful';
|
||||||
|
import type { UndoableOptions } from 'redux-undo';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
type DrawingTool = 'brush' | 'eraser';
|
||||||
|
|
||||||
|
export type Tool = DrawingTool | 'move' | 'rect';
|
||||||
|
|
||||||
|
export type VectorMaskLine = {
|
||||||
|
id: string;
|
||||||
|
type: 'vector_mask_line';
|
||||||
|
tool: DrawingTool;
|
||||||
|
strokeWidth: number;
|
||||||
|
points: number[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type VectorMaskRect = {
|
||||||
|
id: string;
|
||||||
|
type: 'vector_mask_rect';
|
||||||
|
x: number;
|
||||||
|
y: number;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
type LayerBase = {
|
||||||
|
id: string;
|
||||||
|
x: number;
|
||||||
|
y: number;
|
||||||
|
bbox: IRect | null;
|
||||||
|
bboxNeedsUpdate: boolean;
|
||||||
|
isVisible: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
type MaskLayerBase = LayerBase & {
|
||||||
|
positivePrompt: string | null;
|
||||||
|
negativePrompt: string | null; // Up to one text prompt per mask
|
||||||
|
ipAdapterIds: string[]; // Any number of image prompts
|
||||||
|
previewColor: RgbColor;
|
||||||
|
autoNegative: ParameterAutoNegative;
|
||||||
|
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
|
||||||
|
};
|
||||||
|
|
||||||
|
export type VectorMaskLayer = MaskLayerBase & {
|
||||||
|
type: 'vector_mask_layer';
|
||||||
|
objects: (VectorMaskLine | VectorMaskRect)[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type Layer = VectorMaskLayer;
|
||||||
|
|
||||||
|
type RegionalPromptsState = {
|
||||||
|
_version: 1;
|
||||||
|
selectedLayerId: string | null;
|
||||||
|
layers: Layer[];
|
||||||
|
brushSize: number;
|
||||||
|
globalMaskLayerOpacity: number;
|
||||||
|
isEnabled: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const initialRegionalPromptsState: RegionalPromptsState = {
|
||||||
|
_version: 1,
|
||||||
|
selectedLayerId: null,
|
||||||
|
brushSize: 100,
|
||||||
|
layers: [],
|
||||||
|
globalMaskLayerOpacity: 0.5, // this globally changes all mask layers' opacity
|
||||||
|
isEnabled: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
|
||||||
|
export const isVectorMaskLayer = (layer?: Layer): layer is VectorMaskLayer => layer?.type === 'vector_mask_layer';
|
||||||
|
const resetLayer = (layer: VectorMaskLayer) => {
|
||||||
|
layer.objects = [];
|
||||||
|
layer.bbox = null;
|
||||||
|
layer.isVisible = true;
|
||||||
|
layer.needsPixelBbox = false;
|
||||||
|
layer.bboxNeedsUpdate = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const regionalPromptsSlice = createSlice({
|
||||||
|
name: 'regionalPrompts',
|
||||||
|
initialState: initialRegionalPromptsState,
|
||||||
|
reducers: {
|
||||||
|
//#region All Layers
|
||||||
|
layerAdded: {
|
||||||
|
reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => {
|
||||||
|
const kind = action.payload;
|
||||||
|
if (action.payload === 'vector_mask_layer') {
|
||||||
|
const lastColor = state.layers[state.layers.length - 1]?.previewColor;
|
||||||
|
const previewColor = LayerColors.next(lastColor);
|
||||||
|
const layer: VectorMaskLayer = {
|
||||||
|
id: getVectorMaskLayerId(action.meta.uuid),
|
||||||
|
type: kind,
|
||||||
|
isVisible: true,
|
||||||
|
bbox: null,
|
||||||
|
bboxNeedsUpdate: false,
|
||||||
|
objects: [],
|
||||||
|
previewColor,
|
||||||
|
x: 0,
|
||||||
|
y: 0,
|
||||||
|
autoNegative: 'invert',
|
||||||
|
needsPixelBbox: false,
|
||||||
|
positivePrompt: '',
|
||||||
|
negativePrompt: null,
|
||||||
|
ipAdapterIds: [],
|
||||||
|
};
|
||||||
|
state.layers.push(layer);
|
||||||
|
state.selectedLayerId = layer.id;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }),
|
||||||
|
},
|
||||||
|
layerSelected: (state, action: PayloadAction<string>) => {
|
||||||
|
const layer = state.layers.find((l) => l.id === action.payload);
|
||||||
|
if (layer) {
|
||||||
|
state.selectedLayerId = layer.id;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
layerVisibilityToggled: (state, action: PayloadAction<string>) => {
|
||||||
|
const layer = state.layers.find((l) => l.id === action.payload);
|
||||||
|
if (layer) {
|
||||||
|
layer.isVisible = !layer.isVisible;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
layerTranslated: (state, action: PayloadAction<{ layerId: string; x: number; y: number }>) => {
|
||||||
|
const { layerId, x, y } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
layer.x = x;
|
||||||
|
layer.y = y;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
layerBboxChanged: (state, action: PayloadAction<{ layerId: string; bbox: IRect | null }>) => {
|
||||||
|
const { layerId, bbox } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
layer.bbox = bbox;
|
||||||
|
layer.bboxNeedsUpdate = false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
layerReset: (state, action: PayloadAction<string>) => {
|
||||||
|
const layer = state.layers.find((l) => l.id === action.payload);
|
||||||
|
if (layer) {
|
||||||
|
resetLayer(layer);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
layerDeleted: (state, action: PayloadAction<string>) => {
|
||||||
|
state.layers = state.layers.filter((l) => l.id !== action.payload);
|
||||||
|
state.selectedLayerId = state.layers[0]?.id ?? null;
|
||||||
|
},
|
||||||
|
layerMovedForward: (state, action: PayloadAction<string>) => {
|
||||||
|
const cb = (l: Layer) => l.id === action.payload;
|
||||||
|
moveForward(state.layers, cb);
|
||||||
|
},
|
||||||
|
layerMovedToFront: (state, action: PayloadAction<string>) => {
|
||||||
|
const cb = (l: Layer) => l.id === action.payload;
|
||||||
|
// Because the layers are in reverse order, moving to the front is equivalent to moving to the back
|
||||||
|
moveToBack(state.layers, cb);
|
||||||
|
},
|
||||||
|
layerMovedBackward: (state, action: PayloadAction<string>) => {
|
||||||
|
const cb = (l: Layer) => l.id === action.payload;
|
||||||
|
moveBackward(state.layers, cb);
|
||||||
|
},
|
||||||
|
layerMovedToBack: (state, action: PayloadAction<string>) => {
|
||||||
|
const cb = (l: Layer) => l.id === action.payload;
|
||||||
|
// Because the layers are in reverse order, moving to the back is equivalent to moving to the front
|
||||||
|
moveToFront(state.layers, cb);
|
||||||
|
},
|
||||||
|
allLayersDeleted: (state) => {
|
||||||
|
state.layers = [];
|
||||||
|
state.selectedLayerId = null;
|
||||||
|
},
|
||||||
|
selectedLayerReset: (state) => {
|
||||||
|
const layer = state.layers.find((l) => l.id === state.selectedLayerId);
|
||||||
|
if (layer) {
|
||||||
|
resetLayer(layer);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
selectedLayerDeleted: (state) => {
|
||||||
|
state.layers = state.layers.filter((l) => l.id !== state.selectedLayerId);
|
||||||
|
state.selectedLayerId = state.layers[0]?.id ?? null;
|
||||||
|
},
|
||||||
|
//#endregion
|
||||||
|
|
||||||
|
//#region Mask Layers
|
||||||
|
maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
|
||||||
|
const { layerId, prompt } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
layer.positivePrompt = prompt;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
|
||||||
|
const { layerId, prompt } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
layer.negativePrompt = prompt;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
maskLayerIPAdapterAdded: {
|
||||||
|
reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => {
|
||||||
|
const layer = state.layers.find((l) => l.id === action.payload);
|
||||||
|
if (layer) {
|
||||||
|
layer.ipAdapterIds.push(action.meta.uuid);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
prepare: (payload: string) => ({ payload, meta: { uuid: uuidv4() } }),
|
||||||
|
},
|
||||||
|
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
|
||||||
|
const { layerId, color } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
layer.previewColor = color;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
maskLayerLineAdded: {
|
||||||
|
reducer: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<
|
||||||
|
{ layerId: string; points: [number, number, number, number]; tool: DrawingTool },
|
||||||
|
string,
|
||||||
|
{ uuid: string }
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { layerId, points, tool } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
const lineId = getVectorMaskLayerLineId(layer.id, action.meta.uuid);
|
||||||
|
layer.objects.push({
|
||||||
|
type: 'vector_mask_line',
|
||||||
|
tool: tool,
|
||||||
|
id: lineId,
|
||||||
|
// Points must be offset by the layer's x and y coordinates
|
||||||
|
// TODO: Handle this in the event listener?
|
||||||
|
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
|
||||||
|
strokeWidth: state.brushSize,
|
||||||
|
});
|
||||||
|
layer.bboxNeedsUpdate = true;
|
||||||
|
if (!layer.needsPixelBbox && tool === 'eraser') {
|
||||||
|
layer.needsPixelBbox = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({
|
||||||
|
payload,
|
||||||
|
meta: { uuid: uuidv4() },
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
|
||||||
|
const { layerId, point } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
const lastLine = layer.objects.findLast(isLine);
|
||||||
|
if (!lastLine) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Points must be offset by the layer's x and y coordinates
|
||||||
|
// TODO: Handle this in the event listener
|
||||||
|
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
|
||||||
|
layer.bboxNeedsUpdate = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
maskLayerRectAdded: {
|
||||||
|
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect }, string, { uuid: string }>) => {
|
||||||
|
const { layerId, rect } = action.payload;
|
||||||
|
if (rect.height === 0 || rect.width === 0) {
|
||||||
|
// Ignore zero-area rectangles
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
const id = getVectorMaskLayerRectId(layer.id, action.meta.uuid);
|
||||||
|
layer.objects.push({
|
||||||
|
type: 'vector_mask_rect',
|
||||||
|
id,
|
||||||
|
x: rect.x - layer.x,
|
||||||
|
y: rect.y - layer.y,
|
||||||
|
width: rect.width,
|
||||||
|
height: rect.height,
|
||||||
|
});
|
||||||
|
layer.bboxNeedsUpdate = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload, meta: { uuid: uuidv4() } }),
|
||||||
|
},
|
||||||
|
maskLayerAutoNegativeChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
|
||||||
|
) => {
|
||||||
|
const { layerId, autoNegative } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (layer) {
|
||||||
|
layer.autoNegative = autoNegative;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
//#endregion
|
||||||
|
|
||||||
|
//#region General
|
||||||
|
brushSizeChanged: (state, action: PayloadAction<number>) => {
|
||||||
|
state.brushSize = action.payload;
|
||||||
|
},
|
||||||
|
globalMaskLayerOpacityChanged: (state, action: PayloadAction<number>) => {
|
||||||
|
state.globalMaskLayerOpacity = action.payload;
|
||||||
|
},
|
||||||
|
isEnabledChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.isEnabled = action.payload;
|
||||||
|
},
|
||||||
|
undo: (state) => {
|
||||||
|
// Invalidate the bbox for all layers to prevent stale bboxes
|
||||||
|
for (const layer of state.layers) {
|
||||||
|
layer.bboxNeedsUpdate = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
redo: (state) => {
|
||||||
|
// Invalidate the bbox for all layers to prevent stale bboxes
|
||||||
|
for (const layer of state.layers) {
|
||||||
|
layer.bboxNeedsUpdate = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
//#endregion
|
||||||
|
},
|
||||||
|
extraReducers(builder) {
|
||||||
|
builder.addCase(controlAdapterRemoved, (state, action) => {
|
||||||
|
for (const layer of state.layers) {
|
||||||
|
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class is used to cycle through a set of colors for the prompt region layers.
|
||||||
|
*/
|
||||||
|
class LayerColors {
|
||||||
|
static COLORS: RgbColor[] = [
|
||||||
|
{ r: 121, g: 157, b: 219 }, // rgb(121, 157, 219)
|
||||||
|
{ r: 131, g: 214, b: 131 }, // rgb(131, 214, 131)
|
||||||
|
{ r: 250, g: 225, b: 80 }, // rgb(250, 225, 80)
|
||||||
|
{ r: 220, g: 144, b: 101 }, // rgb(220, 144, 101)
|
||||||
|
{ r: 224, g: 117, b: 117 }, // rgb(224, 117, 117)
|
||||||
|
{ r: 213, g: 139, b: 202 }, // rgb(213, 139, 202)
|
||||||
|
{ r: 161, g: 120, b: 214 }, // rgb(161, 120, 214)
|
||||||
|
];
|
||||||
|
static i = this.COLORS.length - 1;
|
||||||
|
/**
|
||||||
|
* Get the next color in the sequence. If a known color is provided, the next color will be the one after it.
|
||||||
|
*/
|
||||||
|
static next(currentColor?: RgbColor): RgbColor {
|
||||||
|
if (currentColor) {
|
||||||
|
const i = this.COLORS.findIndex((c) => isEqual(c, currentColor));
|
||||||
|
if (i !== -1) {
|
||||||
|
this.i = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.i = (this.i + 1) % this.COLORS.length;
|
||||||
|
const color = this.COLORS[this.i];
|
||||||
|
assert(color);
|
||||||
|
return color;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const {
|
||||||
|
// All layer actions
|
||||||
|
layerAdded,
|
||||||
|
layerDeleted,
|
||||||
|
layerMovedBackward,
|
||||||
|
layerMovedForward,
|
||||||
|
layerMovedToBack,
|
||||||
|
layerMovedToFront,
|
||||||
|
layerReset,
|
||||||
|
layerSelected,
|
||||||
|
layerTranslated,
|
||||||
|
layerBboxChanged,
|
||||||
|
layerVisibilityToggled,
|
||||||
|
allLayersDeleted,
|
||||||
|
selectedLayerReset,
|
||||||
|
selectedLayerDeleted,
|
||||||
|
// Mask layer actions
|
||||||
|
maskLayerLineAdded,
|
||||||
|
maskLayerPointsAdded,
|
||||||
|
maskLayerRectAdded,
|
||||||
|
maskLayerNegativePromptChanged,
|
||||||
|
maskLayerPositivePromptChanged,
|
||||||
|
maskLayerIPAdapterAdded,
|
||||||
|
maskLayerAutoNegativeChanged,
|
||||||
|
maskLayerPreviewColorChanged,
|
||||||
|
// General actions
|
||||||
|
brushSizeChanged,
|
||||||
|
globalMaskLayerOpacityChanged,
|
||||||
|
undo,
|
||||||
|
redo,
|
||||||
|
} = regionalPromptsSlice.actions;
|
||||||
|
|
||||||
|
export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts;
|
||||||
|
|
||||||
|
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||||
|
const migrateRegionalPromptsState = (state: any): any => {
|
||||||
|
return state;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const $isMouseDown = atom(false);
|
||||||
|
export const $isMouseOver = atom(false);
|
||||||
|
export const $lastMouseDownPos = atom<Vector2d | null>(null);
|
||||||
|
export const $tool = atom<Tool>('brush');
|
||||||
|
export const $cursorPosition = atom<Vector2d | null>(null);
|
||||||
|
|
||||||
|
// IDs for singleton Konva layers and objects
|
||||||
|
export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer';
|
||||||
|
export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group';
|
||||||
|
export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill';
|
||||||
|
export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner';
|
||||||
|
export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer';
|
||||||
|
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
|
||||||
|
export const BACKGROUND_LAYER_ID = 'background_layer';
|
||||||
|
export const BACKGROUND_RECT_ID = 'background_layer.rect';
|
||||||
|
|
||||||
|
// Names (aka classes) for Konva layers and objects
|
||||||
|
export const VECTOR_MASK_LAYER_NAME = 'vector_mask_layer';
|
||||||
|
export const VECTOR_MASK_LAYER_LINE_NAME = 'vector_mask_layer.line';
|
||||||
|
export const VECTOR_MASK_LAYER_OBJECT_GROUP_NAME = 'vector_mask_layer.object_group';
|
||||||
|
export const VECTOR_MASK_LAYER_RECT_NAME = 'vector_mask_layer.rect';
|
||||||
|
export const LAYER_BBOX_NAME = 'layer.bbox';
|
||||||
|
|
||||||
|
// Getters for non-singleton layer and object IDs
|
||||||
|
const getVectorMaskLayerId = (layerId: string) => `${VECTOR_MASK_LAYER_NAME}_${layerId}`;
|
||||||
|
const getVectorMaskLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
|
||||||
|
const getVectorMaskLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
|
||||||
|
export const getVectorMaskLayerObjectGroupId = (layerId: string, groupId: string) =>
|
||||||
|
`${layerId}.objectGroup_${groupId}`;
|
||||||
|
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
|
||||||
|
|
||||||
|
export const regionalPromptsPersistConfig: PersistConfig<RegionalPromptsState> = {
|
||||||
|
name: regionalPromptsSlice.name,
|
||||||
|
initialState: initialRegionalPromptsState,
|
||||||
|
migrate: migrateRegionalPromptsState,
|
||||||
|
persistDenylist: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
// These actions are _individually_ grouped together as single undoable actions
|
||||||
|
const undoableGroupByMatcher = isAnyOf(
|
||||||
|
layerTranslated,
|
||||||
|
brushSizeChanged,
|
||||||
|
globalMaskLayerOpacityChanged,
|
||||||
|
maskLayerPositivePromptChanged,
|
||||||
|
maskLayerNegativePromptChanged,
|
||||||
|
maskLayerPreviewColorChanged
|
||||||
|
);
|
||||||
|
|
||||||
|
// These are used to group actions into logical lines below (hate typos)
|
||||||
|
const LINE_1 = 'LINE_1';
|
||||||
|
const LINE_2 = 'LINE_2';
|
||||||
|
|
||||||
|
export const regionalPromptsUndoableConfig: UndoableOptions<RegionalPromptsState, UnknownAction> = {
|
||||||
|
limit: 64,
|
||||||
|
undoType: regionalPromptsSlice.actions.undo.type,
|
||||||
|
redoType: regionalPromptsSlice.actions.redo.type,
|
||||||
|
groupBy: (action, state, history) => {
|
||||||
|
// Lines are started with `maskLayerLineAdded` and may have any number of subsequent `maskLayerPointsAdded` events.
|
||||||
|
// We can use a double-buffer-esque trick to group each "logical" line as a single undoable action, without grouping
|
||||||
|
// separate logical lines as a single undo action.
|
||||||
|
if (maskLayerLineAdded.match(action)) {
|
||||||
|
return history.group === LINE_1 ? LINE_2 : LINE_1;
|
||||||
|
}
|
||||||
|
if (maskLayerPointsAdded.match(action)) {
|
||||||
|
if (history.group === LINE_1 || history.group === LINE_2) {
|
||||||
|
return history.group;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (undoableGroupByMatcher(action)) {
|
||||||
|
return action.type;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
filter: (action, _state, _history) => {
|
||||||
|
// Ignore all actions from other slices
|
||||||
|
if (!action.type.startsWith(regionalPromptsSlice.name)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// This action is triggered on state changes, including when we undo. If we do not ignore this action, when we
|
||||||
|
// undo, this action triggers and empties the future states array. Therefore, we must ignore this action.
|
||||||
|
if (layerBboxChanged.match(action)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
},
|
||||||
|
};
|
134
invokeai/frontend/web/src/features/regionalPrompts/util/bbox.ts
Normal file
134
invokeai/frontend/web/src/features/regionalPrompts/util/bbox.ts
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
|
import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||||
|
import { VECTOR_MASK_LAYER_OBJECT_GROUP_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import Konva from 'konva';
|
||||||
|
import type { Layer as KonvaLayerType } from 'konva/lib/Layer';
|
||||||
|
import type { IRect } from 'konva/lib/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
|
||||||
|
|
||||||
|
type Extents = {
|
||||||
|
minX: number;
|
||||||
|
minY: number;
|
||||||
|
maxX: number;
|
||||||
|
maxY: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the bounding box of an image.
|
||||||
|
* @param imageData The ImageData object to get the bounding box of.
|
||||||
|
* @returns The minimum and maximum x and y values of the image's bounding box.
|
||||||
|
*/
|
||||||
|
const getImageDataBbox = (imageData: ImageData): Extents | null => {
|
||||||
|
const { data, width, height } = imageData;
|
||||||
|
let minX = width;
|
||||||
|
let minY = height;
|
||||||
|
let maxX = -1;
|
||||||
|
let maxY = -1;
|
||||||
|
let alpha = 0;
|
||||||
|
let isEmpty = true;
|
||||||
|
|
||||||
|
for (let y = 0; y < height; y++) {
|
||||||
|
for (let x = 0; x < width; x++) {
|
||||||
|
alpha = data[(y * width + x) * 4 + 3] ?? 0;
|
||||||
|
if (alpha > 0) {
|
||||||
|
isEmpty = false;
|
||||||
|
if (x < minX) {
|
||||||
|
minX = x;
|
||||||
|
}
|
||||||
|
if (x > maxX) {
|
||||||
|
maxX = x;
|
||||||
|
}
|
||||||
|
if (y < minY) {
|
||||||
|
minY = y;
|
||||||
|
}
|
||||||
|
if (y > maxY) {
|
||||||
|
maxY = y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return isEmpty ? null : { minX, minY, maxX, maxY };
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the bounding box of a regional prompt konva layer. This function has special handling for regional prompt layers.
|
||||||
|
* @param layer The konva layer to get the bounding box of.
|
||||||
|
* @param preview Whether to open a new tab displaying the rendered layer, which is used to calculate the bbox.
|
||||||
|
*/
|
||||||
|
export const getLayerBboxPixels = (layer: KonvaLayerType, preview: boolean = false): IRect | null => {
|
||||||
|
// To calculate the layer's bounding box, we must first export it to a pixel array, then do some math.
|
||||||
|
//
|
||||||
|
// Though it is relatively fast, we can't use Konva's `getClientRect`. It programmatically determines the rect
|
||||||
|
// by calculating the extents of individual shapes from their "vector" shape data.
|
||||||
|
//
|
||||||
|
// This doesn't work when some shapes are drawn with composite operations that "erase" pixels, like eraser lines.
|
||||||
|
// These shapes' extents are still calculated as if they were solid, leading to a bounding box that is too large.
|
||||||
|
const stage = layer.getStage();
|
||||||
|
|
||||||
|
// Construct and offscreen canvas on which we will do the bbox calculations.
|
||||||
|
const offscreenStageContainer = document.createElement('div');
|
||||||
|
const offscreenStage = new Konva.Stage({
|
||||||
|
container: offscreenStageContainer,
|
||||||
|
width: stage.width(),
|
||||||
|
height: stage.height(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Clone the layer and filter out unwanted children.
|
||||||
|
const layerClone = layer.clone();
|
||||||
|
offscreenStage.add(layerClone);
|
||||||
|
|
||||||
|
for (const child of layerClone.getChildren()) {
|
||||||
|
if (child.name() === VECTOR_MASK_LAYER_OBJECT_GROUP_NAME) {
|
||||||
|
// We need to cache the group to ensure it composites out eraser strokes correctly
|
||||||
|
child.opacity(1);
|
||||||
|
child.cache();
|
||||||
|
} else {
|
||||||
|
// Filter out unwanted children.
|
||||||
|
child.destroy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a worst-case rect using the relatively fast `getClientRect`.
|
||||||
|
const layerRect = layerClone.getClientRect();
|
||||||
|
|
||||||
|
// Capture the image data with the above rect.
|
||||||
|
const layerImageData = offscreenStage
|
||||||
|
.toCanvas(layerRect)
|
||||||
|
.getContext('2d')
|
||||||
|
?.getImageData(0, 0, layerRect.width, layerRect.height);
|
||||||
|
assert(layerImageData, "Unable to get layer's image data");
|
||||||
|
|
||||||
|
if (preview) {
|
||||||
|
openBase64ImageInTab([{ base64: imageDataToDataURL(layerImageData), caption: layer.id() }]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the layer's bounding box.
|
||||||
|
const layerBbox = getImageDataBbox(layerImageData);
|
||||||
|
|
||||||
|
if (!layerBbox) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Correct the bounding box to be relative to the layer's position.
|
||||||
|
const correctedLayerBbox = {
|
||||||
|
x: layerBbox.minX - Math.floor(stage.x()) + layerRect.x - Math.floor(layer.x()),
|
||||||
|
y: layerBbox.minY - Math.floor(stage.y()) + layerRect.y - Math.floor(layer.y()),
|
||||||
|
width: layerBbox.maxX - layerBbox.minX,
|
||||||
|
height: layerBbox.maxY - layerBbox.minY,
|
||||||
|
};
|
||||||
|
|
||||||
|
return correctedLayerBbox;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getLayerBboxFast = (layer: KonvaLayerType): IRect | null => {
|
||||||
|
const bbox = layer.getClientRect(GET_CLIENT_RECT_CONFIG);
|
||||||
|
return {
|
||||||
|
x: Math.floor(bbox.x),
|
||||||
|
y: Math.floor(bbox.y),
|
||||||
|
width: Math.floor(bbox.width),
|
||||||
|
height: Math.floor(bbox.height),
|
||||||
|
};
|
||||||
|
};
|
@ -0,0 +1,64 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
|
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||||
|
import { VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { renderers } from 'features/regionalPrompts/util/renderers';
|
||||||
|
import Konva from 'konva';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||||
|
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||||
|
* @param preview Whether to open a new tab displaying each layer.
|
||||||
|
* @returns A map of layer IDs to blobs.
|
||||||
|
*/
|
||||||
|
export const getRegionalPromptLayerBlobs = async (
|
||||||
|
layerIds?: string[],
|
||||||
|
preview: boolean = false
|
||||||
|
): Promise<Record<string, Blob>> => {
|
||||||
|
const state = getStore().getState();
|
||||||
|
const reduxLayers = state.regionalPrompts.present.layers;
|
||||||
|
const container = document.createElement('div');
|
||||||
|
const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height });
|
||||||
|
renderers.renderLayers(stage, reduxLayers, 1, 'brush');
|
||||||
|
|
||||||
|
const konvaLayers = stage.find<Konva.Layer>(`.${VECTOR_MASK_LAYER_NAME}`);
|
||||||
|
const blobs: Record<string, Blob> = {};
|
||||||
|
|
||||||
|
// First remove all layers
|
||||||
|
for (const layer of konvaLayers) {
|
||||||
|
layer.remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next render each layer to a blob
|
||||||
|
for (const layer of konvaLayers) {
|
||||||
|
if (layerIds && !layerIds.includes(layer.id())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
|
||||||
|
assert(reduxLayer, `Redux layer ${layer.id()} not found`);
|
||||||
|
stage.add(layer);
|
||||||
|
const blob = await new Promise<Blob>((resolve) => {
|
||||||
|
stage.toBlob({
|
||||||
|
callback: (blob) => {
|
||||||
|
assert(blob, 'Blob is null');
|
||||||
|
resolve(blob);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (preview) {
|
||||||
|
const base64 = await blobToDataURL(blob);
|
||||||
|
openBase64ImageInTab([
|
||||||
|
{
|
||||||
|
base64,
|
||||||
|
caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
layer.remove();
|
||||||
|
blobs[layer.id()] = blob;
|
||||||
|
}
|
||||||
|
|
||||||
|
return blobs;
|
||||||
|
};
|
@ -0,0 +1,619 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
|
||||||
|
import { getScaledFlooredCursorPosition } from 'features/regionalPrompts/hooks/mouseEventHooks';
|
||||||
|
import type {
|
||||||
|
Layer,
|
||||||
|
Tool,
|
||||||
|
VectorMaskLayer,
|
||||||
|
VectorMaskLine,
|
||||||
|
VectorMaskRect,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import {
|
||||||
|
$tool,
|
||||||
|
BACKGROUND_LAYER_ID,
|
||||||
|
BACKGROUND_RECT_ID,
|
||||||
|
getLayerBboxId,
|
||||||
|
getVectorMaskLayerObjectGroupId,
|
||||||
|
isVectorMaskLayer,
|
||||||
|
LAYER_BBOX_NAME,
|
||||||
|
TOOL_PREVIEW_BRUSH_BORDER_INNER_ID,
|
||||||
|
TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID,
|
||||||
|
TOOL_PREVIEW_BRUSH_FILL_ID,
|
||||||
|
TOOL_PREVIEW_BRUSH_GROUP_ID,
|
||||||
|
TOOL_PREVIEW_LAYER_ID,
|
||||||
|
TOOL_PREVIEW_RECT_ID,
|
||||||
|
VECTOR_MASK_LAYER_LINE_NAME,
|
||||||
|
VECTOR_MASK_LAYER_NAME,
|
||||||
|
VECTOR_MASK_LAYER_OBJECT_GROUP_NAME,
|
||||||
|
VECTOR_MASK_LAYER_RECT_NAME,
|
||||||
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { getLayerBboxFast, getLayerBboxPixels } from 'features/regionalPrompts/util/bbox';
|
||||||
|
import Konva from 'konva';
|
||||||
|
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||||
|
import { debounce } from 'lodash-es';
|
||||||
|
import type { RgbColor } from 'react-colorful';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
|
||||||
|
const BBOX_NOT_SELECTED_STROKE = 'rgba(255, 255, 255, 0.353)';
|
||||||
|
const BBOX_NOT_SELECTED_MOUSEOVER_STROKE = 'rgba(255, 255, 255, 0.661)';
|
||||||
|
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||||
|
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||||
|
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||||
|
const STAGE_BG_DATAURL =
|
||||||
|
'';
|
||||||
|
|
||||||
|
const mapId = (object: { id: string }) => object.id;
|
||||||
|
|
||||||
|
const getIsSelected = (layerId?: string | null) => {
|
||||||
|
if (!layerId) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return layerId === getStore().getState().regionalPrompts.present.selectedLayerId;
|
||||||
|
};
|
||||||
|
|
||||||
|
const selectVectorMaskObjects = (node: Konva.Node) => {
|
||||||
|
return node.name() === VECTOR_MASK_LAYER_LINE_NAME || node.name() === VECTOR_MASK_LAYER_RECT_NAME;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the brush preview layer.
|
||||||
|
* @param stage The konva stage to render on.
|
||||||
|
* @returns The brush preview layer.
|
||||||
|
*/
|
||||||
|
const createToolPreviewLayer = (stage: Konva.Stage) => {
|
||||||
|
// Initialize the brush preview layer & add to the stage
|
||||||
|
const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false });
|
||||||
|
stage.add(toolPreviewLayer);
|
||||||
|
|
||||||
|
// Add handlers to show/hide the brush preview layer
|
||||||
|
stage.on('mousemove', (e) => {
|
||||||
|
const tool = $tool.get();
|
||||||
|
e.target
|
||||||
|
.getStage()
|
||||||
|
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||||
|
?.visible(tool === 'brush' || tool === 'eraser');
|
||||||
|
});
|
||||||
|
stage.on('mouseleave', (e) => {
|
||||||
|
e.target.getStage()?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
|
||||||
|
});
|
||||||
|
stage.on('mouseenter', (e) => {
|
||||||
|
const tool = $tool.get();
|
||||||
|
e.target
|
||||||
|
.getStage()
|
||||||
|
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||||
|
?.visible(tool === 'brush' || tool === 'eraser');
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create the brush preview group & circles
|
||||||
|
const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID });
|
||||||
|
const brushPreviewFill = new Konva.Circle({
|
||||||
|
id: TOOL_PREVIEW_BRUSH_FILL_ID,
|
||||||
|
listening: false,
|
||||||
|
strokeEnabled: false,
|
||||||
|
});
|
||||||
|
brushPreviewGroup.add(brushPreviewFill);
|
||||||
|
const brushPreviewBorderInner = new Konva.Circle({
|
||||||
|
id: TOOL_PREVIEW_BRUSH_BORDER_INNER_ID,
|
||||||
|
listening: false,
|
||||||
|
stroke: BRUSH_BORDER_INNER_COLOR,
|
||||||
|
strokeWidth: 1,
|
||||||
|
strokeEnabled: true,
|
||||||
|
});
|
||||||
|
brushPreviewGroup.add(brushPreviewBorderInner);
|
||||||
|
const brushPreviewBorderOuter = new Konva.Circle({
|
||||||
|
id: TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID,
|
||||||
|
listening: false,
|
||||||
|
stroke: BRUSH_BORDER_OUTER_COLOR,
|
||||||
|
strokeWidth: 1,
|
||||||
|
strokeEnabled: true,
|
||||||
|
});
|
||||||
|
brushPreviewGroup.add(brushPreviewBorderOuter);
|
||||||
|
toolPreviewLayer.add(brushPreviewGroup);
|
||||||
|
|
||||||
|
// Create the rect preview
|
||||||
|
const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 });
|
||||||
|
toolPreviewLayer.add(rectPreview);
|
||||||
|
|
||||||
|
return toolPreviewLayer;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders the brush preview for the selected tool.
|
||||||
|
* @param stage The konva stage to render on.
|
||||||
|
* @param tool The selected tool.
|
||||||
|
* @param color The selected layer's color.
|
||||||
|
* @param cursorPos The cursor position.
|
||||||
|
* @param lastMouseDownPos The position of the last mouse down event - used for the rect tool.
|
||||||
|
* @param brushSize The brush size.
|
||||||
|
*/
|
||||||
|
const renderToolPreview = (
|
||||||
|
stage: Konva.Stage,
|
||||||
|
tool: Tool,
|
||||||
|
color: RgbColor | null,
|
||||||
|
globalMaskLayerOpacity: number,
|
||||||
|
cursorPos: Vector2d | null,
|
||||||
|
lastMouseDownPos: Vector2d | null,
|
||||||
|
isMouseOver: boolean,
|
||||||
|
brushSize: number
|
||||||
|
) => {
|
||||||
|
const layerCount = stage.find(`.${VECTOR_MASK_LAYER_NAME}`).length;
|
||||||
|
// Update the stage's pointer style
|
||||||
|
if (layerCount === 0) {
|
||||||
|
// We have no layers, so we should not render any tool
|
||||||
|
stage.container().style.cursor = 'default';
|
||||||
|
} else if (tool === 'move') {
|
||||||
|
// Move tool gets a pointer
|
||||||
|
stage.container().style.cursor = 'default';
|
||||||
|
} else if (tool === 'rect') {
|
||||||
|
// Move rect gets a crosshair
|
||||||
|
stage.container().style.cursor = 'crosshair';
|
||||||
|
} else {
|
||||||
|
// Else we use the brush preview
|
||||||
|
stage.container().style.cursor = 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolPreviewLayer = stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`) ?? createToolPreviewLayer(stage);
|
||||||
|
|
||||||
|
if (!isMouseOver || layerCount === 0) {
|
||||||
|
// We can bail early if the mouse isn't over the stage or there are no layers
|
||||||
|
toolPreviewLayer.visible(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
toolPreviewLayer.visible(true);
|
||||||
|
|
||||||
|
const brushPreviewGroup = stage.findOne<Konva.Group>(`#${TOOL_PREVIEW_BRUSH_GROUP_ID}`);
|
||||||
|
assert(brushPreviewGroup, 'Brush preview group not found');
|
||||||
|
|
||||||
|
const rectPreview = stage.findOne<Konva.Rect>(`#${TOOL_PREVIEW_RECT_ID}`);
|
||||||
|
assert(rectPreview, 'Rect preview not found');
|
||||||
|
|
||||||
|
// No need to render the brush preview if the cursor position or color is missing
|
||||||
|
if (cursorPos && color && (tool === 'brush' || tool === 'eraser')) {
|
||||||
|
// Update the fill circle
|
||||||
|
const brushPreviewFill = brushPreviewGroup.findOne<Konva.Circle>(`#${TOOL_PREVIEW_BRUSH_FILL_ID}`);
|
||||||
|
brushPreviewFill?.setAttrs({
|
||||||
|
x: cursorPos.x,
|
||||||
|
y: cursorPos.y,
|
||||||
|
radius: brushSize / 2,
|
||||||
|
fill: rgbaColorToString({ ...color, a: globalMaskLayerOpacity }),
|
||||||
|
globalCompositeOperation: tool === 'brush' ? 'source-over' : 'destination-out',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update the inner border of the brush preview
|
||||||
|
const brushPreviewInner = toolPreviewLayer.findOne<Konva.Circle>(`#${TOOL_PREVIEW_BRUSH_BORDER_INNER_ID}`);
|
||||||
|
brushPreviewInner?.setAttrs({ x: cursorPos.x, y: cursorPos.y, radius: brushSize / 2 });
|
||||||
|
|
||||||
|
// Update the outer border of the brush preview
|
||||||
|
const brushPreviewOuter = toolPreviewLayer.findOne<Konva.Circle>(`#${TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID}`);
|
||||||
|
brushPreviewOuter?.setAttrs({
|
||||||
|
x: cursorPos.x,
|
||||||
|
y: cursorPos.y,
|
||||||
|
radius: brushSize / 2 + 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
brushPreviewGroup.visible(true);
|
||||||
|
} else {
|
||||||
|
brushPreviewGroup.visible(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cursorPos && lastMouseDownPos && tool === 'rect') {
|
||||||
|
const rectPreview = toolPreviewLayer.findOne<Konva.Rect>(`#${TOOL_PREVIEW_RECT_ID}`);
|
||||||
|
rectPreview?.setAttrs({
|
||||||
|
x: Math.min(cursorPos.x, lastMouseDownPos.x),
|
||||||
|
y: Math.min(cursorPos.y, lastMouseDownPos.y),
|
||||||
|
width: Math.abs(cursorPos.x - lastMouseDownPos.x),
|
||||||
|
height: Math.abs(cursorPos.y - lastMouseDownPos.y),
|
||||||
|
});
|
||||||
|
rectPreview?.visible(true);
|
||||||
|
} else {
|
||||||
|
rectPreview?.visible(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a vector mask layer.
|
||||||
|
* @param stage The konva stage to attach the layer to.
|
||||||
|
* @param reduxLayer The redux layer to create the konva layer from.
|
||||||
|
* @param onLayerPosChanged Callback for when the layer's position changes.
|
||||||
|
*/
|
||||||
|
const createVectorMaskLayer = (
|
||||||
|
stage: Konva.Stage,
|
||||||
|
reduxLayer: VectorMaskLayer,
|
||||||
|
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||||
|
) => {
|
||||||
|
// This layer hasn't been added to the konva state yet
|
||||||
|
const konvaLayer = new Konva.Layer({
|
||||||
|
id: reduxLayer.id,
|
||||||
|
name: VECTOR_MASK_LAYER_NAME,
|
||||||
|
draggable: true,
|
||||||
|
dragDistance: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create a `dragmove` listener for this layer
|
||||||
|
if (onLayerPosChanged) {
|
||||||
|
konvaLayer.on('dragend', function (e) {
|
||||||
|
onLayerPosChanged(reduxLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// The dragBoundFunc limits how far the layer can be dragged
|
||||||
|
konvaLayer.dragBoundFunc(function (pos) {
|
||||||
|
const cursorPos = getScaledFlooredCursorPosition(stage);
|
||||||
|
if (!cursorPos) {
|
||||||
|
return this.getAbsolutePosition();
|
||||||
|
}
|
||||||
|
// Prevent the user from dragging the layer out of the stage bounds.
|
||||||
|
if (
|
||||||
|
cursorPos.x < 0 ||
|
||||||
|
cursorPos.x > stage.width() / stage.scaleX() ||
|
||||||
|
cursorPos.y < 0 ||
|
||||||
|
cursorPos.y > stage.height() / stage.scaleY()
|
||||||
|
) {
|
||||||
|
return this.getAbsolutePosition();
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
});
|
||||||
|
|
||||||
|
// The object group holds all of the layer's objects (e.g. lines and rects)
|
||||||
|
const konvaObjectGroup = new Konva.Group({
|
||||||
|
id: getVectorMaskLayerObjectGroupId(reduxLayer.id, uuidv4()),
|
||||||
|
name: VECTOR_MASK_LAYER_OBJECT_GROUP_NAME,
|
||||||
|
listening: false,
|
||||||
|
});
|
||||||
|
konvaLayer.add(konvaObjectGroup);
|
||||||
|
|
||||||
|
stage.add(konvaLayer);
|
||||||
|
|
||||||
|
return konvaLayer;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a konva line from a redux vector mask line.
|
||||||
|
* @param reduxObject The redux object to create the konva line from.
|
||||||
|
* @param konvaGroup The konva group to add the line to.
|
||||||
|
*/
|
||||||
|
const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Group): Konva.Line => {
|
||||||
|
const vectorMaskLine = new Konva.Line({
|
||||||
|
id: reduxObject.id,
|
||||||
|
key: reduxObject.id,
|
||||||
|
name: VECTOR_MASK_LAYER_LINE_NAME,
|
||||||
|
strokeWidth: reduxObject.strokeWidth,
|
||||||
|
tension: 0,
|
||||||
|
lineCap: 'round',
|
||||||
|
lineJoin: 'round',
|
||||||
|
shadowForStrokeEnabled: false,
|
||||||
|
globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||||
|
listening: false,
|
||||||
|
});
|
||||||
|
konvaGroup.add(vectorMaskLine);
|
||||||
|
return vectorMaskLine;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a konva rect from a redux vector mask rect.
|
||||||
|
* @param reduxObject The redux object to create the konva rect from.
|
||||||
|
* @param konvaGroup The konva group to add the rect to.
|
||||||
|
*/
|
||||||
|
const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Group): Konva.Rect => {
|
||||||
|
const vectorMaskRect = new Konva.Rect({
|
||||||
|
id: reduxObject.id,
|
||||||
|
key: reduxObject.id,
|
||||||
|
name: VECTOR_MASK_LAYER_RECT_NAME,
|
||||||
|
x: reduxObject.x,
|
||||||
|
y: reduxObject.y,
|
||||||
|
width: reduxObject.width,
|
||||||
|
height: reduxObject.height,
|
||||||
|
listening: false,
|
||||||
|
});
|
||||||
|
konvaGroup.add(vectorMaskRect);
|
||||||
|
return vectorMaskRect;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders a vector mask layer.
|
||||||
|
* @param stage The konva stage to render on.
|
||||||
|
* @param reduxLayer The redux vector mask layer to render.
|
||||||
|
* @param reduxLayerIndex The index of the layer in the redux store.
|
||||||
|
* @param globalMaskLayerOpacity The opacity of the global mask layer.
|
||||||
|
* @param tool The current tool.
|
||||||
|
*/
|
||||||
|
const renderVectorMaskLayer = (
|
||||||
|
stage: Konva.Stage,
|
||||||
|
reduxLayer: VectorMaskLayer,
|
||||||
|
globalMaskLayerOpacity: number,
|
||||||
|
tool: Tool,
|
||||||
|
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||||
|
): void => {
|
||||||
|
const konvaLayer =
|
||||||
|
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createVectorMaskLayer(stage, reduxLayer, onLayerPosChanged);
|
||||||
|
|
||||||
|
// Update the layer's position and listening state
|
||||||
|
konvaLayer.setAttrs({
|
||||||
|
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
|
||||||
|
x: Math.floor(reduxLayer.x),
|
||||||
|
y: Math.floor(reduxLayer.y),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert the color to a string, stripping the alpha - the object group will handle opacity.
|
||||||
|
const rgbColor = rgbColorToString(reduxLayer.previewColor);
|
||||||
|
|
||||||
|
const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${VECTOR_MASK_LAYER_OBJECT_GROUP_NAME}`);
|
||||||
|
assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`);
|
||||||
|
|
||||||
|
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
|
||||||
|
let groupNeedsCache = false;
|
||||||
|
|
||||||
|
const objectIds = reduxLayer.objects.map(mapId);
|
||||||
|
for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) {
|
||||||
|
if (!objectIds.includes(objectNode.id())) {
|
||||||
|
objectNode.destroy();
|
||||||
|
groupNeedsCache = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const reduxObject of reduxLayer.objects) {
|
||||||
|
if (reduxObject.type === 'vector_mask_line') {
|
||||||
|
const vectorMaskLine =
|
||||||
|
stage.findOne<Konva.Line>(`#${reduxObject.id}`) ?? createVectorMaskLine(reduxObject, konvaObjectGroup);
|
||||||
|
|
||||||
|
// Only update the points if they have changed. The point values are never mutated, they are only added to the
|
||||||
|
// array, so checking the length is sufficient to determine if we need to re-cache.
|
||||||
|
if (vectorMaskLine.points().length !== reduxObject.points.length) {
|
||||||
|
vectorMaskLine.points(reduxObject.points);
|
||||||
|
groupNeedsCache = true;
|
||||||
|
}
|
||||||
|
// Only update the color if it has changed.
|
||||||
|
if (vectorMaskLine.stroke() !== rgbColor) {
|
||||||
|
vectorMaskLine.stroke(rgbColor);
|
||||||
|
groupNeedsCache = true;
|
||||||
|
}
|
||||||
|
} else if (reduxObject.type === 'vector_mask_rect') {
|
||||||
|
const konvaObject =
|
||||||
|
stage.findOne<Konva.Rect>(`#${reduxObject.id}`) ?? createVectorMaskRect(reduxObject, konvaObjectGroup);
|
||||||
|
|
||||||
|
// Only update the color if it has changed.
|
||||||
|
if (konvaObject.fill() !== rgbColor) {
|
||||||
|
konvaObject.fill(rgbColor);
|
||||||
|
groupNeedsCache = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only update layer visibility if it has changed.
|
||||||
|
if (konvaLayer.visible() !== reduxLayer.isVisible) {
|
||||||
|
konvaLayer.visible(reduxLayer.isVisible);
|
||||||
|
groupNeedsCache = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (konvaObjectGroup.children.length === 0) {
|
||||||
|
// No objects - clear the cache to reset the previous pixel data
|
||||||
|
konvaObjectGroup.clearCache();
|
||||||
|
} else if (groupNeedsCache) {
|
||||||
|
konvaObjectGroup.cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updating group opacity does not require re-caching
|
||||||
|
if (konvaObjectGroup.opacity() !== globalMaskLayerOpacity) {
|
||||||
|
konvaObjectGroup.opacity(globalMaskLayerOpacity);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders the layers on the stage.
|
||||||
|
* @param stage The konva stage to render on.
|
||||||
|
* @param reduxLayers Array of the layers from the redux store.
|
||||||
|
* @param layerOpacity The opacity of the layer.
|
||||||
|
* @param onLayerPosChanged Callback for when the layer's position changes. This is optional to allow for offscreen rendering.
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
const renderLayers = (
|
||||||
|
stage: Konva.Stage,
|
||||||
|
reduxLayers: Layer[],
|
||||||
|
globalMaskLayerOpacity: number,
|
||||||
|
tool: Tool,
|
||||||
|
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||||
|
) => {
|
||||||
|
const reduxLayerIds = reduxLayers.map(mapId);
|
||||||
|
|
||||||
|
// Remove un-rendered layers
|
||||||
|
for (const konvaLayer of stage.find<Konva.Layer>(`.${VECTOR_MASK_LAYER_NAME}`)) {
|
||||||
|
if (!reduxLayerIds.includes(konvaLayer.id())) {
|
||||||
|
konvaLayer.destroy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const reduxLayer of reduxLayers) {
|
||||||
|
if (isVectorMaskLayer(reduxLayer)) {
|
||||||
|
renderVectorMaskLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a bounding box rect for a layer.
|
||||||
|
* @param reduxLayer The redux layer to create the bounding box for.
|
||||||
|
* @param konvaLayer The konva layer to attach the bounding box to.
|
||||||
|
* @param onBboxMouseDown Callback for when the bounding box is clicked.
|
||||||
|
*/
|
||||||
|
const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer, onBboxMouseDown: (layerId: string) => void) => {
|
||||||
|
const rect = new Konva.Rect({
|
||||||
|
id: getLayerBboxId(reduxLayer.id),
|
||||||
|
name: LAYER_BBOX_NAME,
|
||||||
|
strokeWidth: 1,
|
||||||
|
});
|
||||||
|
rect.on('mousedown', function () {
|
||||||
|
onBboxMouseDown(reduxLayer.id);
|
||||||
|
});
|
||||||
|
rect.on('mouseover', function (e) {
|
||||||
|
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||||
|
this.stroke(BBOX_SELECTED_STROKE);
|
||||||
|
} else {
|
||||||
|
this.stroke(BBOX_NOT_SELECTED_MOUSEOVER_STROKE);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
rect.on('mouseout', function (e) {
|
||||||
|
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||||
|
this.stroke(BBOX_SELECTED_STROKE);
|
||||||
|
} else {
|
||||||
|
this.stroke(BBOX_NOT_SELECTED_STROKE);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
konvaLayer.add(rect);
|
||||||
|
return rect;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders the bounding boxes for the layers.
|
||||||
|
* @param stage The konva stage to render on
|
||||||
|
* @param reduxLayers An array of all redux layers to draw bboxes for
|
||||||
|
* @param selectedLayerId The selected layer's id
|
||||||
|
* @param tool The current tool
|
||||||
|
* @param onBboxChanged Callback for when the bbox is changed
|
||||||
|
* @param onBboxMouseDown Callback for when the bbox is clicked
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
const renderBbox = (
|
||||||
|
stage: Konva.Stage,
|
||||||
|
reduxLayers: Layer[],
|
||||||
|
selectedLayerId: string | null,
|
||||||
|
tool: Tool,
|
||||||
|
onBboxChanged: (layerId: string, bbox: IRect | null) => void,
|
||||||
|
onBboxMouseDown: (layerId: string) => void
|
||||||
|
) => {
|
||||||
|
// Hide all bboxes so they don't interfere with getClientRect
|
||||||
|
for (const bboxRect of stage.find<Konva.Rect>(`.${LAYER_BBOX_NAME}`)) {
|
||||||
|
bboxRect.visible(false);
|
||||||
|
bboxRect.listening(false);
|
||||||
|
}
|
||||||
|
// No selected layer or not using the move tool - nothing more to do here
|
||||||
|
if (tool !== 'move') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const reduxLayer of reduxLayers) {
|
||||||
|
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
|
||||||
|
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
|
||||||
|
|
||||||
|
let bbox = reduxLayer.bbox;
|
||||||
|
|
||||||
|
// We only need to recalculate the bbox if the layer has changed and it has objects
|
||||||
|
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
|
||||||
|
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes
|
||||||
|
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
|
||||||
|
// Update the layer's bbox in the redux store
|
||||||
|
onBboxChanged(reduxLayer.id, bbox);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!bbox) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const rect =
|
||||||
|
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
|
||||||
|
|
||||||
|
rect.setAttrs({
|
||||||
|
visible: true,
|
||||||
|
listening: true,
|
||||||
|
x: bbox.x,
|
||||||
|
y: bbox.y,
|
||||||
|
width: bbox.width,
|
||||||
|
height: bbox.height,
|
||||||
|
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the background layer for the stage.
|
||||||
|
* @param stage The konva stage to render on
|
||||||
|
*/
|
||||||
|
const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||||
|
const layer = new Konva.Layer({
|
||||||
|
id: BACKGROUND_LAYER_ID,
|
||||||
|
});
|
||||||
|
const background = new Konva.Rect({
|
||||||
|
id: BACKGROUND_RECT_ID,
|
||||||
|
x: stage.x(),
|
||||||
|
y: 0,
|
||||||
|
width: stage.width() / stage.scaleX(),
|
||||||
|
height: stage.height() / stage.scaleY(),
|
||||||
|
listening: false,
|
||||||
|
opacity: 0.2,
|
||||||
|
});
|
||||||
|
layer.add(background);
|
||||||
|
stage.add(layer);
|
||||||
|
const image = new Image();
|
||||||
|
image.onload = () => {
|
||||||
|
background.fillPatternImage(image);
|
||||||
|
};
|
||||||
|
image.src = STAGE_BG_DATAURL;
|
||||||
|
return layer;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders the background layer for the stage.
|
||||||
|
* @param stage The konva stage to render on
|
||||||
|
* @param width The unscaled width of the canvas
|
||||||
|
* @param height The unscaled height of the canvas
|
||||||
|
*/
|
||||||
|
const renderBackground = (stage: Konva.Stage, width: number, height: number) => {
|
||||||
|
const layer = stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage);
|
||||||
|
|
||||||
|
const background = layer.findOne<Konva.Rect>(`#${BACKGROUND_RECT_ID}`);
|
||||||
|
assert(background, 'Background rect not found');
|
||||||
|
// ensure background rect is in the top-left of the canvas
|
||||||
|
background.absolutePosition({ x: 0, y: 0 });
|
||||||
|
|
||||||
|
// set the dimensions of the background rect to match the canvas - not the stage!!!
|
||||||
|
background.size({
|
||||||
|
width: width / stage.scaleX(),
|
||||||
|
height: height / stage.scaleY(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Calculate the amount the stage is moved - including the effect of scaling
|
||||||
|
const stagePos = {
|
||||||
|
x: -stage.x() / stage.scaleX(),
|
||||||
|
y: -stage.y() / stage.scaleY(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply that movement to the fill pattern
|
||||||
|
background.fillPatternOffset(stagePos);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Arranges all layers in the z-axis by updating their z-indices.
|
||||||
|
* @param stage The konva stage
|
||||||
|
* @param layerIds An array of redux layer ids, in their z-index order
|
||||||
|
*/
|
||||||
|
const arrangeLayers = (stage: Konva.Stage, layerIds: string[]): void => {
|
||||||
|
let nextZIndex = 0;
|
||||||
|
// Background is the first layer
|
||||||
|
stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(nextZIndex++);
|
||||||
|
// Then arrange the redux layers in order
|
||||||
|
for (const layerId of layerIds) {
|
||||||
|
stage.findOne<Konva.Layer>(`#${layerId}`)?.zIndex(nextZIndex++);
|
||||||
|
}
|
||||||
|
// Finally, the tool preview layer is always on top
|
||||||
|
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const renderers = {
|
||||||
|
renderToolPreview,
|
||||||
|
renderLayers,
|
||||||
|
renderBbox,
|
||||||
|
renderBackground,
|
||||||
|
arrangeLayers,
|
||||||
|
};
|
||||||
|
|
||||||
|
const DEBOUNCE_MS = 300;
|
||||||
|
|
||||||
|
export const debouncedRenderers = {
|
||||||
|
renderToolPreview: debounce(renderToolPreview, DEBOUNCE_MS),
|
||||||
|
renderLayers: debounce(renderLayers, DEBOUNCE_MS),
|
||||||
|
renderBbox: debounce(renderBbox, DEBOUNCE_MS),
|
||||||
|
renderBackground: debounce(renderBackground, DEBOUNCE_MS),
|
||||||
|
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
|
||||||
|
};
|
@ -13,51 +13,60 @@ import {
|
|||||||
selectValidIPAdapters,
|
selectValidIPAdapters,
|
||||||
selectValidT2IAdapters,
|
selectValidT2IAdapters,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { Fragment, memo } from 'react';
|
import { Fragment, memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
const selector = createMemoizedSelector(
|
||||||
const badges: string[] = [];
|
[selectControlAdaptersSlice, selectRegionalPromptsSlice],
|
||||||
let isError = false;
|
(controlAdapters, regionalPrompts) => {
|
||||||
|
const badges: string[] = [];
|
||||||
|
let isError = false;
|
||||||
|
|
||||||
const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters).filter((ca) => ca.isEnabled).length;
|
const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters)
|
||||||
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
|
.filter((ca) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)))
|
||||||
if (enabledIPAdapterCount > 0) {
|
.filter((ca) => ca.isEnabled).length;
|
||||||
badges.push(`${enabledIPAdapterCount} IP`);
|
|
||||||
}
|
|
||||||
if (enabledIPAdapterCount > validIPAdapterCount) {
|
|
||||||
isError = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const enabledControlNetCount = selectAllControlNets(controlAdapters).filter((ca) => ca.isEnabled).length;
|
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
|
||||||
const validControlNetCount = selectValidControlNets(controlAdapters).length;
|
if (enabledIPAdapterCount > 0) {
|
||||||
if (enabledControlNetCount > 0) {
|
badges.push(`${enabledIPAdapterCount} IP`);
|
||||||
badges.push(`${enabledControlNetCount} ControlNet`);
|
}
|
||||||
}
|
if (enabledIPAdapterCount > validIPAdapterCount) {
|
||||||
if (enabledControlNetCount > validControlNetCount) {
|
isError = true;
|
||||||
isError = true;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters).filter((ca) => ca.isEnabled).length;
|
const enabledControlNetCount = selectAllControlNets(controlAdapters).filter((ca) => ca.isEnabled).length;
|
||||||
const validT2IAdapterCount = selectValidT2IAdapters(controlAdapters).length;
|
const validControlNetCount = selectValidControlNets(controlAdapters).length;
|
||||||
if (enabledT2IAdapterCount > 0) {
|
if (enabledControlNetCount > 0) {
|
||||||
badges.push(`${enabledT2IAdapterCount} T2I`);
|
badges.push(`${enabledControlNetCount} ControlNet`);
|
||||||
}
|
}
|
||||||
if (enabledT2IAdapterCount > validT2IAdapterCount) {
|
if (enabledControlNetCount > validControlNetCount) {
|
||||||
isError = true;
|
isError = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const controlAdapterIds = selectControlAdapterIds(controlAdapters);
|
const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters).filter((ca) => ca.isEnabled).length;
|
||||||
|
const validT2IAdapterCount = selectValidT2IAdapters(controlAdapters).length;
|
||||||
|
if (enabledT2IAdapterCount > 0) {
|
||||||
|
badges.push(`${enabledT2IAdapterCount} T2I`);
|
||||||
|
}
|
||||||
|
if (enabledT2IAdapterCount > validT2IAdapterCount) {
|
||||||
|
isError = true;
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter(
|
||||||
controlAdapterIds,
|
(id) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(id))
|
||||||
badges,
|
);
|
||||||
isError, // TODO: Add some visual indicator that the control adapters are in an error state
|
|
||||||
};
|
return {
|
||||||
});
|
controlAdapterIds,
|
||||||
|
badges,
|
||||||
|
isError, // TODO: Add some visual indicator that the control adapters are in an error state
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
export const ControlSettingsAccordion: React.FC = memo(() => {
|
export const ControlSettingsAccordion: React.FC = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
@ -2,6 +2,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { aspectRatioChanged, setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
import { aspectRatioChanged, setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||||
import ParamBoundingBoxHeight from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
import ParamBoundingBoxHeight from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
||||||
import ParamBoundingBoxWidth from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
import ParamBoundingBoxWidth from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
||||||
|
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||||
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
||||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||||
@ -41,6 +42,7 @@ export const ImageSizeCanvas = memo(() => {
|
|||||||
aspectRatioState={aspectRatioState}
|
aspectRatioState={aspectRatioState}
|
||||||
heightComponent={<ParamBoundingBoxHeight />}
|
heightComponent={<ParamBoundingBoxHeight />}
|
||||||
widthComponent={<ParamBoundingBoxWidth />}
|
widthComponent={<ParamBoundingBoxWidth />}
|
||||||
|
previewComponent={<AspectRatioIconPreview />}
|
||||||
onChangeAspectRatioState={onChangeAspectRatioState}
|
onChangeAspectRatioState={onChangeAspectRatioState}
|
||||||
onChangeWidth={onChangeWidth}
|
onChangeWidth={onChangeWidth}
|
||||||
onChangeHeight={onChangeHeight}
|
onChangeHeight={onChangeHeight}
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { ParamHeight } from 'features/parameters/components/Core/ParamHeight';
|
import { ParamHeight } from 'features/parameters/components/Core/ParamHeight';
|
||||||
import { ParamWidth } from 'features/parameters/components/Core/ParamWidth';
|
import { ParamWidth } from 'features/parameters/components/Core/ParamWidth';
|
||||||
|
import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview';
|
||||||
|
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||||
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
||||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||||
import { aspectRatioChanged, heightChanged, widthChanged } from 'features/parameters/store/generationSlice';
|
import { aspectRatioChanged, heightChanged, widthChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
export const ImageSizeLinear = memo(() => {
|
export const ImageSizeLinear = memo(() => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const tab = useAppSelector(activeTabNameSelector);
|
||||||
const width = useAppSelector((s) => s.generation.width);
|
const width = useAppSelector((s) => s.generation.width);
|
||||||
const height = useAppSelector((s) => s.generation.height);
|
const height = useAppSelector((s) => s.generation.height);
|
||||||
const aspectRatioState = useAppSelector((s) => s.generation.aspectRatio);
|
const aspectRatioState = useAppSelector((s) => s.generation.aspectRatio);
|
||||||
@ -40,6 +44,7 @@ export const ImageSizeLinear = memo(() => {
|
|||||||
aspectRatioState={aspectRatioState}
|
aspectRatioState={aspectRatioState}
|
||||||
heightComponent={<ParamHeight />}
|
heightComponent={<ParamHeight />}
|
||||||
widthComponent={<ParamWidth />}
|
widthComponent={<ParamWidth />}
|
||||||
|
previewComponent={tab === 'txt2img' ? <AspectRatioCanvasPreview /> : <AspectRatioIconPreview />}
|
||||||
onChangeAspectRatioState={onChangeAspectRatioState}
|
onChangeAspectRatioState={onChangeAspectRatioState}
|
||||||
onChangeWidth={onChangeWidth}
|
onChangeWidth={onChangeWidth}
|
||||||
onChangeHeight={onChangeHeight}
|
onChangeHeight={onChangeHeight}
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
import { Box, Flex, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||||
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
|
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
|
||||||
import QueueControls from 'features/queue/components/QueueControls';
|
import QueueControls from 'features/queue/components/QueueControls';
|
||||||
|
import { RegionalPromptsPanelContent } from 'features/regionalPrompts/components/RegionalPromptsPanelContent';
|
||||||
|
import { useRegionalControlTitle } from 'features/regionalPrompts/hooks/useRegionalControlTitle';
|
||||||
import { SDXLPrompts } from 'features/sdxl/components/SDXLPrompts/SDXLPrompts';
|
import { SDXLPrompts } from 'features/sdxl/components/SDXLPrompts/SDXLPrompts';
|
||||||
import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion';
|
import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion';
|
||||||
import { CompositingSettingsAccordion } from 'features/settingsAccordions/components/CompositingSettingsAccordion/CompositingSettingsAccordion';
|
import { CompositingSettingsAccordion } from 'features/settingsAccordions/components/CompositingSettingsAccordion/CompositingSettingsAccordion';
|
||||||
@ -14,6 +16,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
import type { CSSProperties } from 'react';
|
import type { CSSProperties } from 'react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const overlayScrollbarsStyles: CSSProperties = {
|
const overlayScrollbarsStyles: CSSProperties = {
|
||||||
height: '100%',
|
height: '100%',
|
||||||
@ -21,7 +24,9 @@ const overlayScrollbarsStyles: CSSProperties = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const ParametersPanel = () => {
|
const ParametersPanel = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
const regionalControlTitle = useRegionalControlTitle();
|
||||||
const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl');
|
const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl');
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -32,12 +37,28 @@ const ParametersPanel = () => {
|
|||||||
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayScrollbarsParams.options}>
|
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayScrollbarsParams.options}>
|
||||||
<Flex gap={2} flexDirection="column" h="full" w="full">
|
<Flex gap={2} flexDirection="column" h="full" w="full">
|
||||||
{isSDXL ? <SDXLPrompts /> : <Prompts />}
|
{isSDXL ? <SDXLPrompts /> : <Prompts />}
|
||||||
<ImageSettingsAccordion />
|
<Tabs variant="line" isLazy={true} display="flex" flexDir="column" w="full" h="full">
|
||||||
<GenerationSettingsAccordion />
|
<TabList>
|
||||||
<ControlSettingsAccordion />
|
<Tab>{t('parameters.globalSettings')}</Tab>
|
||||||
{activeTabName === 'unifiedCanvas' && <CompositingSettingsAccordion />}
|
<Tab>{regionalControlTitle}</Tab>
|
||||||
{isSDXL && <RefinerSettingsAccordion />}
|
</TabList>
|
||||||
<AdvancedSettingsAccordion />
|
|
||||||
|
<TabPanels w="full" h="full">
|
||||||
|
<TabPanel>
|
||||||
|
<Flex gap={2} flexDirection="column" h="full" w="full">
|
||||||
|
<ImageSettingsAccordion />
|
||||||
|
<GenerationSettingsAccordion />
|
||||||
|
<ControlSettingsAccordion />
|
||||||
|
{activeTabName === 'unifiedCanvas' && <CompositingSettingsAccordion />}
|
||||||
|
{isSDXL && <RefinerSettingsAccordion />}
|
||||||
|
<AdvancedSettingsAccordion />
|
||||||
|
</Flex>
|
||||||
|
</TabPanel>
|
||||||
|
<TabPanel>
|
||||||
|
<RegionalPromptsPanelContent />
|
||||||
|
</TabPanel>
|
||||||
|
</TabPanels>
|
||||||
|
</Tabs>
|
||||||
</Flex>
|
</Flex>
|
||||||
</OverlayScrollbarsComponent>
|
</OverlayScrollbarsComponent>
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -1,13 +1,31 @@
|
|||||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
import { Box, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||||
import CurrentImageDisplay from 'features/gallery/components/CurrentImage/CurrentImageDisplay';
|
import CurrentImageDisplay from 'features/gallery/components/CurrentImage/CurrentImageDisplay';
|
||||||
|
import { RegionalPromptsEditor } from 'features/regionalPrompts/components/RegionalPromptsEditor';
|
||||||
|
import { useRegionalControlTitle } from 'features/regionalPrompts/hooks/useRegionalControlTitle';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const TextToImageTab = () => {
|
const TextToImageTab = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const regionalControlTitle = useRegionalControlTitle();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
<Box position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||||
<Flex w="full" h="full">
|
<Tabs variant="line" isLazy={true} display="flex" flexDir="column" w="full" h="full">
|
||||||
<CurrentImageDisplay />
|
<TabList>
|
||||||
</Flex>
|
<Tab>{t('common.viewer')}</Tab>
|
||||||
|
<Tab>{regionalControlTitle}</Tab>
|
||||||
|
</TabList>
|
||||||
|
|
||||||
|
<TabPanels w="full" h="full" minH={0} minW={0}>
|
||||||
|
<TabPanel>
|
||||||
|
<CurrentImageDisplay />
|
||||||
|
</TabPanel>
|
||||||
|
<TabPanel>
|
||||||
|
<RegionalPromptsEditor />
|
||||||
|
</TabPanel>
|
||||||
|
</TabPanels>
|
||||||
|
</Tabs>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -124,7 +124,9 @@ export const usePanel = (arg: UsePanelOptions): UsePanelReturn => {
|
|||||||
*
|
*
|
||||||
* For now, we'll just resize the panel to the min size every time the panel group is resized.
|
* For now, we'll just resize the panel to the min size every time the panel group is resized.
|
||||||
*/
|
*/
|
||||||
panelHandleRef.current.resize(minSizePct);
|
if (!panelHandleRef.current.isCollapsed()) {
|
||||||
|
panelHandleRef.current.resize(minSizePct);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
resizeObserver.observe(panelGroupElement);
|
resizeObserver.observe(panelGroupElement);
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1 +1 @@
|
|||||||
__version__ = "4.1.0"
|
__version__ = "4.2.0a2"
|
||||||
|
@ -4,7 +4,7 @@ import pytest
|
|||||||
from torch import tensor
|
from torch import tensor
|
||||||
|
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
||||||
from invokeai.backend.model_manager.config import InvalidModelConfigException
|
from invokeai.backend.model_manager.config import InvalidModelConfigException, MainDiffusersConfig, ModelVariantType
|
||||||
from invokeai.backend.model_manager.probe import (
|
from invokeai.backend.model_manager.probe import (
|
||||||
CkptType,
|
CkptType,
|
||||||
ModelProbe,
|
ModelProbe,
|
||||||
@ -78,3 +78,11 @@ def test_probe_handles_state_dict_with_integer_keys():
|
|||||||
}
|
}
|
||||||
with pytest.raises(InvalidModelConfigException):
|
with pytest.raises(InvalidModelConfigException):
|
||||||
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)
|
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def test_probe_sd1_diffusers_inpainting(datadir: Path):
|
||||||
|
config = ModelProbe.probe(datadir / "sd-1/main/dreamshaper-8-inpainting")
|
||||||
|
assert isinstance(config, MainDiffusersConfig)
|
||||||
|
assert config.base is BaseModelType.StableDiffusion1
|
||||||
|
assert config.variant is ModelVariantType.Inpaint
|
||||||
|
assert config.repo_variant is ModelRepoVariant.FP16
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
This folder contains config files copied from [Lykon/dreamshaper-8-inpainting](https://huggingface.co/Lykon/dreamshaper-8-inpainting).
|
@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "StableDiffusionInpaintPipeline",
|
||||||
|
"_diffusers_version": "0.21.0.dev0",
|
||||||
|
"_name_or_path": "lykon-models/dreamshaper-8-inpainting",
|
||||||
|
"feature_extractor": [
|
||||||
|
"transformers",
|
||||||
|
"CLIPFeatureExtractor"
|
||||||
|
],
|
||||||
|
"requires_safety_checker": true,
|
||||||
|
"safety_checker": [
|
||||||
|
"stable_diffusion",
|
||||||
|
"StableDiffusionSafetyChecker"
|
||||||
|
],
|
||||||
|
"scheduler": [
|
||||||
|
"diffusers",
|
||||||
|
"DEISMultistepScheduler"
|
||||||
|
],
|
||||||
|
"text_encoder": [
|
||||||
|
"transformers",
|
||||||
|
"CLIPTextModel"
|
||||||
|
],
|
||||||
|
"tokenizer": [
|
||||||
|
"transformers",
|
||||||
|
"CLIPTokenizer"
|
||||||
|
],
|
||||||
|
"unet": [
|
||||||
|
"diffusers",
|
||||||
|
"UNet2DConditionModel"
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"diffusers",
|
||||||
|
"AutoencoderKL"
|
||||||
|
]
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "DEISMultistepScheduler",
|
||||||
|
"_diffusers_version": "0.21.0.dev0",
|
||||||
|
"algorithm_type": "deis",
|
||||||
|
"beta_end": 0.012,
|
||||||
|
"beta_schedule": "scaled_linear",
|
||||||
|
"beta_start": 0.00085,
|
||||||
|
"clip_sample": false,
|
||||||
|
"dynamic_thresholding_ratio": 0.995,
|
||||||
|
"lower_order_final": true,
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"prediction_type": "epsilon",
|
||||||
|
"sample_max_value": 1.0,
|
||||||
|
"set_alpha_to_one": false,
|
||||||
|
"skip_prk_steps": true,
|
||||||
|
"solver_order": 2,
|
||||||
|
"solver_type": "logrho",
|
||||||
|
"steps_offset": 1,
|
||||||
|
"thresholding": false,
|
||||||
|
"timestep_spacing": "leading",
|
||||||
|
"trained_betas": null,
|
||||||
|
"use_karras_sigmas": false
|
||||||
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "UNet2DConditionModel",
|
||||||
|
"_diffusers_version": "0.21.0.dev0",
|
||||||
|
"_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8-inpainting/snapshots/15dcb9dec91a39ee498e3917c9ef6174b103862d/unet",
|
||||||
|
"act_fn": "silu",
|
||||||
|
"addition_embed_type": null,
|
||||||
|
"addition_embed_type_num_heads": 64,
|
||||||
|
"addition_time_embed_dim": null,
|
||||||
|
"attention_head_dim": 8,
|
||||||
|
"attention_type": "default",
|
||||||
|
"block_out_channels": [
|
||||||
|
320,
|
||||||
|
640,
|
||||||
|
1280,
|
||||||
|
1280
|
||||||
|
],
|
||||||
|
"center_input_sample": false,
|
||||||
|
"class_embed_type": null,
|
||||||
|
"class_embeddings_concat": false,
|
||||||
|
"conv_in_kernel": 3,
|
||||||
|
"conv_out_kernel": 3,
|
||||||
|
"cross_attention_dim": 768,
|
||||||
|
"cross_attention_norm": null,
|
||||||
|
"down_block_types": [
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D"
|
||||||
|
],
|
||||||
|
"downsample_padding": 1,
|
||||||
|
"dual_cross_attention": false,
|
||||||
|
"encoder_hid_dim": null,
|
||||||
|
"encoder_hid_dim_type": null,
|
||||||
|
"flip_sin_to_cos": true,
|
||||||
|
"freq_shift": 0,
|
||||||
|
"in_channels": 9,
|
||||||
|
"layers_per_block": 2,
|
||||||
|
"mid_block_only_cross_attention": null,
|
||||||
|
"mid_block_scale_factor": 1,
|
||||||
|
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"norm_num_groups": 32,
|
||||||
|
"num_attention_heads": null,
|
||||||
|
"num_class_embeds": null,
|
||||||
|
"only_cross_attention": false,
|
||||||
|
"out_channels": 4,
|
||||||
|
"projection_class_embeddings_input_dim": null,
|
||||||
|
"resnet_out_scale_factor": 1.0,
|
||||||
|
"resnet_skip_time_act": false,
|
||||||
|
"resnet_time_scale_shift": "default",
|
||||||
|
"sample_size": 64,
|
||||||
|
"time_cond_proj_dim": null,
|
||||||
|
"time_embedding_act_fn": null,
|
||||||
|
"time_embedding_dim": null,
|
||||||
|
"time_embedding_type": "positional",
|
||||||
|
"timestep_post_act": null,
|
||||||
|
"transformer_layers_per_block": 1,
|
||||||
|
"up_block_types": [
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D"
|
||||||
|
],
|
||||||
|
"upcast_attention": null,
|
||||||
|
"use_linear_projection": false
|
||||||
|
}
|
@ -99,6 +99,20 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
|||||||
assert not Path(tmp_path, obj_1_name).exists()
|
assert not Path(tmp_path, obj_1_name).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path):
|
||||||
|
tempdir = tmp_path / "tmpdir"
|
||||||
|
tempdir.mkdir()
|
||||||
|
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||||
|
assert not tempdir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path):
|
||||||
|
tempdir = tmp_path / "tmpdir"
|
||||||
|
tempdir.mkdir()
|
||||||
|
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False)
|
||||||
|
assert tempdir.exists()
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user