mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/feat/config-migration
This commit is contained in:
commit
048306b417
@ -51,6 +51,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
@ -185,7 +186,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@ -198,6 +199,32 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
minimum_denoise: float = InputField(
|
||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||
)
|
||||
image: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||
title="[OPTIONAL] Image",
|
||||
ui_order=6,
|
||||
)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
title="[OPTIONAL] UNet",
|
||||
ui_order=5,
|
||||
)
|
||||
vae: Optional[VAEField] = InputField(
|
||||
default=None,
|
||||
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||
title="[OPTIONAL] VAE",
|
||||
input=Input.Connection,
|
||||
ui_order=7,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == "float32",
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=9,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
@ -233,8 +260,27 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
masked_latents_name = None
|
||||
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||
# all three fields must be present at the same time
|
||||
main_model_config = context.models.get_config(self.unet.unet.key)
|
||||
assert isinstance(main_model_config, MainConfigBase)
|
||||
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||
mask = blur_tensor
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(
|
||||
vae_info, self.fp32, self.tiled, masked_image.clone()
|
||||
)
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation
|
||||
from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
|
||||
|
||||
@ -34,3 +35,86 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"alpha_mask_to_tensor",
|
||||
title="Alpha Mask to Tensor",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
|
||||
|
||||
image: ImageField = InputField(description="The mask image to convert.")
|
||||
invert: bool = InputField(default=False, description="Whether to invert the mask.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
|
||||
if self.invert:
|
||||
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)
|
||||
else:
|
||||
mask[0] = torch.tensor(np.array(image)[:, :, 3] > 0, dtype=torch.bool)
|
||||
|
||||
return MaskOutput(
|
||||
mask=TensorField(tensor_name=context.tensors.save(mask)),
|
||||
height=mask.shape[1],
|
||||
width=mask.shape[2],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"invert_tensor_mask",
|
||||
title="Invert Tensor Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class InvertTensorMaskInvocation(BaseInvocation):
|
||||
"""Inverts a tensor mask."""
|
||||
|
||||
mask: TensorField = InputField(description="The tensor mask to convert.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
inverted = ~mask
|
||||
|
||||
return MaskOutput(
|
||||
mask=TensorField(tensor_name=context.tensors.save(inverted)),
|
||||
height=inverted.shape[1],
|
||||
width=inverted.shape[2],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_mask_to_tensor",
|
||||
title="Image Mask to Tensor",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
"""Convert a mask image to a tensor. Converts the image to grayscale and uses thresholding at the specified value."""
|
||||
|
||||
image: ImageField = InputField(description="The mask image to convert.")
|
||||
cutoff: int = InputField(ge=0, le=255, description="Cutoff (<)", default=128)
|
||||
invert: bool = InputField(default=False, description="Whether to invert the mask.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="L")
|
||||
|
||||
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
|
||||
if self.invert:
|
||||
mask[0] = torch.tensor(np.array(image)[:, :] >= self.cutoff, dtype=torch.bool)
|
||||
else:
|
||||
mask[0] = torch.tensor(np.array(image)[:, :] < self.cutoff, dtype=torch.bool)
|
||||
|
||||
return MaskOutput(
|
||||
mask=TensorField(tensor_name=context.tensors.save(mask)),
|
||||
height=mask.shape[1],
|
||||
width=mask.shape[2],
|
||||
)
|
||||
|
@ -3,7 +3,6 @@
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from hashlib import sha256
|
||||
@ -43,6 +42,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_base import (
|
||||
@ -112,17 +112,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def start(self, invoker: Optional[Invoker] = None) -> None:
|
||||
"""Start the installer thread."""
|
||||
|
||||
# Yes, this is weird. When the installer thread is running, the
|
||||
# thread masks the ^C signal. When we receive a
|
||||
# sigINT, we stop the thread, reset sigINT, and send a new
|
||||
# sigINT to the parent process.
|
||||
def sigint_handler(signum, frame):
|
||||
self.stop()
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
signal.raise_signal(signal.SIGINT)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
with self._lock:
|
||||
if self._running:
|
||||
raise Exception("Attempt to start the installer service twice")
|
||||
@ -132,7 +121,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
|
||||
# We should only do the scan when the flag is set (which should only be set when testing).
|
||||
if self.app_config.scan_models_on_startup:
|
||||
self._register_orphaned_models()
|
||||
with catch_sigint():
|
||||
self._register_orphaned_models()
|
||||
|
||||
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
|
||||
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
|
||||
|
@ -1,6 +1,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
@ -17,12 +17,6 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteAllResult:
|
||||
deleted_count: int
|
||||
freed_space_bytes: float
|
||||
|
||||
|
||||
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
@ -35,6 +29,12 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._ephemeral:
|
||||
# Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown.
|
||||
for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
||||
self._tempdir = (
|
||||
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
||||
|
@ -301,12 +301,12 @@ class MainConfigBase(ModelConfigBase):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
@ -155,7 +155,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_sd15",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
|
||||
description="IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
@ -163,7 +163,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter Plus",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_plus_sd15",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
|
||||
description="Refined IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
@ -171,7 +171,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter Plus Face",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_plus_face_sd15",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
|
||||
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
@ -179,7 +179,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter SDXL",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="InvokeAI/ip_adapter_sdxl",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
|
||||
description="IP-Adapter for SDXL models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||
|
29
invokeai/backend/util/catch_sigint.py
Normal file
29
invokeai/backend/util/catch_sigint.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""
|
||||
This module defines a context manager `catch_sigint()` which temporarily replaces
|
||||
the sigINT handler defined by the ASGI in order to allow the user to ^C the application
|
||||
and shut it down immediately. This was implemented in order to allow the user to interrupt
|
||||
slow model hashing during startup.
|
||||
|
||||
Use like this:
|
||||
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
with catch_sigint():
|
||||
run_some_hard_to_interrupt_process()
|
||||
"""
|
||||
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
|
||||
def sigint_handler(signum, frame): # type: ignore
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
signal.raise_signal(signal.SIGINT)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def catch_sigint() -> Generator[None, None, None]:
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
yield
|
||||
signal.signal(signal.SIGINT, original_handler)
|
@ -11,6 +11,7 @@ import { createStore } from '../src/app/store/store';
|
||||
// @ts-ignore
|
||||
import translationEN from '../public/locales/en.json';
|
||||
import { ReduxInit } from './ReduxInit';
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
|
||||
i18n.use(initReactI18next).init({
|
||||
lng: 'en',
|
||||
@ -25,6 +26,7 @@ i18n.use(initReactI18next).init({
|
||||
});
|
||||
|
||||
const store = createStore(undefined, false);
|
||||
$store.set(store);
|
||||
$baseUrl.set('http://localhost:9090');
|
||||
|
||||
const preview: Preview = {
|
||||
|
@ -25,7 +25,7 @@
|
||||
"typegen": "node scripts/typegen.js",
|
||||
"preview": "vite preview",
|
||||
"lint:knip": "knip",
|
||||
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
||||
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:0 src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
"lint:prettier": "prettier --check .",
|
||||
"lint:tsc": "tsc --noEmit",
|
||||
@ -95,6 +95,7 @@
|
||||
"reactflow": "^11.10.4",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^5.1.0",
|
||||
"redux-undo": "^1.1.0",
|
||||
"rfdc": "^1.3.1",
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^11.0.3",
|
||||
|
@ -140,6 +140,9 @@ dependencies:
|
||||
redux-remember:
|
||||
specifier: ^5.1.0
|
||||
version: 5.1.0(redux@5.0.1)
|
||||
redux-undo:
|
||||
specifier: ^1.1.0
|
||||
version: 1.1.0
|
||||
rfdc:
|
||||
specifier: ^1.3.1
|
||||
version: 1.3.1
|
||||
@ -11962,6 +11965,10 @@ packages:
|
||||
redux: 5.0.1
|
||||
dev: false
|
||||
|
||||
/redux-undo@1.1.0:
|
||||
resolution: {integrity: sha512-zzLFh2qeF0MTIlzDhDLm9NtkfBqCllQJ3OCuIl5RKlG/ayHw6GUdIFdMhzMS9NnrnWdBX5u//ExMOHpfudGGOg==}
|
||||
dev: false
|
||||
|
||||
/redux@5.0.1:
|
||||
resolution: {integrity: sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==}
|
||||
dev: false
|
||||
|
BIN
invokeai/frontend/web/public/assets/images/transparent_bg.png
Normal file
BIN
invokeai/frontend/web/public/assets/images/transparent_bg.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 KiB |
@ -84,6 +84,8 @@
|
||||
"direction": "Direction",
|
||||
"ipAdapter": "IP Adapter",
|
||||
"t2iAdapter": "T2I Adapter",
|
||||
"positivePrompt": "Positive Prompt",
|
||||
"negativePrompt": "Negative Prompt",
|
||||
"discordLabel": "Discord",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"error": "Error",
|
||||
@ -136,7 +138,9 @@
|
||||
"red": "Red",
|
||||
"green": "Green",
|
||||
"blue": "Blue",
|
||||
"alpha": "Alpha"
|
||||
"alpha": "Alpha",
|
||||
"selected": "Selected",
|
||||
"viewer": "Viewer"
|
||||
},
|
||||
"controlnet": {
|
||||
"controlAdapter_one": "Control Adapter",
|
||||
@ -893,6 +897,7 @@
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"globalSettings": "Global Settings",
|
||||
"height": "Height",
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
"images": "Images",
|
||||
@ -1505,5 +1510,27 @@
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "Store is not initialized"
|
||||
},
|
||||
"regionalPrompts": {
|
||||
"deleteAll": "Delete All",
|
||||
"addLayer": "Add Layer",
|
||||
"moveToFront": "Move to Front",
|
||||
"moveToBack": "Move to Back",
|
||||
"moveForward": "Move Forward",
|
||||
"moveBackward": "Move Backward",
|
||||
"brushSize": "Brush Size",
|
||||
"regionalControl": "Regional Control (ALPHA)",
|
||||
"enableRegionalPrompts": "Enable $t(regionalPrompts.regionalPrompts)",
|
||||
"globalMaskOpacity": "Global Mask Opacity",
|
||||
"autoNegative": "Auto Negative",
|
||||
"toggleVisibility": "Toggle Layer Visibility",
|
||||
"deletePrompt": "Delete Prompt",
|
||||
"resetRegion": "Reset Region",
|
||||
"debugLayers": "Debug Layers",
|
||||
"rectangle": "Rectangle",
|
||||
"maskPreviewColor": "Mask Preview Color",
|
||||
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||
"addIPAdapter": "Add $t(common.ipAdapter)"
|
||||
}
|
||||
}
|
||||
|
@ -27,7 +27,8 @@ export type LoggerNamespace =
|
||||
| 'socketio'
|
||||
| 'session'
|
||||
| 'queue'
|
||||
| 'dnd';
|
||||
| 'dnd'
|
||||
| 'regionalPrompts';
|
||||
|
||||
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
|
||||
|
||||
|
@ -21,6 +21,11 @@ import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workf
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
import { queueSlice } from 'features/queue/store/queueSlice';
|
||||
import {
|
||||
regionalPromptsPersistConfig,
|
||||
regionalPromptsSlice,
|
||||
regionalPromptsUndoableConfig,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||
import { configSlice } from 'features/system/store/configSlice';
|
||||
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
|
||||
@ -30,6 +35,7 @@ import { defaultsDeep, keys, omit, pick } from 'lodash-es';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
import undoable from 'redux-undo';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { api } from 'services/api';
|
||||
import { authToastMiddleware } from 'services/api/authToastMiddleware';
|
||||
@ -59,6 +65,7 @@ const allReducers = {
|
||||
[queueSlice.name]: queueSlice.reducer,
|
||||
[workflowSlice.name]: workflowSlice.reducer,
|
||||
[hrfSlice.name]: hrfSlice.reducer,
|
||||
[regionalPromptsSlice.name]: undoable(regionalPromptsSlice.reducer, regionalPromptsUndoableConfig),
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@ -103,6 +110,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[loraPersistConfig.name]: loraPersistConfig,
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[regionalPromptsPersistConfig.name]: regionalPromptsPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
@ -114,6 +122,7 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
try {
|
||||
const { initialState, migrate } = persistConfig;
|
||||
const parsed = JSON.parse(data);
|
||||
|
||||
// strip out old keys
|
||||
const stripped = pick(parsed, keys(initialState));
|
||||
// run (additive) migrations
|
||||
@ -141,7 +150,9 @@ const serialize: SerializeFunction = (data, key) => {
|
||||
if (!persistConfig) {
|
||||
throw new Error(`No persist config for slice "${key}"`);
|
||||
}
|
||||
const result = omit(data, persistConfig.persistDenylist);
|
||||
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
|
||||
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data;
|
||||
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist);
|
||||
return JSON.stringify(result);
|
||||
};
|
||||
|
||||
|
@ -26,7 +26,7 @@ const sx: ChakraProps['sx'] = {
|
||||
|
||||
const colorPickerStyles: CSSProperties = { width: '100%' };
|
||||
|
||||
const numberInputWidth: ChakraProps['w'] = '4.2rem';
|
||||
const numberInputWidth: ChakraProps['w'] = '3.5rem';
|
||||
|
||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
const { color, onChange, withNumberInput, ...rest } = props;
|
||||
@ -41,7 +41,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
{withNumberInput && (
|
||||
<Flex gap={5}>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.red')}</FormLabel>
|
||||
<FormLabel>{t('common.red')[0]}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.r}
|
||||
onChange={handleChangeR}
|
||||
@ -53,7 +53,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.green')}</FormLabel>
|
||||
<FormLabel>{t('common.green')[0]}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.g}
|
||||
onChange={handleChangeG}
|
||||
@ -65,7 +65,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.blue')}</FormLabel>
|
||||
<FormLabel>{t('common.blue')[0]}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.b}
|
||||
onChange={handleChangeB}
|
||||
@ -77,7 +77,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.alpha')}</FormLabel>
|
||||
<FormLabel>{t('common.alpha')[0]}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.a}
|
||||
onChange={handleChangeA}
|
||||
|
@ -0,0 +1,84 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { RgbColorPicker as ColorfulRgbColorPicker } from 'react-colorful';
|
||||
import type { ColorPickerBaseProps, RgbColor } from 'react-colorful/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type RgbColorPickerProps = ColorPickerBaseProps<RgbColor> & {
|
||||
withNumberInput?: boolean;
|
||||
};
|
||||
|
||||
const colorPickerPointerStyles: NonNullable<ChakraProps['sx']> = {
|
||||
width: 6,
|
||||
height: 6,
|
||||
borderColor: 'base.100',
|
||||
};
|
||||
|
||||
const sx: ChakraProps['sx'] = {
|
||||
'.react-colorful__hue-pointer': colorPickerPointerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
|
||||
gap: 5,
|
||||
flexDir: 'column',
|
||||
};
|
||||
|
||||
const colorPickerStyles: CSSProperties = { width: '100%' };
|
||||
|
||||
const numberInputWidth: ChakraProps['w'] = '3.5rem';
|
||||
|
||||
const RgbColorPicker = (props: RgbColorPickerProps) => {
|
||||
const { color, onChange, withNumberInput, ...rest } = props;
|
||||
const { t } = useTranslation();
|
||||
const handleChangeR = useCallback((r: number) => onChange({ ...color, r }), [color, onChange]);
|
||||
const handleChangeG = useCallback((g: number) => onChange({ ...color, g }), [color, onChange]);
|
||||
const handleChangeB = useCallback((b: number) => onChange({ ...color, b }), [color, onChange]);
|
||||
return (
|
||||
<Flex sx={sx}>
|
||||
<ColorfulRgbColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
|
||||
{withNumberInput && (
|
||||
<Flex gap={5}>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.red')[0]}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.r}
|
||||
onChange={handleChangeR}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
w={numberInputWidth}
|
||||
defaultValue={90}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.green')[0]}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.g}
|
||||
onChange={handleChangeG}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
w={numberInputWidth}
|
||||
defaultValue={90}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl gap={0}>
|
||||
<FormLabel>{t('common.blue')[0]}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={color.b}
|
||||
onChange={handleChangeB}
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
w={numberInputWidth}
|
||||
defaultValue={255}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RgbColorPicker);
|
85
invokeai/frontend/web/src/common/util/arrayUtils.test.ts
Normal file
85
invokeai/frontend/web/src/common/util/arrayUtils.test.ts
Normal file
@ -0,0 +1,85 @@
|
||||
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe('Array Manipulation Functions', () => {
|
||||
const originalArray = ['a', 'b', 'c', 'd'];
|
||||
describe('moveForwardOne', () => {
|
||||
it('should move an item forward by one position', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveForward(array, (item) => item === 'b');
|
||||
expect(result).toEqual(['a', 'c', 'b', 'd']);
|
||||
});
|
||||
|
||||
it('should do nothing if the item is at the end', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveForward(array, (item) => item === 'd');
|
||||
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||
});
|
||||
|
||||
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveForward(array, (item) => item === 'z');
|
||||
expect(result).toEqual(originalArray);
|
||||
});
|
||||
});
|
||||
|
||||
describe('moveToFront', () => {
|
||||
it('should move an item to the front', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveToFront(array, (item) => item === 'c');
|
||||
expect(result).toEqual(['c', 'a', 'b', 'd']);
|
||||
});
|
||||
|
||||
it('should do nothing if the item is already at the front', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveToFront(array, (item) => item === 'a');
|
||||
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||
});
|
||||
|
||||
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveToFront(array, (item) => item === 'z');
|
||||
expect(result).toEqual(originalArray);
|
||||
});
|
||||
});
|
||||
|
||||
describe('moveBackwardsOne', () => {
|
||||
it('should move an item backward by one position', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveBackward(array, (item) => item === 'c');
|
||||
expect(result).toEqual(['a', 'c', 'b', 'd']);
|
||||
});
|
||||
|
||||
it('should do nothing if the item is at the beginning', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveBackward(array, (item) => item === 'a');
|
||||
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||
});
|
||||
|
||||
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveBackward(array, (item) => item === 'z');
|
||||
expect(result).toEqual(originalArray);
|
||||
});
|
||||
});
|
||||
|
||||
describe('moveToBack', () => {
|
||||
it('should move an item to the back', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveToBack(array, (item) => item === 'b');
|
||||
expect(result).toEqual(['a', 'c', 'd', 'b']);
|
||||
});
|
||||
|
||||
it('should do nothing if the item is already at the back', () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveToBack(array, (item) => item === 'd');
|
||||
expect(result).toEqual(['a', 'b', 'c', 'd']);
|
||||
});
|
||||
|
||||
it("should leave the array unchanged if the item isn't in the array", () => {
|
||||
const array = [...originalArray];
|
||||
const result = moveToBack(array, (item) => item === 'z');
|
||||
expect(result).toEqual(originalArray);
|
||||
});
|
||||
});
|
||||
});
|
37
invokeai/frontend/web/src/common/util/arrayUtils.ts
Normal file
37
invokeai/frontend/web/src/common/util/arrayUtils.ts
Normal file
@ -0,0 +1,37 @@
|
||||
export const moveForward = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||
const index = array.findIndex(callback);
|
||||
if (index >= 0 && index < array.length - 1) {
|
||||
//@ts-expect-error - These indicies are safe per the previous check
|
||||
[array[index], array[index + 1]] = [array[index + 1], array[index]];
|
||||
}
|
||||
return array;
|
||||
};
|
||||
|
||||
export const moveToFront = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||
const index = array.findIndex(callback);
|
||||
if (index > 0) {
|
||||
const [item] = array.splice(index, 1);
|
||||
//@ts-expect-error - These indicies are safe per the previous check
|
||||
array.unshift(item);
|
||||
}
|
||||
return array;
|
||||
};
|
||||
|
||||
export const moveBackward = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||
const index = array.findIndex(callback);
|
||||
if (index > 0) {
|
||||
//@ts-expect-error - These indicies are safe per the previous check
|
||||
[array[index], array[index - 1]] = [array[index - 1], array[index]];
|
||||
}
|
||||
return array;
|
||||
};
|
||||
|
||||
export const moveToBack = <T>(array: T[], callback: (item: T) => boolean): T[] => {
|
||||
const index = array.findIndex(callback);
|
||||
if (index >= 0 && index < array.length - 1) {
|
||||
const [item] = array.splice(index, 1);
|
||||
//@ts-expect-error - These indicies are safe per the previous check
|
||||
array.push(item);
|
||||
}
|
||||
return array;
|
||||
};
|
@ -10,6 +10,18 @@ import { clamp } from 'lodash-es';
|
||||
import type { MutableRefObject } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const calculateNewBrushSize = (brushSize: number, delta: number) => {
|
||||
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
||||
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
||||
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||
// This needs to be clamped to prevent the delta from getting too large
|
||||
const finalDelta = clamp(targetDelta, -20, 20);
|
||||
// The new brush size is also clamped to prevent it from getting too large or small
|
||||
const newBrushSize = clamp(brushSize + finalDelta, 1, 500);
|
||||
|
||||
return newBrushSize;
|
||||
};
|
||||
|
||||
const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const stageScale = useAppSelector((s) => s.canvas.stageScale);
|
||||
@ -36,15 +48,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
}
|
||||
|
||||
if ($ctrl.get() || $meta.get()) {
|
||||
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
||||
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
||||
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||
// This needs to be clamped to prevent the delta from getting too large
|
||||
const finalDelta = clamp(targetDelta, -20, 20);
|
||||
// The new brush size is also clamped to prevent it from getting too large or small
|
||||
const newBrushSize = clamp(brushSize + finalDelta, 1, 500);
|
||||
|
||||
dispatch(setBrushSize(newBrushSize));
|
||||
dispatch(setBrushSize(calculateNewBrushSize(brushSize, delta)));
|
||||
} else {
|
||||
const cursorPos = stageRef.current.getPointerPosition();
|
||||
let delta = e.evt.deltaY;
|
||||
|
@ -7,3 +7,22 @@ export const blobToDataURL = (blob: Blob): Promise<string> => {
|
||||
reader.readAsDataURL(blob);
|
||||
});
|
||||
};
|
||||
|
||||
export function imageDataToDataURL(imageData: ImageData): string {
|
||||
const { width, height } = imageData;
|
||||
|
||||
// Create a canvas to transfer the ImageData to
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = width;
|
||||
canvas.height = height;
|
||||
|
||||
// Draw the ImageData onto the canvas
|
||||
const ctx = canvas.getContext('2d');
|
||||
if (!ctx) {
|
||||
throw new Error('Unable to get canvas context');
|
||||
}
|
||||
ctx.putImageData(imageData, 0, 0);
|
||||
|
||||
// Convert the canvas to a data URL (base64)
|
||||
return canvas.toDataURL();
|
||||
}
|
||||
|
@ -1,6 +1,11 @@
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import type { RgbaColor, RgbColor } from 'react-colorful';
|
||||
|
||||
export const rgbaColorToString = (color: RgbaColor): string => {
|
||||
const { r, g, b, a } = color;
|
||||
return `rgba(${r}, ${g}, ${b}, ${a})`;
|
||||
};
|
||||
|
||||
export const rgbColorToString = (color: RgbColor): string => {
|
||||
const { r, g, b } = color;
|
||||
return `rgba(${r}, ${g}, ${b})`;
|
||||
};
|
||||
|
@ -6,6 +6,7 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { maskLayerIPAdapterAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { merge, uniq } from 'lodash-es';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
@ -382,6 +383,10 @@ export const controlAdaptersSlice = createSlice({
|
||||
builder.addCase(socketInvocationError, (state) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
|
||||
builder.addCase(maskLayerIPAdapterAdded, (state, action) => {
|
||||
caAdapter.addOne(state, buildControlAdapter(action.meta.uuid, 'ip_adapter'));
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Box, Flex, IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Tooltip, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import { isString } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
@ -9,18 +9,19 @@ import { PiCopyBold, PiDownloadSimpleBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
label: string;
|
||||
data: object | string;
|
||||
data: unknown;
|
||||
fileName?: string;
|
||||
withDownload?: boolean;
|
||||
withCopy?: boolean;
|
||||
extraCopyActions?: { label: string; getData: (data: unknown) => unknown }[];
|
||||
};
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams('scroll', 'scroll').options;
|
||||
|
||||
const DataViewer = (props: Props) => {
|
||||
const { label, data, fileName, withDownload = true, withCopy = true } = props;
|
||||
const { label, data, fileName, withDownload = true, withCopy = true, extraCopyActions } = props;
|
||||
const dataString = useMemo(() => (isString(data) ? data : JSON.stringify(data, null, 2)), [data]);
|
||||
|
||||
const shift = useShiftModifier();
|
||||
const handleCopy = useCallback(() => {
|
||||
navigator.clipboard.writeText(dataString);
|
||||
}, [dataString]);
|
||||
@ -67,6 +68,10 @@ const DataViewer = (props: Props) => {
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
{shift &&
|
||||
extraCopyActions?.map(({ label, getData }) => (
|
||||
<ExtraCopyAction label={label} getData={getData} data={data} key={label} />
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
@ -78,3 +83,27 @@ const overlayScrollbarsStyles: CSSProperties = {
|
||||
height: '100%',
|
||||
width: '100%',
|
||||
};
|
||||
|
||||
type ExtraCopyActionProps = {
|
||||
label: string;
|
||||
data: unknown;
|
||||
getData: (data: unknown) => unknown;
|
||||
};
|
||||
const ExtraCopyAction = ({ label, data, getData }: ExtraCopyActionProps) => {
|
||||
const { t } = useTranslation();
|
||||
const handleCopy = useCallback(() => {
|
||||
navigator.clipboard.writeText(JSON.stringify(getData(data), null, 2));
|
||||
}, [data, getData]);
|
||||
|
||||
return (
|
||||
<Tooltip label={`${t('gallery.copy')} ${label} JSON`}>
|
||||
<IconButton
|
||||
aria-label={`${t('gallery.copy')} ${label} JSON`}
|
||||
icon={<PiCopyBold size={16} />}
|
||||
variant="ghost"
|
||||
opacity={0.7}
|
||||
onClick={handleCopy}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||
@ -92,13 +92,9 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultPreprocessor control={control} name="preprocessor" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<SimpleGrid columns={2} gap={8}>
|
||||
<DefaultPreprocessor control={control} name="preprocessor" />
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||
@ -122,40 +122,16 @@ export const MainModelDefaultSettings = () => {
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultVae control={control} name="vae" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultWidth control={control} optimalDimension={optimalDimension} />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultHeight control={control} optimalDimension={optimalDimension} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<SimpleGrid columns={2} gap={8}>
|
||||
<DefaultVae control={control} name="vae" />
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
<DefaultWidth control={control} optimalDimension={optimalDimension} />
|
||||
<DefaultHeight control={control} optimalDimension={optimalDimension} />
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -6,6 +6,7 @@ import {
|
||||
FormLabel,
|
||||
Heading,
|
||||
Input,
|
||||
SimpleGrid,
|
||||
Text,
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
@ -66,25 +67,21 @@ export const ModelEdit = ({ form }: Props) => {
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<Flex gap={4}>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={form.control} />
|
||||
@ -93,9 +90,9 @@ export const ModelEdit = ({ form }: Props) => {
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...form.register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Box, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
@ -24,57 +24,32 @@ export const ModelView = () => {
|
||||
return (
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
||||
</Flex>
|
||||
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
||||
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
|
||||
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
||||
</Flex>
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
||||
)}
|
||||
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={data.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
||||
</Flex>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
|
||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||
</Flex>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||
)}
|
||||
</Flex>
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
|
||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
|
||||
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
|
||||
</Box>
|
||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<MainModelDefaultSettings />
|
||||
</Box>
|
||||
)}
|
||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<ControlNetOrT2IAdapterDefaultSettings />
|
||||
</Box>
|
||||
)}
|
||||
{(data.type === 'main' || data.type === 'lora') && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<TriggerPhrases />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -77,9 +77,17 @@ export const TriggerPhrases = () => {
|
||||
[updateModel, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
const onTriggerPhraseAddFormSubmit = useCallback(
|
||||
(e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
addTriggerPhrase();
|
||||
},
|
||||
[addTriggerPhrase]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<form>
|
||||
<form onSubmit={onTriggerPhraseAddFormSubmit}>
|
||||
<FormControl w="full" isInvalid={Boolean(errors.length)} orientation="vertical">
|
||||
<FormLabel>{t('modelManager.triggerPhrases')}</FormLabel>
|
||||
<Flex flexDir="column" w="full">
|
||||
|
@ -23,6 +23,7 @@ export type NodesState = {
|
||||
nodeOpacity: number;
|
||||
shouldSnapToGrid: boolean;
|
||||
shouldColorEdges: boolean;
|
||||
shouldShowEdgeLabels: boolean;
|
||||
selectedNodes: string[];
|
||||
selectedEdges: string[];
|
||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||
@ -32,7 +33,6 @@ export type NodesState = {
|
||||
isAddNodePopoverOpen: boolean;
|
||||
addNewNodePosition: XYPosition | null;
|
||||
selectionMode: SelectionMode;
|
||||
shouldShowEdgeLabels: boolean;
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
|
@ -19,12 +19,14 @@ export const addIPAdapterToLinearGraph = async (
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
|
||||
const hasModel = Boolean(model);
|
||||
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||
const hasControlImage = controlImage;
|
||||
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||
});
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters)
|
||||
.filter(({ model, controlImage, isEnabled }) => {
|
||||
const hasModel = Boolean(model);
|
||||
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||
const hasControlImage = controlImage;
|
||||
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||
})
|
||||
.filter((ca) => !state.regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)));
|
||||
|
||||
if (validIPAdapters.length) {
|
||||
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
||||
|
@ -0,0 +1,346 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectAllIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import {
|
||||
IP_ADAPTER_COLLECT,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
|
||||
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
|
||||
import { size } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { CollectInvocation, Edge, IPAdapterInvocation, NonNullableGraph, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
if (!state.regionalPrompts.present.isEnabled) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
const isSDXL = state.generation.model?.base === 'sdxl';
|
||||
const layers = state.regionalPrompts.present.layers
|
||||
// Only support vector mask layers now
|
||||
// TODO: Image masks
|
||||
.filter(isVectorMaskLayer)
|
||||
// Only visible layers are rendered on the canvas
|
||||
.filter((l) => l.isVisible)
|
||||
// Only layers with prompts get added to the graph
|
||||
.filter((l) => {
|
||||
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
|
||||
const hasIPAdapter = l.ipAdapterIds.length !== 0;
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
});
|
||||
|
||||
const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter(
|
||||
({ id, model, controlImage, isEnabled }) => {
|
||||
const hasModel = Boolean(model);
|
||||
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||
const hasControlImage = controlImage;
|
||||
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
|
||||
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
|
||||
}
|
||||
);
|
||||
|
||||
const layerIds = layers.map((l) => l.id);
|
||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||
|
||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||
// the existing conditioning nodes.
|
||||
|
||||
// With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up
|
||||
const posCondCollectNode: CollectInvocation = {
|
||||
id: POSITIVE_CONDITIONING_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
|
||||
const negCondCollectNode: CollectInvocation = {
|
||||
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
|
||||
|
||||
// Re-route the denoise node's OG conditioning inputs to the collect nodes
|
||||
const newEdges: Edge[] = [];
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
|
||||
newEdges.push({
|
||||
source: edge.source,
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
|
||||
newEdges.push({
|
||||
source: edge.source,
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
newEdges.push(edge);
|
||||
}
|
||||
}
|
||||
graph.edges = newEdges;
|
||||
|
||||
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
});
|
||||
|
||||
if (!graph.nodes[IP_ADAPTER_COLLECT] && regionalIPAdapters.length > 0) {
|
||||
const ipAdapterCollectNode: CollectInvocation = {
|
||||
id: IP_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.nodes[IP_ADAPTER_COLLECT] = ipAdapterCollectNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: IP_ADAPTER_COLLECT, field: 'collection' },
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'ip_adapter',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Upload the blobs to the backend, add each to graph
|
||||
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
|
||||
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
|
||||
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
|
||||
for (const layer of layers) {
|
||||
const blob = blobs[layer.id];
|
||||
assert(blob, `Blob for layer ${layer.id} not found`);
|
||||
|
||||
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
// TODO: This will raise on network error
|
||||
const { image_name } = await req.unwrap();
|
||||
|
||||
// The main mask-to-tensor node
|
||||
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
|
||||
type: 'alpha_mask_to_tensor',
|
||||
image: {
|
||||
image_name,
|
||||
},
|
||||
};
|
||||
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
|
||||
|
||||
if (layer.positivePrompt) {
|
||||
// The main positive conditioning node
|
||||
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
};
|
||||
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
|
||||
});
|
||||
|
||||
// Connect the conditioning to the collector
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
|
||||
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
|
||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.negativePrompt) {
|
||||
// The main negative conditioning node
|
||||
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.negativePrompt,
|
||||
style: layer.negativePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.negativePrompt,
|
||||
};
|
||||
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
|
||||
});
|
||||
|
||||
// Connect the conditioning to the collector
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
|
||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
|
||||
// Copy the connections to the "global" negative conditioning node to the regional cond
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
|
||||
// We re-use the mask image, but invert it when converting to tensor
|
||||
const invertTensorMaskNode: S['InvertTensorMaskInvocation'] = {
|
||||
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
|
||||
type: 'invert_tensor_mask',
|
||||
};
|
||||
graph.nodes[invertTensorMaskNode.id] = invertTensorMaskNode;
|
||||
|
||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: maskToTensorNode.id,
|
||||
field: 'mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: invertTensorMaskNode.id,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
|
||||
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
|
||||
// positive prompt
|
||||
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
};
|
||||
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
|
||||
// Connect the inverted mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: invertTensorMaskNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
|
||||
});
|
||||
// Connect the conditioning to the negative collector
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
|
||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const ipAdapterId of layer.ipAdapterIds) {
|
||||
const ipAdapter = selectAllIPAdapters(state.controlAdapters)
|
||||
.filter(({ id, model, controlImage, isEnabled }) => {
|
||||
const hasModel = Boolean(model);
|
||||
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||
const hasControlImage = controlImage;
|
||||
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
|
||||
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
|
||||
})
|
||||
.find((ca) => ca.id === ipAdapterId);
|
||||
|
||||
if (!ipAdapter?.model) {
|
||||
return;
|
||||
}
|
||||
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
|
||||
assert(controlImage, 'IP Adapter image is required');
|
||||
|
||||
const ipAdapterNode: IPAdapterInvocation = {
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
is_intermediate: true,
|
||||
weight: weight,
|
||||
method: method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image: {
|
||||
image_name: controlImage,
|
||||
},
|
||||
};
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: ipAdapterNode.id, field: 'mask' },
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
destination: {
|
||||
node_id: IP_ADAPTER_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
@ -9,6 +9,7 @@ import {
|
||||
CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
INPAINT_CREATE_MASK,
|
||||
INPAINT_IMAGE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
@ -145,6 +146,16 @@ export const addVAEToGraph = async (
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
source: {
|
||||
|
@ -133,6 +133,8 @@ export const buildCanvasInpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
minimum_denoise: canvasCoherenceMinDenoise,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -182,6 +184,16 @@ export const buildCanvasInpaintGraph = async (
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect CLIP Skip to Conditioning
|
||||
{
|
||||
source: {
|
||||
@ -331,6 +343,16 @@ export const buildCanvasInpaintGraph = async (
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Down
|
||||
{
|
||||
source: {
|
||||
|
@ -157,6 +157,8 @@ export const buildCanvasOutpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
minimum_denoise: canvasCoherenceMinDenoise,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -207,6 +209,16 @@ export const buildCanvasOutpaintGraph = async (
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect CLIP Skip to Conditioning
|
||||
{
|
||||
source: {
|
||||
@ -453,6 +465,16 @@ export const buildCanvasOutpaintGraph = async (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
|
@ -135,6 +135,8 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -214,6 +216,16 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect Everything To Inpaint Node
|
||||
{
|
||||
source: {
|
||||
@ -342,6 +354,16 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Down
|
||||
{
|
||||
source: {
|
||||
|
@ -157,6 +157,8 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -237,6 +239,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect Infill Result To Inpaint Image
|
||||
{
|
||||
source: {
|
||||
@ -451,6 +463,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Take combined mask and resize
|
||||
{
|
||||
source: {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@ -273,6 +274,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
await addRegionalPromptsToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
@ -255,6 +256,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
await addRegionalPromptsToGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled) {
|
||||
addHrfToGraph(state, graph);
|
||||
|
@ -46,6 +46,13 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
|
||||
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
||||
export const SEAMLESS = 'seamless';
|
||||
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
||||
export const PROMPT_REGION_MASK_TO_TENSOR_PREFIX = 'prompt_region_mask_to_tensor';
|
||||
export const PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX = 'prompt_region_invert_tensor_mask';
|
||||
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
|
||||
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
|
||||
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
|
||||
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||
|
||||
// friendly graph ids
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
|
@ -0,0 +1,13 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const AspectRatioCanvasPreview = memo(() => {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center" position="relative">
|
||||
<StageComponent asPreview />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
AspectRatioCanvasPreview.displayName = 'AspectRatioCanvasPreview';
|
@ -2,7 +2,7 @@ import { useSize } from '@chakra-ui/react-use-size';
|
||||
import { Flex, Icon } from '@invoke-ai/ui-library';
|
||||
import { useImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { useMemo, useRef } from 'react';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
import { PiFrameCorners } from 'react-icons/pi';
|
||||
|
||||
import {
|
||||
@ -15,7 +15,7 @@ import {
|
||||
MOTION_ICON_INITIAL,
|
||||
} from './constants';
|
||||
|
||||
export const AspectRatioPreview = () => {
|
||||
export const AspectRatioIconPreview = memo(() => {
|
||||
const ctx = useImageSizeContext();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const containerSize = useSize(containerRef);
|
||||
@ -70,4 +70,6 @@ export const AspectRatioPreview = () => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
AspectRatioIconPreview.displayName = 'AspectRatioIconPreview';
|
@ -1,6 +1,5 @@
|
||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||
import { Flex, FormControlGroup } from '@invoke-ai/ui-library';
|
||||
import { AspectRatioPreview } from 'features/parameters/components/ImageSize/AspectRatioPreview';
|
||||
import { AspectRatioSelect } from 'features/parameters/components/ImageSize/AspectRatioSelect';
|
||||
import type { ImageSizeContextInnerValue } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||
import { ImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||
@ -13,10 +12,11 @@ import { memo } from 'react';
|
||||
type ImageSizeProps = ImageSizeContextInnerValue & {
|
||||
widthComponent: ReactNode;
|
||||
heightComponent: ReactNode;
|
||||
previewComponent: ReactNode;
|
||||
};
|
||||
|
||||
export const ImageSize = memo((props: ImageSizeProps) => {
|
||||
const { widthComponent, heightComponent, ...ctx } = props;
|
||||
const { widthComponent, heightComponent, previewComponent, ...ctx } = props;
|
||||
return (
|
||||
<ImageSizeContext.Provider value={ctx}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
@ -33,7 +33,7 @@ export const ImageSize = memo((props: ImageSizeProps) => {
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
<Flex w="108px" h="108px" flexShrink={0} flexGrow={0}>
|
||||
<AspectRatioPreview />
|
||||
{previewComponent}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</ImageSizeContext.Provider>
|
||||
|
@ -1,7 +1,6 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
|
||||
import type { AspectRatioID, AspectRatioState } from './types';
|
||||
|
||||
// When the aspect ratio is between these two values, we show the icon (experimentally determined)
|
||||
export const ICON_LOW_CUTOFF = 0.23;
|
||||
export const ICON_HIGH_CUTOFF = 1 / ICON_LOW_CUTOFF;
|
||||
@ -25,7 +24,6 @@ export const ICON_CONTAINER_STYLES = {
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
};
|
||||
|
||||
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
|
||||
{ label: 'Free' as const, value: 'Free' },
|
||||
{ label: '16:9' as const, value: '16:9' },
|
||||
|
@ -196,3 +196,8 @@ const zLoRAWeight = z.number();
|
||||
type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
|
||||
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Regional Prompts AutoNegative
|
||||
const zAutoNegative = z.enum(['off', 'invert']);
|
||||
export type ParameterAutoNegative = z.infer<typeof zAutoNegative>;
|
||||
// #endregion
|
||||
|
@ -3,6 +3,7 @@ import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataView
|
||||
import { useCancelBatch } from 'features/queue/hooks/useCancelBatch';
|
||||
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
|
||||
import { getSecondsFromTimestamps } from 'features/queue/util/getSecondsFromTimestamps';
|
||||
import { get } from 'lodash-es';
|
||||
import type { ReactNode } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -92,7 +93,15 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
||||
</Flex>
|
||||
)}
|
||||
<Flex layerStyle="second" h={512} w="full" borderRadius="base" alignItems="center" justifyContent="center">
|
||||
{queueItem ? <DataViewer label="Queue Item" data={queueItem} /> : <Spinner opacity={0.5} />}
|
||||
{queueItem ? (
|
||||
<DataViewer
|
||||
label="Queue Item"
|
||||
data={queueItem}
|
||||
extraCopyActions={[{ label: 'Graph', getData: (data) => get(data, 'session.graph') }]}
|
||||
/>
|
||||
) : (
|
||||
<Spinner opacity={0.5} />
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -0,0 +1,22 @@
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { layerAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const AddLayerButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(layerAdded('vector_mask_layer'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Button onClick={onClick} leftIcon={<PiPlusBold />} variant="ghost">
|
||||
{t('regionalPrompts.addLayer')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
AddLayerButton.displayName = 'AddLayerButton';
|
@ -0,0 +1,70 @@
|
||||
import { Button, Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
isVectorMaskLayer,
|
||||
maskLayerIPAdapterAdded,
|
||||
maskLayerNegativePromptChanged,
|
||||
maskLayerPositivePromptChanged,
|
||||
selectRegionalPromptsSlice,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import { assert } from 'tsafe';
|
||||
type AddPromptButtonProps = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const AddPromptButtons = ({ layerId }: AddPromptButtonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectValidActions = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||
return {
|
||||
canAddPositivePrompt: layer.positivePrompt === null,
|
||||
canAddNegativePrompt: layer.negativePrompt === null,
|
||||
};
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const validActions = useAppSelector(selectValidActions);
|
||||
const addPositivePrompt = useCallback(() => {
|
||||
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: '' }));
|
||||
}, [dispatch, layerId]);
|
||||
const addNegativePrompt = useCallback(() => {
|
||||
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
|
||||
}, [dispatch, layerId]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
dispatch(maskLayerIPAdapterAdded(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
|
||||
return (
|
||||
<Flex w="full" p={2} justifyContent="space-between">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addPositivePrompt}
|
||||
isDisabled={!validActions.canAddPositivePrompt}
|
||||
>
|
||||
{t('common.positivePrompt')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addNegativePrompt}
|
||||
isDisabled={!validActions.canAddNegativePrompt}
|
||||
>
|
||||
{t('common.negativePrompt')}
|
||||
</Button>
|
||||
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('common.ipAdapter')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,63 @@
|
||||
import {
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { brushSizeChanged, initialRegionalPromptsState } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const marks = [0, 100, 200, 300];
|
||||
const formatPx = (v: number | string) => `${v} px`;
|
||||
|
||||
export const BrushSize = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const brushSize = useAppSelector((s) => s.regionalPrompts.present.brushSize);
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(brushSizeChanged(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
return (
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('regionalPrompts.brushSize')}</FormLabel>
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<CompositeNumberInput
|
||||
min={1}
|
||||
max={600}
|
||||
defaultValue={initialRegionalPromptsState.brushSize}
|
||||
value={brushSize}
|
||||
onChange={onChange}
|
||||
w={24}
|
||||
format={formatPx}
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent w={200} py={2} px={4}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<CompositeSlider
|
||||
min={1}
|
||||
max={300}
|
||||
defaultValue={initialRegionalPromptsState.brushSize}
|
||||
value={brushSize}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
BrushSize.displayName = 'BrushSize';
|
@ -0,0 +1,22 @@
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { allLayersDeleted } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
export const DeleteAllLayersButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(allLayersDeleted());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Button onClick={onClick} leftIcon={<PiTrashSimpleBold />} variant="ghost" colorScheme="error">
|
||||
{t('regionalPrompts.deleteAll')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
DeleteAllLayersButton.displayName = 'DeleteAllLayersButton';
|
@ -0,0 +1,70 @@
|
||||
import {
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
globalMaskLayerOpacityChanged,
|
||||
initialRegionalPromptsState,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const marks = [0, 25, 50, 75, 100];
|
||||
const formatPct = (v: number | string) => `${v} %`;
|
||||
|
||||
export const GlobalMaskLayerOpacity = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const globalMaskLayerOpacity = useAppSelector((s) =>
|
||||
Math.round(s.regionalPrompts.present.globalMaskLayerOpacity * 100)
|
||||
);
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(globalMaskLayerOpacityChanged(v / 100));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
return (
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('regionalPrompts.globalMaskOpacity')}</FormLabel>
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<CompositeNumberInput
|
||||
min={0}
|
||||
max={100}
|
||||
step={1}
|
||||
value={globalMaskLayerOpacity}
|
||||
defaultValue={initialRegionalPromptsState.globalMaskLayerOpacity * 100}
|
||||
onChange={onChange}
|
||||
w={24}
|
||||
format={formatPct}
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent w={200} py={2} px={4}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<CompositeSlider
|
||||
min={0}
|
||||
max={100}
|
||||
step={1}
|
||||
value={globalMaskLayerOpacity}
|
||||
defaultValue={initialRegionalPromptsState.globalMaskLayerOpacity * 100}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
GlobalMaskLayerOpacity.displayName = 'GlobalMaskLayerOpacity';
|
@ -0,0 +1,51 @@
|
||||
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
isVectorMaskLayer,
|
||||
maskLayerAutoNegativeChanged,
|
||||
selectRegionalPromptsSlice,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
const useAutoNegative = (layerId: string) => {
|
||||
const selectAutoNegative = useMemo(
|
||||
() =>
|
||||
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||
return layer.autoNegative;
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const autoNegative = useAppSelector(selectAutoNegative);
|
||||
return autoNegative;
|
||||
};
|
||||
|
||||
export const RPLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const autoNegative = useAutoNegative(layerId);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(maskLayerAutoNegativeChanged({ layerId, autoNegative: e.target.checked ? 'invert' : 'off' }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl gap={2}>
|
||||
<FormLabel m={0}>{t('regionalPrompts.autoNegative')}</FormLabel>
|
||||
<Checkbox size="md" isChecked={autoNegative === 'invert'} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerAutoNegativeCheckbox.displayName = 'RPLayerAutoNegativeCheckbox';
|
@ -0,0 +1,67 @@
|
||||
import { Flex, Popover, PopoverBody, PopoverContent, PopoverTrigger, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import RgbColorPicker from 'common/components/RgbColorPicker';
|
||||
import { rgbColorToString } from 'features/canvas/util/colorToString';
|
||||
import {
|
||||
isVectorMaskLayer,
|
||||
maskLayerPreviewColorChanged,
|
||||
selectRegionalPromptsSlice,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { RgbColor } from 'react-colorful';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RPLayerColorPicker = memo(({ layerId }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const selectColor = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an vector mask layer`);
|
||||
return layer.previewColor;
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const color = useAppSelector(selectColor);
|
||||
const dispatch = useAppDispatch();
|
||||
const onColorChange = useCallback(
|
||||
(color: RgbColor) => {
|
||||
dispatch(maskLayerPreviewColorChanged({ layerId, color }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<span>
|
||||
<Tooltip label={t('regionalPrompts.maskPreviewColor')}>
|
||||
<Flex
|
||||
as="button"
|
||||
aria-label={t('regionalPrompts.maskPreviewColor')}
|
||||
borderRadius="base"
|
||||
borderWidth={1}
|
||||
bg={rgbColorToString(color)}
|
||||
w={8}
|
||||
h={8}
|
||||
cursor="pointer"
|
||||
tabIndex={-1}
|
||||
/>
|
||||
</Tooltip>
|
||||
</span>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverBody minH={64}>
|
||||
<RgbColorPicker color={color} onChange={onColorChange} withNumberInput />
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerColorPicker.displayName = 'RPLayerColorPicker';
|
@ -0,0 +1,28 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { layerDeleted } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
type Props = { layerId: string };
|
||||
|
||||
export const RPLayerDeleteButton = memo(({ layerId }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const deleteLayer = useCallback(() => {
|
||||
dispatch(layerDeleted(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
return (
|
||||
<IconButton
|
||||
size="sm"
|
||||
colorScheme="error"
|
||||
aria-label={t('common.delete')}
|
||||
tooltip={t('common.delete')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={deleteLayer}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerDeleteButton.displayName = 'RPLayerDeleteButton';
|
@ -0,0 +1,34 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
|
||||
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RPLayerIPAdapterList = memo(({ layerId }: Props) => {
|
||||
const selectIPAdapterIds = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(layer, `Layer ${layerId} not found`);
|
||||
return layer.ipAdapterIds;
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const ipAdapterIds = useAppSelector(selectIPAdapterIds);
|
||||
|
||||
return (
|
||||
<Flex w="full" flexDir="column" gap={2}>
|
||||
{ipAdapterIds.map((id, index) => (
|
||||
<ControlAdapterConfig key={id} id={id} number={index + 1} />
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerIPAdapterList.displayName = 'RPLayerIPAdapterList';
|
@ -0,0 +1,87 @@
|
||||
import { Badge, Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { rgbColorToString } from 'features/canvas/util/colorToString';
|
||||
import { RPLayerColorPicker } from 'features/regionalPrompts/components/RPLayerColorPicker';
|
||||
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
|
||||
import { RPLayerIPAdapterList } from 'features/regionalPrompts/components/RPLayerIPAdapterList';
|
||||
import { RPLayerMenu } from 'features/regionalPrompts/components/RPLayerMenu';
|
||||
import { RPLayerNegativePrompt } from 'features/regionalPrompts/components/RPLayerNegativePrompt';
|
||||
import { RPLayerPositivePrompt } from 'features/regionalPrompts/components/RPLayerPositivePrompt';
|
||||
import RPLayerSettingsPopover from 'features/regionalPrompts/components/RPLayerSettingsPopover';
|
||||
import { RPLayerVisibilityToggle } from 'features/regionalPrompts/components/RPLayerVisibilityToggle';
|
||||
import {
|
||||
isVectorMaskLayer,
|
||||
layerSelected,
|
||||
selectRegionalPromptsSlice,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { AddPromptButtons } from './AddPromptButtons';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RPLayerListItem = memo(({ layerId }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||
return {
|
||||
color: rgbColorToString(layer.previewColor),
|
||||
hasPositivePrompt: layer.positivePrompt !== null,
|
||||
hasNegativePrompt: layer.negativePrompt !== null,
|
||||
hasIPAdapters: layer.ipAdapterIds.length > 0,
|
||||
isSelected: layerId === regionalPrompts.present.selectedLayerId,
|
||||
autoNegative: layer.autoNegative,
|
||||
};
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const { autoNegative, color, hasPositivePrompt, hasNegativePrompt, hasIPAdapters, isSelected } =
|
||||
useAppSelector(selector);
|
||||
const onClickCapture = useCallback(() => {
|
||||
// Must be capture so that the layer is selected before deleting/resetting/etc
|
||||
dispatch(layerSelected(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
return (
|
||||
<Flex
|
||||
gap={2}
|
||||
onClickCapture={onClickCapture}
|
||||
bg={isSelected ? color : 'base.800'}
|
||||
ps={2}
|
||||
borderRadius="base"
|
||||
pe="1px"
|
||||
py="1px"
|
||||
cursor="pointer"
|
||||
>
|
||||
<Flex flexDir="column" gap={2} w="full" bg="base.850" p={2} borderRadius="base">
|
||||
<Flex gap={3} alignItems="center">
|
||||
<RPLayerVisibilityToggle layerId={layerId} />
|
||||
<RPLayerColorPicker layerId={layerId} />
|
||||
<Spacer />
|
||||
{autoNegative === 'invert' && (
|
||||
<Badge color="base.300" bg="transparent" borderWidth={1}>
|
||||
{t('regionalPrompts.autoNegative')}
|
||||
</Badge>
|
||||
)}
|
||||
<RPLayerDeleteButton layerId={layerId} />
|
||||
<RPLayerSettingsPopover layerId={layerId} />
|
||||
<RPLayerMenu layerId={layerId} />
|
||||
</Flex>
|
||||
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons layerId={layerId} />}
|
||||
{hasPositivePrompt && <RPLayerPositivePrompt layerId={layerId} />}
|
||||
{hasNegativePrompt && <RPLayerNegativePrompt layerId={layerId} />}
|
||||
{hasIPAdapters && <RPLayerIPAdapterList layerId={layerId} />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerListItem.displayName = 'RPLayerListItem';
|
@ -0,0 +1,120 @@
|
||||
import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
isVectorMaskLayer,
|
||||
layerDeleted,
|
||||
layerMovedBackward,
|
||||
layerMovedForward,
|
||||
layerMovedToBack,
|
||||
layerMovedToFront,
|
||||
layerReset,
|
||||
maskLayerIPAdapterAdded,
|
||||
maskLayerNegativePromptChanged,
|
||||
maskLayerPositivePromptChanged,
|
||||
selectRegionalPromptsSlice,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
PiArrowCounterClockwiseBold,
|
||||
PiArrowDownBold,
|
||||
PiArrowLineDownBold,
|
||||
PiArrowLineUpBold,
|
||||
PiArrowUpBold,
|
||||
PiDotsThreeVerticalBold,
|
||||
PiPlusBold,
|
||||
PiTrashSimpleBold,
|
||||
} from 'react-icons/pi';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = { layerId: string };
|
||||
|
||||
export const RPLayerMenu = memo(({ layerId }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectValidActions = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||
const layerIndex = regionalPrompts.present.layers.findIndex((l) => l.id === layerId);
|
||||
const layerCount = regionalPrompts.present.layers.length;
|
||||
return {
|
||||
canAddPositivePrompt: layer.positivePrompt === null,
|
||||
canAddNegativePrompt: layer.negativePrompt === null,
|
||||
canMoveForward: layerIndex < layerCount - 1,
|
||||
canMoveBackward: layerIndex > 0,
|
||||
canMoveToFront: layerIndex < layerCount - 1,
|
||||
canMoveToBack: layerIndex > 0,
|
||||
};
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const validActions = useAppSelector(selectValidActions);
|
||||
const addPositivePrompt = useCallback(() => {
|
||||
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: '' }));
|
||||
}, [dispatch, layerId]);
|
||||
const addNegativePrompt = useCallback(() => {
|
||||
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
|
||||
}, [dispatch, layerId]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
dispatch(maskLayerIPAdapterAdded(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
const moveForward = useCallback(() => {
|
||||
dispatch(layerMovedForward(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
const moveToFront = useCallback(() => {
|
||||
dispatch(layerMovedToFront(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
const moveBackward = useCallback(() => {
|
||||
dispatch(layerMovedBackward(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
const moveToBack = useCallback(() => {
|
||||
dispatch(layerMovedToBack(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
const resetLayer = useCallback(() => {
|
||||
dispatch(layerReset(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
const deleteLayer = useCallback(() => {
|
||||
dispatch(layerDeleted(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={IconButton} aria-label="Layer menu" size="sm" icon={<PiDotsThreeVerticalBold />} />
|
||||
<MenuList>
|
||||
<MenuItem onClick={addPositivePrompt} isDisabled={!validActions.canAddPositivePrompt} icon={<PiPlusBold />}>
|
||||
{t('regionalPrompts.addPositivePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={addNegativePrompt} isDisabled={!validActions.canAddNegativePrompt} icon={<PiPlusBold />}>
|
||||
{t('regionalPrompts.addNegativePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={addIPAdapter} icon={<PiPlusBold />}>
|
||||
{t('regionalPrompts.addIPAdapter')}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem onClick={moveToFront} isDisabled={!validActions.canMoveToFront} icon={<PiArrowLineUpBold />}>
|
||||
{t('regionalPrompts.moveToFront')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={moveForward} isDisabled={!validActions.canMoveForward} icon={<PiArrowUpBold />}>
|
||||
{t('regionalPrompts.moveForward')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={moveBackward} isDisabled={!validActions.canMoveBackward} icon={<PiArrowDownBold />}>
|
||||
{t('regionalPrompts.moveBackward')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={moveToBack} isDisabled={!validActions.canMoveToBack} icon={<PiArrowLineDownBold />}>
|
||||
{t('regionalPrompts.moveToBack')}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem onClick={resetLayer} icon={<PiArrowCounterClockwiseBold />}>
|
||||
{t('accessibility.reset')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={deleteLayer} icon={<PiTrashSimpleBold />} color="error.300">
|
||||
{t('common.delete')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerMenu.displayName = 'RPLayerMenu';
|
@ -0,0 +1,58 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { RPLayerPromptDeleteButton } from 'features/regionalPrompts/components/RPLayerPromptDeleteButton';
|
||||
import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||
import { maskLayerNegativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RPLayerNegativePrompt = memo(({ layerId }: Props) => {
|
||||
const prompt = useLayerNegativePrompt(layerId);
|
||||
const dispatch = useAppDispatch();
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { t } = useTranslation();
|
||||
const _onChange = useCallback(
|
||||
(v: string) => {
|
||||
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: v }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative" w="full">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={textareaRef}
|
||||
value={prompt}
|
||||
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||
onChange={onChange}
|
||||
onKeyDown={onKeyDown}
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
fontSize="sm"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<RPLayerPromptDeleteButton layerId={layerId} polarity="negative" />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerNegativePrompt.displayName = 'RPLayerNegativePrompt';
|
@ -0,0 +1,58 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { RPLayerPromptDeleteButton } from 'features/regionalPrompts/components/RPLayerPromptDeleteButton';
|
||||
import { useLayerPositivePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||
import { maskLayerPositivePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RPLayerPositivePrompt = memo(({ layerId }: Props) => {
|
||||
const prompt = useLayerPositivePrompt(layerId);
|
||||
const dispatch = useAppDispatch();
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { t } = useTranslation();
|
||||
const _onChange = useCallback(
|
||||
(v: string) => {
|
||||
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: v }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative" w="full">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={textareaRef}
|
||||
value={prompt}
|
||||
placeholder={t('parameters.positivePromptPlaceholder')}
|
||||
onChange={onChange}
|
||||
onKeyDown={onKeyDown}
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
minH={28}
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<RPLayerPromptDeleteButton layerId={layerId} polarity="positive" />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerPositivePrompt.displayName = 'RPLayerPositivePrompt';
|
@ -0,0 +1,38 @@
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
maskLayerNegativePromptChanged,
|
||||
maskLayerPositivePromptChanged,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
polarity: 'positive' | 'negative';
|
||||
};
|
||||
|
||||
export const RPLayerPromptDeleteButton = memo(({ layerId, polarity }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const onClick = useCallback(() => {
|
||||
if (polarity === 'positive') {
|
||||
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: null }));
|
||||
} else {
|
||||
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: null }));
|
||||
}
|
||||
}, [dispatch, layerId, polarity]);
|
||||
return (
|
||||
<Tooltip label={t('regionalPrompts.deletePrompt')}>
|
||||
<IconButton
|
||||
variant="promptOverlay"
|
||||
aria-label={t('regionalPrompts.deletePrompt')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={onClick}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerPromptDeleteButton.displayName = 'RPLayerPromptDeleteButton';
|
@ -0,0 +1,53 @@
|
||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Flex,
|
||||
FormControlGroup,
|
||||
IconButton,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { RPLayerAutoNegativeCheckbox } from 'features/regionalPrompts/components/RPLayerAutoNegativeCheckbox';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
const formLabelProps: FormLabelProps = {
|
||||
flexGrow: 1,
|
||||
minW: 32,
|
||||
};
|
||||
|
||||
const RPLayerSettingsPopover = ({ layerId }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<IconButton
|
||||
tooltip={t('common.settingsLabel')}
|
||||
aria-label={t('common.settingsLabel')}
|
||||
size="sm"
|
||||
icon={<PiGearSixBold />}
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<Flex direction="column" gap={2}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<RPLayerAutoNegativeCheckbox layerId={layerId} />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RPLayerSettingsPopover);
|
@ -0,0 +1,34 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useLayerIsVisible } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||
import { layerVisibilityToggled } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RPLayerVisibilityToggle = memo(({ layerId }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isVisible = useLayerIsVisible(layerId);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(layerVisibilityToggled(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label={t('regionalPrompts.toggleVisibility')}
|
||||
tooltip={t('regionalPrompts.toggleVisibility')}
|
||||
variant="outline"
|
||||
icon={isVisible ? <PiCheckBold /> : undefined}
|
||||
onClick={onClick}
|
||||
colorScheme="base"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
RPLayerVisibilityToggle.displayName = 'RPLayerVisibilityToggle';
|
@ -0,0 +1,24 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
import { RegionalPromptsEditor } from 'features/regionalPrompts/components/RegionalPromptsEditor';
|
||||
|
||||
const meta: Meta<typeof RegionalPromptsEditor> = {
|
||||
title: 'Feature/RegionalPrompts',
|
||||
tags: ['autodocs'],
|
||||
component: RegionalPromptsEditor,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof RegionalPromptsEditor>;
|
||||
|
||||
const Component = () => {
|
||||
return (
|
||||
<Flex w={1500} h={1500}>
|
||||
<RegionalPromptsEditor />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
@ -0,0 +1,24 @@
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { RegionalPromptsToolbar } from 'features/regionalPrompts/components/RegionalPromptsToolbar';
|
||||
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const RegionalPromptsEditor = memo(() => {
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
flexDirection="column"
|
||||
height="100%"
|
||||
width="100%"
|
||||
rowGap={4}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<RegionalPromptsToolbar />
|
||||
<StageComponent />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
RegionalPromptsEditor.displayName = 'RegionalPromptsEditor';
|
@ -0,0 +1,38 @@
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButton';
|
||||
import { DeleteAllLayersButton } from 'features/regionalPrompts/components/DeleteAllLayersButton';
|
||||
import { RPLayerListItem } from 'features/regionalPrompts/components/RPLayerListItem';
|
||||
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectRPLayerIdsReversed = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
|
||||
regionalPrompts.present.layers
|
||||
.filter(isVectorMaskLayer)
|
||||
.map((l) => l.id)
|
||||
.reverse()
|
||||
);
|
||||
|
||||
export const RegionalPromptsPanelContent = memo(() => {
|
||||
const rpLayerIdsReversed = useAppSelector(selectRPLayerIdsReversed);
|
||||
return (
|
||||
<Flex flexDir="column" gap={4} w="full" h="full">
|
||||
<Flex justifyContent="space-around">
|
||||
<AddLayerButton />
|
||||
<DeleteAllLayersButton />
|
||||
</Flex>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
{rpLayerIdsReversed.map((id) => (
|
||||
<RPLayerListItem key={id} layerId={id} />
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
RegionalPromptsPanelContent.displayName = 'RegionalPromptsPanelContent';
|
@ -0,0 +1,20 @@
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { BrushSize } from 'features/regionalPrompts/components/BrushSize';
|
||||
import { GlobalMaskLayerOpacity } from 'features/regionalPrompts/components/GlobalMaskLayerOpacity';
|
||||
import { ToolChooser } from 'features/regionalPrompts/components/ToolChooser';
|
||||
import { UndoRedoButtonGroup } from 'features/regionalPrompts/components/UndoRedoButtonGroup';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const RegionalPromptsToolbar = memo(() => {
|
||||
return (
|
||||
<Flex gap={4}>
|
||||
<BrushSize />
|
||||
<GlobalMaskLayerOpacity />
|
||||
<UndoRedoButtonGroup />
|
||||
<ToolChooser />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
RegionalPromptsToolbar.displayName = 'RegionalPromptsToolbar';
|
@ -0,0 +1,232 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMouseEvents } from 'features/regionalPrompts/hooks/mouseEventHooks';
|
||||
import {
|
||||
$cursorPosition,
|
||||
$isMouseOver,
|
||||
$lastMouseDownPos,
|
||||
$tool,
|
||||
isVectorMaskLayer,
|
||||
layerBboxChanged,
|
||||
layerSelected,
|
||||
layerTranslated,
|
||||
selectRegionalPromptsSlice,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { debouncedRenderers, renderers as normalRenderers } from 'features/regionalPrompts/util/renderers';
|
||||
import Konva from 'konva';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import type { MutableRefObject } from 'react';
|
||||
import { memo, useCallback, useLayoutEffect, useMemo, useRef, useState } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
// This will log warnings when layers > 5 - maybe use `import.meta.env.MODE === 'development'` instead?
|
||||
Konva.showWarnings = false;
|
||||
|
||||
const log = logger('regionalPrompts');
|
||||
|
||||
const selectSelectedLayerColor = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === regionalPrompts.present.selectedLayerId);
|
||||
if (!layer) {
|
||||
return null;
|
||||
}
|
||||
assert(isVectorMaskLayer(layer), `Layer ${regionalPrompts.present.selectedLayerId} is not an RP layer`);
|
||||
return layer.previewColor;
|
||||
});
|
||||
|
||||
const useStageRenderer = (
|
||||
stageRef: MutableRefObject<Konva.Stage>,
|
||||
container: HTMLDivElement | null,
|
||||
wrapper: HTMLDivElement | null,
|
||||
asPreview: boolean
|
||||
) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const width = useAppSelector((s) => s.generation.width);
|
||||
const height = useAppSelector((s) => s.generation.height);
|
||||
const state = useAppSelector((s) => s.regionalPrompts.present);
|
||||
const tool = useStore($tool);
|
||||
const { onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel } = useMouseEvents();
|
||||
const cursorPosition = useStore($cursorPosition);
|
||||
const lastMouseDownPos = useStore($lastMouseDownPos);
|
||||
const isMouseOver = useStore($isMouseOver);
|
||||
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
|
||||
const layerIds = useMemo(() => state.layers.map((l) => l.id), [state.layers]);
|
||||
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
|
||||
|
||||
const onLayerPosChanged = useCallback(
|
||||
(layerId: string, x: number, y: number) => {
|
||||
dispatch(layerTranslated({ layerId, x, y }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onBboxChanged = useCallback(
|
||||
(layerId: string, bbox: IRect | null) => {
|
||||
dispatch(layerBboxChanged({ layerId, bbox }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onBboxMouseDown = useCallback(
|
||||
(layerId: string) => {
|
||||
dispatch(layerSelected(layerId));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Initializing stage');
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
const stage = stageRef.current.container(container);
|
||||
return () => {
|
||||
log.trace('Cleaning up stage');
|
||||
stage.destroy();
|
||||
};
|
||||
}, [container, stageRef]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Adding stage listeners');
|
||||
if (asPreview) {
|
||||
return;
|
||||
}
|
||||
stageRef.current.on('mousedown', onMouseDown);
|
||||
stageRef.current.on('mouseup', onMouseUp);
|
||||
stageRef.current.on('mousemove', onMouseMove);
|
||||
stageRef.current.on('mouseenter', onMouseEnter);
|
||||
stageRef.current.on('mouseleave', onMouseLeave);
|
||||
stageRef.current.on('wheel', onMouseWheel);
|
||||
const stage = stageRef.current;
|
||||
|
||||
return () => {
|
||||
log.trace('Cleaning up stage listeners');
|
||||
stage.off('mousedown', onMouseDown);
|
||||
stage.off('mouseup', onMouseUp);
|
||||
stage.off('mousemove', onMouseMove);
|
||||
stage.off('mouseenter', onMouseEnter);
|
||||
stage.off('mouseleave', onMouseLeave);
|
||||
stage.off('wheel', onMouseWheel);
|
||||
};
|
||||
}, [stageRef, asPreview, onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Updating stage dimensions');
|
||||
if (!wrapper) {
|
||||
return;
|
||||
}
|
||||
|
||||
const stage = stageRef.current;
|
||||
|
||||
const fitStageToContainer = () => {
|
||||
const newXScale = wrapper.offsetWidth / width;
|
||||
const newYScale = wrapper.offsetHeight / height;
|
||||
const newScale = Math.min(newXScale, newYScale, 1);
|
||||
stage.width(width * newScale);
|
||||
stage.height(height * newScale);
|
||||
stage.scaleX(newScale);
|
||||
stage.scaleY(newScale);
|
||||
};
|
||||
|
||||
const resizeObserver = new ResizeObserver(fitStageToContainer);
|
||||
resizeObserver.observe(wrapper);
|
||||
fitStageToContainer();
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [stageRef, width, height, wrapper]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering tool preview');
|
||||
if (asPreview) {
|
||||
// Preview should not display tool
|
||||
return;
|
||||
}
|
||||
renderers.renderToolPreview(
|
||||
stageRef.current,
|
||||
tool,
|
||||
selectedLayerIdColor,
|
||||
state.globalMaskLayerOpacity,
|
||||
cursorPosition,
|
||||
lastMouseDownPos,
|
||||
isMouseOver,
|
||||
state.brushSize
|
||||
);
|
||||
}, [
|
||||
asPreview,
|
||||
stageRef,
|
||||
tool,
|
||||
selectedLayerIdColor,
|
||||
state.globalMaskLayerOpacity,
|
||||
cursorPosition,
|
||||
lastMouseDownPos,
|
||||
isMouseOver,
|
||||
state.brushSize,
|
||||
renderers,
|
||||
]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering layers');
|
||||
renderers.renderLayers(stageRef.current, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
}, [stageRef, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged, renderers]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering bbox');
|
||||
if (asPreview) {
|
||||
// Preview should not display bboxes
|
||||
return;
|
||||
}
|
||||
renderers.renderBbox(stageRef.current, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
|
||||
}, [stageRef, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderers]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering background');
|
||||
if (asPreview) {
|
||||
// The preview should not have a background
|
||||
return;
|
||||
}
|
||||
renderers.renderBackground(stageRef.current, width, height);
|
||||
}, [stageRef, asPreview, width, height, renderers]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Arranging layers');
|
||||
renderers.arrangeLayers(stageRef.current, layerIds);
|
||||
}, [stageRef, layerIds, renderers]);
|
||||
};
|
||||
|
||||
type Props = {
|
||||
asPreview?: boolean;
|
||||
};
|
||||
|
||||
export const StageComponent = memo(({ asPreview = false }: Props) => {
|
||||
const stageRef = useRef<Konva.Stage>(
|
||||
new Konva.Stage({
|
||||
container: document.createElement('div'), // We will overwrite this shortly...
|
||||
})
|
||||
);
|
||||
const [container, setContainer] = useState<HTMLDivElement | null>(null);
|
||||
const [wrapper, setWrapper] = useState<HTMLDivElement | null>(null);
|
||||
|
||||
const containerRef = useCallback((el: HTMLDivElement | null) => {
|
||||
setContainer(el);
|
||||
}, []);
|
||||
|
||||
const wrapperRef = useCallback((el: HTMLDivElement | null) => {
|
||||
setWrapper(el);
|
||||
}, []);
|
||||
|
||||
useStageRenderer(stageRef, container, wrapper, asPreview);
|
||||
|
||||
return (
|
||||
<Flex overflow="hidden" w="full" h="full">
|
||||
<Flex ref={wrapperRef} w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Flex ref={containerRef} tabIndex={-1} bg="base.850" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StageComponent.displayName = 'StageComponent';
|
@ -0,0 +1,89 @@
|
||||
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
$tool,
|
||||
layerAdded,
|
||||
selectedLayerDeleted,
|
||||
selectedLayerReset,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsOutCardinalBold, PiEraserBold, PiPaintBrushBold, PiRectangleBold } from 'react-icons/pi';
|
||||
|
||||
export const ToolChooser: React.FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isDisabled = useAppSelector((s) => s.regionalPrompts.present.layers.length === 0);
|
||||
const tool = useStore($tool);
|
||||
|
||||
const setToolToBrush = useCallback(() => {
|
||||
$tool.set('brush');
|
||||
}, []);
|
||||
useHotkeys('b', setToolToBrush, { enabled: !isDisabled }, [isDisabled]);
|
||||
const setToolToEraser = useCallback(() => {
|
||||
$tool.set('eraser');
|
||||
}, []);
|
||||
useHotkeys('e', setToolToEraser, { enabled: !isDisabled }, [isDisabled]);
|
||||
const setToolToRect = useCallback(() => {
|
||||
$tool.set('rect');
|
||||
}, []);
|
||||
useHotkeys('u', setToolToRect, { enabled: !isDisabled }, [isDisabled]);
|
||||
const setToolToMove = useCallback(() => {
|
||||
$tool.set('move');
|
||||
}, []);
|
||||
useHotkeys('v', setToolToMove, { enabled: !isDisabled }, [isDisabled]);
|
||||
|
||||
const resetSelectedLayer = useCallback(() => {
|
||||
dispatch(selectedLayerReset());
|
||||
}, [dispatch]);
|
||||
useHotkeys('shift+c', resetSelectedLayer);
|
||||
|
||||
const addLayer = useCallback(() => {
|
||||
dispatch(layerAdded('vector_mask_layer'));
|
||||
}, [dispatch]);
|
||||
useHotkeys('shift+a', addLayer);
|
||||
|
||||
const deleteSelectedLayer = useCallback(() => {
|
||||
dispatch(selectedLayerDeleted());
|
||||
}, [dispatch]);
|
||||
useHotkeys('shift+d', deleteSelectedLayer);
|
||||
|
||||
return (
|
||||
<ButtonGroup isAttached>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.brush')} (B)`}
|
||||
tooltip={`${t('unifiedCanvas.brush')} (B)`}
|
||||
icon={<PiPaintBrushBold />}
|
||||
variant={tool === 'brush' ? 'solid' : 'outline'}
|
||||
onClick={setToolToBrush}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.eraser')} (E)`}
|
||||
tooltip={`${t('unifiedCanvas.eraser')} (E)`}
|
||||
icon={<PiEraserBold />}
|
||||
variant={tool === 'eraser' ? 'solid' : 'outline'}
|
||||
onClick={setToolToEraser}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('regionalPrompts.rectangle')} (U)`}
|
||||
tooltip={`${t('regionalPrompts.rectangle')} (U)`}
|
||||
icon={<PiRectangleBold />}
|
||||
variant={tool === 'rect' ? 'solid' : 'outline'}
|
||||
onClick={setToolToRect}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.move')} (V)`}
|
||||
tooltip={`${t('unifiedCanvas.move')} (V)`}
|
||||
icon={<PiArrowsOutCardinalBold />}
|
||||
variant={tool === 'move' ? 'solid' : 'outline'}
|
||||
onClick={setToolToMove}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
);
|
||||
};
|
@ -0,0 +1,49 @@
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { redo, undo } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowClockwiseBold, PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
export const UndoRedoButtonGroup = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const mayUndo = useAppSelector((s) => s.regionalPrompts.past.length > 0);
|
||||
const handleUndo = useCallback(() => {
|
||||
dispatch(undo());
|
||||
}, [dispatch]);
|
||||
useHotkeys(['meta+z', 'ctrl+z'], handleUndo, { enabled: mayUndo, preventDefault: true }, [mayUndo, handleUndo]);
|
||||
|
||||
const mayRedo = useAppSelector((s) => s.regionalPrompts.future.length > 0);
|
||||
const handleRedo = useCallback(() => {
|
||||
dispatch(redo());
|
||||
}, [dispatch]);
|
||||
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], handleRedo, { enabled: mayRedo, preventDefault: true }, [
|
||||
mayRedo,
|
||||
handleRedo,
|
||||
]);
|
||||
|
||||
return (
|
||||
<ButtonGroup>
|
||||
<IconButton
|
||||
aria-label={t('unifiedCanvas.undo')}
|
||||
tooltip={t('unifiedCanvas.undo')}
|
||||
onClick={handleUndo}
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
isDisabled={!mayUndo}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={t('unifiedCanvas.redo')}
|
||||
tooltip={t('unifiedCanvas.redo')}
|
||||
onClick={handleRedo}
|
||||
icon={<PiArrowClockwiseBold />}
|
||||
isDisabled={!mayRedo}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
);
|
||||
});
|
||||
|
||||
UndoRedoButtonGroup.displayName = 'UndoRedoButtonGroup';
|
@ -0,0 +1,49 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useLayerPositivePrompt = (layerId: string) => {
|
||||
const selectLayer = useMemo(
|
||||
() =>
|
||||
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||
assert(layer.positivePrompt !== null, `Layer ${layerId} does not have a positive prompt`);
|
||||
return layer.positivePrompt;
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const prompt = useAppSelector(selectLayer);
|
||||
return prompt;
|
||||
};
|
||||
|
||||
export const useLayerNegativePrompt = (layerId: string) => {
|
||||
const selectLayer = useMemo(
|
||||
() =>
|
||||
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||
assert(layer.negativePrompt !== null, `Layer ${layerId} does not have a negative prompt`);
|
||||
return layer.negativePrompt;
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const prompt = useAppSelector(selectLayer);
|
||||
return prompt;
|
||||
};
|
||||
|
||||
export const useLayerIsVisible = (layerId: string) => {
|
||||
const selectLayer = useMemo(
|
||||
() =>
|
||||
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
assert(isVectorMaskLayer(layer), `Layer ${layerId} not found or not an RP layer`);
|
||||
return layer.isVisible;
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const isVisible = useAppSelector(selectLayer);
|
||||
return isVisible;
|
||||
};
|
@ -0,0 +1,217 @@
|
||||
import { $ctrl, $meta } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
|
||||
import {
|
||||
$cursorPosition,
|
||||
$isMouseDown,
|
||||
$isMouseOver,
|
||||
$lastMouseDownPos,
|
||||
$tool,
|
||||
brushSizeChanged,
|
||||
maskLayerLineAdded,
|
||||
maskLayerPointsAdded,
|
||||
maskLayerRectAdded,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import type Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Vector2d } from 'konva/lib/types';
|
||||
import { useCallback, useRef } from 'react';
|
||||
|
||||
const getIsFocused = (stage: Konva.Stage) => {
|
||||
return stage.container().contains(document.activeElement);
|
||||
};
|
||||
|
||||
export const getScaledFlooredCursorPosition = (stage: Konva.Stage) => {
|
||||
const pointerPosition = stage.getPointerPosition();
|
||||
const stageTransform = stage.getAbsoluteTransform().copy();
|
||||
if (!pointerPosition || !stageTransform) {
|
||||
return;
|
||||
}
|
||||
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
||||
return {
|
||||
x: Math.floor(scaledCursorPosition.x),
|
||||
y: Math.floor(scaledCursorPosition.y),
|
||||
};
|
||||
};
|
||||
|
||||
const syncCursorPos = (stage: Konva.Stage): Vector2d | null => {
|
||||
const pos = getScaledFlooredCursorPosition(stage);
|
||||
if (!pos) {
|
||||
return null;
|
||||
}
|
||||
$cursorPosition.set(pos);
|
||||
return pos;
|
||||
};
|
||||
|
||||
const BRUSH_SPACING = 20;
|
||||
|
||||
export const useMouseEvents = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedLayerId = useAppSelector((s) => s.regionalPrompts.present.selectedLayerId);
|
||||
const tool = useStore($tool);
|
||||
const lastCursorPosRef = useRef<[number, number] | null>(null);
|
||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
||||
const brushSize = useAppSelector((s) => s.regionalPrompts.present.brushSize);
|
||||
|
||||
const onMouseDown = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos) {
|
||||
return;
|
||||
}
|
||||
$isMouseDown.set(true);
|
||||
$lastMouseDownPos.set(pos);
|
||||
if (!selectedLayerId) {
|
||||
return;
|
||||
}
|
||||
if (tool === 'brush' || tool === 'eraser') {
|
||||
dispatch(
|
||||
maskLayerLineAdded({
|
||||
layerId: selectedLayerId,
|
||||
points: [pos.x, pos.y, pos.x, pos.y],
|
||||
tool,
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
[dispatch, selectedLayerId, tool]
|
||||
);
|
||||
|
||||
const onMouseUp = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
$isMouseDown.set(false);
|
||||
const pos = $cursorPosition.get();
|
||||
const lastPos = $lastMouseDownPos.get();
|
||||
const tool = $tool.get();
|
||||
if (pos && lastPos && selectedLayerId && tool === 'rect') {
|
||||
dispatch(
|
||||
maskLayerRectAdded({
|
||||
layerId: selectedLayerId,
|
||||
rect: {
|
||||
x: Math.min(pos.x, lastPos.x),
|
||||
y: Math.min(pos.y, lastPos.y),
|
||||
width: Math.abs(pos.x - lastPos.x),
|
||||
height: Math.abs(pos.y - lastPos.y),
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
$lastMouseDownPos.set(null);
|
||||
},
|
||||
[dispatch, selectedLayerId]
|
||||
);
|
||||
|
||||
const onMouseMove = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos || !selectedLayerId) {
|
||||
return;
|
||||
}
|
||||
if (getIsFocused(stage) && $isMouseOver.get() && $isMouseDown.get() && (tool === 'brush' || tool === 'eraser')) {
|
||||
if (lastCursorPosRef.current) {
|
||||
// Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
|
||||
if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < BRUSH_SPACING) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
lastCursorPosRef.current = [pos.x, pos.y];
|
||||
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
|
||||
}
|
||||
},
|
||||
[dispatch, selectedLayerId, tool]
|
||||
);
|
||||
|
||||
const onMouseLeave = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage);
|
||||
if (
|
||||
pos &&
|
||||
selectedLayerId &&
|
||||
getIsFocused(stage) &&
|
||||
$isMouseOver.get() &&
|
||||
$isMouseDown.get() &&
|
||||
(tool === 'brush' || tool === 'eraser')
|
||||
) {
|
||||
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
|
||||
}
|
||||
$isMouseOver.set(false);
|
||||
$isMouseDown.set(false);
|
||||
$cursorPosition.set(null);
|
||||
},
|
||||
[selectedLayerId, tool, dispatch]
|
||||
);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
$isMouseOver.set(true);
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos) {
|
||||
return;
|
||||
}
|
||||
if (!getIsFocused(stage)) {
|
||||
return;
|
||||
}
|
||||
if (e.evt.buttons !== 1) {
|
||||
$isMouseDown.set(false);
|
||||
} else {
|
||||
$isMouseDown.set(true);
|
||||
if (!selectedLayerId) {
|
||||
return;
|
||||
}
|
||||
if (tool === 'brush' || tool === 'eraser') {
|
||||
dispatch(
|
||||
maskLayerLineAdded({
|
||||
layerId: selectedLayerId,
|
||||
points: [pos.x, pos.y, pos.x, pos.y],
|
||||
tool,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
[dispatch, selectedLayerId, tool]
|
||||
);
|
||||
|
||||
const onMouseWheel = useCallback(
|
||||
(e: KonvaEventObject<WheelEvent>) => {
|
||||
e.evt.preventDefault();
|
||||
|
||||
// checking for ctrl key is pressed or not,
|
||||
// so that brush size can be controlled using ctrl + scroll up/down
|
||||
|
||||
// Invert the delta if the property is set to true
|
||||
let delta = e.evt.deltaY;
|
||||
if (shouldInvertBrushSizeScrollDirection) {
|
||||
delta = -delta;
|
||||
}
|
||||
|
||||
if ($ctrl.get() || $meta.get()) {
|
||||
dispatch(brushSizeChanged(calculateNewBrushSize(brushSize, delta)));
|
||||
}
|
||||
},
|
||||
[shouldInvertBrushSizeScrollDirection, brushSize, dispatch]
|
||||
);
|
||||
|
||||
return { onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel };
|
||||
};
|
@ -0,0 +1,30 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
if (!regionalPrompts.present.isEnabled) {
|
||||
return 0;
|
||||
}
|
||||
const validLayers = regionalPrompts.present.layers
|
||||
.filter((l) => l.isVisible)
|
||||
.filter((l) => {
|
||||
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
|
||||
const hasAtLeastOneImagePrompt = l.ipAdapterIds.length > 0;
|
||||
return hasTextPrompt || hasAtLeastOneImagePrompt;
|
||||
});
|
||||
|
||||
return validLayers.length;
|
||||
});
|
||||
|
||||
export const useRegionalControlTitle = () => {
|
||||
const { t } = useTranslation();
|
||||
const validLayerCount = useAppSelector(selectValidLayerCount);
|
||||
const title = useMemo(() => {
|
||||
const suffix = validLayerCount > 0 ? ` (${validLayerCount})` : '';
|
||||
return `${t('regionalPrompts.regionalControl')}${suffix}`;
|
||||
}, [t, validLayerCount]);
|
||||
return title;
|
||||
};
|
@ -0,0 +1,496 @@
|
||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
|
||||
import { controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import type { RgbColor } from 'react-colorful';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
type DrawingTool = 'brush' | 'eraser';
|
||||
|
||||
export type Tool = DrawingTool | 'move' | 'rect';
|
||||
|
||||
export type VectorMaskLine = {
|
||||
id: string;
|
||||
type: 'vector_mask_line';
|
||||
tool: DrawingTool;
|
||||
strokeWidth: number;
|
||||
points: number[];
|
||||
};
|
||||
|
||||
export type VectorMaskRect = {
|
||||
id: string;
|
||||
type: 'vector_mask_rect';
|
||||
x: number;
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
};
|
||||
|
||||
type LayerBase = {
|
||||
id: string;
|
||||
x: number;
|
||||
y: number;
|
||||
bbox: IRect | null;
|
||||
bboxNeedsUpdate: boolean;
|
||||
isVisible: boolean;
|
||||
};
|
||||
|
||||
type MaskLayerBase = LayerBase & {
|
||||
positivePrompt: string | null;
|
||||
negativePrompt: string | null; // Up to one text prompt per mask
|
||||
ipAdapterIds: string[]; // Any number of image prompts
|
||||
previewColor: RgbColor;
|
||||
autoNegative: ParameterAutoNegative;
|
||||
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
|
||||
};
|
||||
|
||||
export type VectorMaskLayer = MaskLayerBase & {
|
||||
type: 'vector_mask_layer';
|
||||
objects: (VectorMaskLine | VectorMaskRect)[];
|
||||
};
|
||||
|
||||
export type Layer = VectorMaskLayer;
|
||||
|
||||
type RegionalPromptsState = {
|
||||
_version: 1;
|
||||
selectedLayerId: string | null;
|
||||
layers: Layer[];
|
||||
brushSize: number;
|
||||
globalMaskLayerOpacity: number;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
export const initialRegionalPromptsState: RegionalPromptsState = {
|
||||
_version: 1,
|
||||
selectedLayerId: null,
|
||||
brushSize: 100,
|
||||
layers: [],
|
||||
globalMaskLayerOpacity: 0.5, // this globally changes all mask layers' opacity
|
||||
isEnabled: true,
|
||||
};
|
||||
|
||||
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
|
||||
export const isVectorMaskLayer = (layer?: Layer): layer is VectorMaskLayer => layer?.type === 'vector_mask_layer';
|
||||
const resetLayer = (layer: VectorMaskLayer) => {
|
||||
layer.objects = [];
|
||||
layer.bbox = null;
|
||||
layer.isVisible = true;
|
||||
layer.needsPixelBbox = false;
|
||||
layer.bboxNeedsUpdate = false;
|
||||
};
|
||||
|
||||
export const regionalPromptsSlice = createSlice({
|
||||
name: 'regionalPrompts',
|
||||
initialState: initialRegionalPromptsState,
|
||||
reducers: {
|
||||
//#region All Layers
|
||||
layerAdded: {
|
||||
reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => {
|
||||
const kind = action.payload;
|
||||
if (action.payload === 'vector_mask_layer') {
|
||||
const lastColor = state.layers[state.layers.length - 1]?.previewColor;
|
||||
const previewColor = LayerColors.next(lastColor);
|
||||
const layer: VectorMaskLayer = {
|
||||
id: getVectorMaskLayerId(action.meta.uuid),
|
||||
type: kind,
|
||||
isVisible: true,
|
||||
bbox: null,
|
||||
bboxNeedsUpdate: false,
|
||||
objects: [],
|
||||
previewColor,
|
||||
x: 0,
|
||||
y: 0,
|
||||
autoNegative: 'invert',
|
||||
needsPixelBbox: false,
|
||||
positivePrompt: '',
|
||||
negativePrompt: null,
|
||||
ipAdapterIds: [],
|
||||
};
|
||||
state.layers.push(layer);
|
||||
state.selectedLayerId = layer.id;
|
||||
return;
|
||||
}
|
||||
},
|
||||
prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }),
|
||||
},
|
||||
layerSelected: (state, action: PayloadAction<string>) => {
|
||||
const layer = state.layers.find((l) => l.id === action.payload);
|
||||
if (layer) {
|
||||
state.selectedLayerId = layer.id;
|
||||
}
|
||||
},
|
||||
layerVisibilityToggled: (state, action: PayloadAction<string>) => {
|
||||
const layer = state.layers.find((l) => l.id === action.payload);
|
||||
if (layer) {
|
||||
layer.isVisible = !layer.isVisible;
|
||||
}
|
||||
},
|
||||
layerTranslated: (state, action: PayloadAction<{ layerId: string; x: number; y: number }>) => {
|
||||
const { layerId, x, y } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
layer.x = x;
|
||||
layer.y = y;
|
||||
}
|
||||
},
|
||||
layerBboxChanged: (state, action: PayloadAction<{ layerId: string; bbox: IRect | null }>) => {
|
||||
const { layerId, bbox } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
layer.bbox = bbox;
|
||||
layer.bboxNeedsUpdate = false;
|
||||
}
|
||||
},
|
||||
layerReset: (state, action: PayloadAction<string>) => {
|
||||
const layer = state.layers.find((l) => l.id === action.payload);
|
||||
if (layer) {
|
||||
resetLayer(layer);
|
||||
}
|
||||
},
|
||||
layerDeleted: (state, action: PayloadAction<string>) => {
|
||||
state.layers = state.layers.filter((l) => l.id !== action.payload);
|
||||
state.selectedLayerId = state.layers[0]?.id ?? null;
|
||||
},
|
||||
layerMovedForward: (state, action: PayloadAction<string>) => {
|
||||
const cb = (l: Layer) => l.id === action.payload;
|
||||
moveForward(state.layers, cb);
|
||||
},
|
||||
layerMovedToFront: (state, action: PayloadAction<string>) => {
|
||||
const cb = (l: Layer) => l.id === action.payload;
|
||||
// Because the layers are in reverse order, moving to the front is equivalent to moving to the back
|
||||
moveToBack(state.layers, cb);
|
||||
},
|
||||
layerMovedBackward: (state, action: PayloadAction<string>) => {
|
||||
const cb = (l: Layer) => l.id === action.payload;
|
||||
moveBackward(state.layers, cb);
|
||||
},
|
||||
layerMovedToBack: (state, action: PayloadAction<string>) => {
|
||||
const cb = (l: Layer) => l.id === action.payload;
|
||||
// Because the layers are in reverse order, moving to the back is equivalent to moving to the front
|
||||
moveToFront(state.layers, cb);
|
||||
},
|
||||
allLayersDeleted: (state) => {
|
||||
state.layers = [];
|
||||
state.selectedLayerId = null;
|
||||
},
|
||||
selectedLayerReset: (state) => {
|
||||
const layer = state.layers.find((l) => l.id === state.selectedLayerId);
|
||||
if (layer) {
|
||||
resetLayer(layer);
|
||||
}
|
||||
},
|
||||
selectedLayerDeleted: (state) => {
|
||||
state.layers = state.layers.filter((l) => l.id !== state.selectedLayerId);
|
||||
state.selectedLayerId = state.layers[0]?.id ?? null;
|
||||
},
|
||||
//#endregion
|
||||
|
||||
//#region Mask Layers
|
||||
maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
|
||||
const { layerId, prompt } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
layer.positivePrompt = prompt;
|
||||
}
|
||||
},
|
||||
maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
|
||||
const { layerId, prompt } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
layer.negativePrompt = prompt;
|
||||
}
|
||||
},
|
||||
maskLayerIPAdapterAdded: {
|
||||
reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => {
|
||||
const layer = state.layers.find((l) => l.id === action.payload);
|
||||
if (layer) {
|
||||
layer.ipAdapterIds.push(action.meta.uuid);
|
||||
}
|
||||
},
|
||||
prepare: (payload: string) => ({ payload, meta: { uuid: uuidv4() } }),
|
||||
},
|
||||
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
|
||||
const { layerId, color } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
layer.previewColor = color;
|
||||
}
|
||||
},
|
||||
maskLayerLineAdded: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
{ layerId: string; points: [number, number, number, number]; tool: DrawingTool },
|
||||
string,
|
||||
{ uuid: string }
|
||||
>
|
||||
) => {
|
||||
const { layerId, points, tool } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
const lineId = getVectorMaskLayerLineId(layer.id, action.meta.uuid);
|
||||
layer.objects.push({
|
||||
type: 'vector_mask_line',
|
||||
tool: tool,
|
||||
id: lineId,
|
||||
// Points must be offset by the layer's x and y coordinates
|
||||
// TODO: Handle this in the event listener?
|
||||
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
|
||||
strokeWidth: state.brushSize,
|
||||
});
|
||||
layer.bboxNeedsUpdate = true;
|
||||
if (!layer.needsPixelBbox && tool === 'eraser') {
|
||||
layer.needsPixelBbox = true;
|
||||
}
|
||||
}
|
||||
},
|
||||
prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({
|
||||
payload,
|
||||
meta: { uuid: uuidv4() },
|
||||
}),
|
||||
},
|
||||
maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
|
||||
const { layerId, point } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
const lastLine = layer.objects.findLast(isLine);
|
||||
if (!lastLine) {
|
||||
return;
|
||||
}
|
||||
// Points must be offset by the layer's x and y coordinates
|
||||
// TODO: Handle this in the event listener
|
||||
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
|
||||
layer.bboxNeedsUpdate = true;
|
||||
}
|
||||
},
|
||||
maskLayerRectAdded: {
|
||||
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect }, string, { uuid: string }>) => {
|
||||
const { layerId, rect } = action.payload;
|
||||
if (rect.height === 0 || rect.width === 0) {
|
||||
// Ignore zero-area rectangles
|
||||
return;
|
||||
}
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
const id = getVectorMaskLayerRectId(layer.id, action.meta.uuid);
|
||||
layer.objects.push({
|
||||
type: 'vector_mask_rect',
|
||||
id,
|
||||
x: rect.x - layer.x,
|
||||
y: rect.y - layer.y,
|
||||
width: rect.width,
|
||||
height: rect.height,
|
||||
});
|
||||
layer.bboxNeedsUpdate = true;
|
||||
}
|
||||
},
|
||||
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload, meta: { uuid: uuidv4() } }),
|
||||
},
|
||||
maskLayerAutoNegativeChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
|
||||
) => {
|
||||
const { layerId, autoNegative } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
layer.autoNegative = autoNegative;
|
||||
}
|
||||
},
|
||||
//#endregion
|
||||
|
||||
//#region General
|
||||
brushSizeChanged: (state, action: PayloadAction<number>) => {
|
||||
state.brushSize = action.payload;
|
||||
},
|
||||
globalMaskLayerOpacityChanged: (state, action: PayloadAction<number>) => {
|
||||
state.globalMaskLayerOpacity = action.payload;
|
||||
},
|
||||
isEnabledChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.isEnabled = action.payload;
|
||||
},
|
||||
undo: (state) => {
|
||||
// Invalidate the bbox for all layers to prevent stale bboxes
|
||||
for (const layer of state.layers) {
|
||||
layer.bboxNeedsUpdate = true;
|
||||
}
|
||||
},
|
||||
redo: (state) => {
|
||||
// Invalidate the bbox for all layers to prevent stale bboxes
|
||||
for (const layer of state.layers) {
|
||||
layer.bboxNeedsUpdate = true;
|
||||
}
|
||||
},
|
||||
//#endregion
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(controlAdapterRemoved, (state, action) => {
|
||||
for (const layer of state.layers) {
|
||||
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id);
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* This class is used to cycle through a set of colors for the prompt region layers.
|
||||
*/
|
||||
class LayerColors {
|
||||
static COLORS: RgbColor[] = [
|
||||
{ r: 121, g: 157, b: 219 }, // rgb(121, 157, 219)
|
||||
{ r: 131, g: 214, b: 131 }, // rgb(131, 214, 131)
|
||||
{ r: 250, g: 225, b: 80 }, // rgb(250, 225, 80)
|
||||
{ r: 220, g: 144, b: 101 }, // rgb(220, 144, 101)
|
||||
{ r: 224, g: 117, b: 117 }, // rgb(224, 117, 117)
|
||||
{ r: 213, g: 139, b: 202 }, // rgb(213, 139, 202)
|
||||
{ r: 161, g: 120, b: 214 }, // rgb(161, 120, 214)
|
||||
];
|
||||
static i = this.COLORS.length - 1;
|
||||
/**
|
||||
* Get the next color in the sequence. If a known color is provided, the next color will be the one after it.
|
||||
*/
|
||||
static next(currentColor?: RgbColor): RgbColor {
|
||||
if (currentColor) {
|
||||
const i = this.COLORS.findIndex((c) => isEqual(c, currentColor));
|
||||
if (i !== -1) {
|
||||
this.i = i;
|
||||
}
|
||||
}
|
||||
this.i = (this.i + 1) % this.COLORS.length;
|
||||
const color = this.COLORS[this.i];
|
||||
assert(color);
|
||||
return color;
|
||||
}
|
||||
}
|
||||
|
||||
export const {
|
||||
// All layer actions
|
||||
layerAdded,
|
||||
layerDeleted,
|
||||
layerMovedBackward,
|
||||
layerMovedForward,
|
||||
layerMovedToBack,
|
||||
layerMovedToFront,
|
||||
layerReset,
|
||||
layerSelected,
|
||||
layerTranslated,
|
||||
layerBboxChanged,
|
||||
layerVisibilityToggled,
|
||||
allLayersDeleted,
|
||||
selectedLayerReset,
|
||||
selectedLayerDeleted,
|
||||
// Mask layer actions
|
||||
maskLayerLineAdded,
|
||||
maskLayerPointsAdded,
|
||||
maskLayerRectAdded,
|
||||
maskLayerNegativePromptChanged,
|
||||
maskLayerPositivePromptChanged,
|
||||
maskLayerIPAdapterAdded,
|
||||
maskLayerAutoNegativeChanged,
|
||||
maskLayerPreviewColorChanged,
|
||||
// General actions
|
||||
brushSizeChanged,
|
||||
globalMaskLayerOpacityChanged,
|
||||
undo,
|
||||
redo,
|
||||
} = regionalPromptsSlice.actions;
|
||||
|
||||
export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateRegionalPromptsState = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const $isMouseDown = atom(false);
|
||||
export const $isMouseOver = atom(false);
|
||||
export const $lastMouseDownPos = atom<Vector2d | null>(null);
|
||||
export const $tool = atom<Tool>('brush');
|
||||
export const $cursorPosition = atom<Vector2d | null>(null);
|
||||
|
||||
// IDs for singleton Konva layers and objects
|
||||
export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer';
|
||||
export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group';
|
||||
export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill';
|
||||
export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner';
|
||||
export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer';
|
||||
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
|
||||
export const BACKGROUND_LAYER_ID = 'background_layer';
|
||||
export const BACKGROUND_RECT_ID = 'background_layer.rect';
|
||||
|
||||
// Names (aka classes) for Konva layers and objects
|
||||
export const VECTOR_MASK_LAYER_NAME = 'vector_mask_layer';
|
||||
export const VECTOR_MASK_LAYER_LINE_NAME = 'vector_mask_layer.line';
|
||||
export const VECTOR_MASK_LAYER_OBJECT_GROUP_NAME = 'vector_mask_layer.object_group';
|
||||
export const VECTOR_MASK_LAYER_RECT_NAME = 'vector_mask_layer.rect';
|
||||
export const LAYER_BBOX_NAME = 'layer.bbox';
|
||||
|
||||
// Getters for non-singleton layer and object IDs
|
||||
const getVectorMaskLayerId = (layerId: string) => `${VECTOR_MASK_LAYER_NAME}_${layerId}`;
|
||||
const getVectorMaskLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
|
||||
const getVectorMaskLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
|
||||
export const getVectorMaskLayerObjectGroupId = (layerId: string, groupId: string) =>
|
||||
`${layerId}.objectGroup_${groupId}`;
|
||||
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
|
||||
|
||||
export const regionalPromptsPersistConfig: PersistConfig<RegionalPromptsState> = {
|
||||
name: regionalPromptsSlice.name,
|
||||
initialState: initialRegionalPromptsState,
|
||||
migrate: migrateRegionalPromptsState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
// These actions are _individually_ grouped together as single undoable actions
|
||||
const undoableGroupByMatcher = isAnyOf(
|
||||
layerTranslated,
|
||||
brushSizeChanged,
|
||||
globalMaskLayerOpacityChanged,
|
||||
maskLayerPositivePromptChanged,
|
||||
maskLayerNegativePromptChanged,
|
||||
maskLayerPreviewColorChanged
|
||||
);
|
||||
|
||||
// These are used to group actions into logical lines below (hate typos)
|
||||
const LINE_1 = 'LINE_1';
|
||||
const LINE_2 = 'LINE_2';
|
||||
|
||||
export const regionalPromptsUndoableConfig: UndoableOptions<RegionalPromptsState, UnknownAction> = {
|
||||
limit: 64,
|
||||
undoType: regionalPromptsSlice.actions.undo.type,
|
||||
redoType: regionalPromptsSlice.actions.redo.type,
|
||||
groupBy: (action, state, history) => {
|
||||
// Lines are started with `maskLayerLineAdded` and may have any number of subsequent `maskLayerPointsAdded` events.
|
||||
// We can use a double-buffer-esque trick to group each "logical" line as a single undoable action, without grouping
|
||||
// separate logical lines as a single undo action.
|
||||
if (maskLayerLineAdded.match(action)) {
|
||||
return history.group === LINE_1 ? LINE_2 : LINE_1;
|
||||
}
|
||||
if (maskLayerPointsAdded.match(action)) {
|
||||
if (history.group === LINE_1 || history.group === LINE_2) {
|
||||
return history.group;
|
||||
}
|
||||
}
|
||||
if (undoableGroupByMatcher(action)) {
|
||||
return action.type;
|
||||
}
|
||||
return null;
|
||||
},
|
||||
filter: (action, _state, _history) => {
|
||||
// Ignore all actions from other slices
|
||||
if (!action.type.startsWith(regionalPromptsSlice.name)) {
|
||||
return false;
|
||||
}
|
||||
// This action is triggered on state changes, including when we undo. If we do not ignore this action, when we
|
||||
// undo, this action triggers and empties the future states array. Therefore, we must ignore this action.
|
||||
if (layerBboxChanged.match(action)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
};
|
134
invokeai/frontend/web/src/features/regionalPrompts/util/bbox.ts
Normal file
134
invokeai/frontend/web/src/features/regionalPrompts/util/bbox.ts
Normal file
@ -0,0 +1,134 @@
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { VECTOR_MASK_LAYER_OBJECT_GROUP_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import Konva from 'konva';
|
||||
import type { Layer as KonvaLayerType } from 'konva/lib/Layer';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
|
||||
|
||||
type Extents = {
|
||||
minX: number;
|
||||
minY: number;
|
||||
maxX: number;
|
||||
maxY: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the bounding box of an image.
|
||||
* @param imageData The ImageData object to get the bounding box of.
|
||||
* @returns The minimum and maximum x and y values of the image's bounding box.
|
||||
*/
|
||||
const getImageDataBbox = (imageData: ImageData): Extents | null => {
|
||||
const { data, width, height } = imageData;
|
||||
let minX = width;
|
||||
let minY = height;
|
||||
let maxX = -1;
|
||||
let maxY = -1;
|
||||
let alpha = 0;
|
||||
let isEmpty = true;
|
||||
|
||||
for (let y = 0; y < height; y++) {
|
||||
for (let x = 0; x < width; x++) {
|
||||
alpha = data[(y * width + x) * 4 + 3] ?? 0;
|
||||
if (alpha > 0) {
|
||||
isEmpty = false;
|
||||
if (x < minX) {
|
||||
minX = x;
|
||||
}
|
||||
if (x > maxX) {
|
||||
maxX = x;
|
||||
}
|
||||
if (y < minY) {
|
||||
minY = y;
|
||||
}
|
||||
if (y > maxY) {
|
||||
maxY = y;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isEmpty ? null : { minX, minY, maxX, maxY };
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the bounding box of a regional prompt konva layer. This function has special handling for regional prompt layers.
|
||||
* @param layer The konva layer to get the bounding box of.
|
||||
* @param preview Whether to open a new tab displaying the rendered layer, which is used to calculate the bbox.
|
||||
*/
|
||||
export const getLayerBboxPixels = (layer: KonvaLayerType, preview: boolean = false): IRect | null => {
|
||||
// To calculate the layer's bounding box, we must first export it to a pixel array, then do some math.
|
||||
//
|
||||
// Though it is relatively fast, we can't use Konva's `getClientRect`. It programmatically determines the rect
|
||||
// by calculating the extents of individual shapes from their "vector" shape data.
|
||||
//
|
||||
// This doesn't work when some shapes are drawn with composite operations that "erase" pixels, like eraser lines.
|
||||
// These shapes' extents are still calculated as if they were solid, leading to a bounding box that is too large.
|
||||
const stage = layer.getStage();
|
||||
|
||||
// Construct and offscreen canvas on which we will do the bbox calculations.
|
||||
const offscreenStageContainer = document.createElement('div');
|
||||
const offscreenStage = new Konva.Stage({
|
||||
container: offscreenStageContainer,
|
||||
width: stage.width(),
|
||||
height: stage.height(),
|
||||
});
|
||||
|
||||
// Clone the layer and filter out unwanted children.
|
||||
const layerClone = layer.clone();
|
||||
offscreenStage.add(layerClone);
|
||||
|
||||
for (const child of layerClone.getChildren()) {
|
||||
if (child.name() === VECTOR_MASK_LAYER_OBJECT_GROUP_NAME) {
|
||||
// We need to cache the group to ensure it composites out eraser strokes correctly
|
||||
child.opacity(1);
|
||||
child.cache();
|
||||
} else {
|
||||
// Filter out unwanted children.
|
||||
child.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
// Get a worst-case rect using the relatively fast `getClientRect`.
|
||||
const layerRect = layerClone.getClientRect();
|
||||
|
||||
// Capture the image data with the above rect.
|
||||
const layerImageData = offscreenStage
|
||||
.toCanvas(layerRect)
|
||||
.getContext('2d')
|
||||
?.getImageData(0, 0, layerRect.width, layerRect.height);
|
||||
assert(layerImageData, "Unable to get layer's image data");
|
||||
|
||||
if (preview) {
|
||||
openBase64ImageInTab([{ base64: imageDataToDataURL(layerImageData), caption: layer.id() }]);
|
||||
}
|
||||
|
||||
// Calculate the layer's bounding box.
|
||||
const layerBbox = getImageDataBbox(layerImageData);
|
||||
|
||||
if (!layerBbox) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Correct the bounding box to be relative to the layer's position.
|
||||
const correctedLayerBbox = {
|
||||
x: layerBbox.minX - Math.floor(stage.x()) + layerRect.x - Math.floor(layer.x()),
|
||||
y: layerBbox.minY - Math.floor(stage.y()) + layerRect.y - Math.floor(layer.y()),
|
||||
width: layerBbox.maxX - layerBbox.minX,
|
||||
height: layerBbox.maxY - layerBbox.minY,
|
||||
};
|
||||
|
||||
return correctedLayerBbox;
|
||||
};
|
||||
|
||||
export const getLayerBboxFast = (layer: KonvaLayerType): IRect | null => {
|
||||
const bbox = layer.getClientRect(GET_CLIENT_RECT_CONFIG);
|
||||
return {
|
||||
x: Math.floor(bbox.x),
|
||||
y: Math.floor(bbox.y),
|
||||
width: Math.floor(bbox.width),
|
||||
height: Math.floor(bbox.height),
|
||||
};
|
||||
};
|
@ -0,0 +1,64 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { renderers } from 'features/regionalPrompts/util/renderers';
|
||||
import Konva from 'konva';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
export const getRegionalPromptLayerBlobs = async (
|
||||
layerIds?: string[],
|
||||
preview: boolean = false
|
||||
): Promise<Record<string, Blob>> => {
|
||||
const state = getStore().getState();
|
||||
const reduxLayers = state.regionalPrompts.present.layers;
|
||||
const container = document.createElement('div');
|
||||
const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height });
|
||||
renderers.renderLayers(stage, reduxLayers, 1, 'brush');
|
||||
|
||||
const konvaLayers = stage.find<Konva.Layer>(`.${VECTOR_MASK_LAYER_NAME}`);
|
||||
const blobs: Record<string, Blob> = {};
|
||||
|
||||
// First remove all layers
|
||||
for (const layer of konvaLayers) {
|
||||
layer.remove();
|
||||
}
|
||||
|
||||
// Next render each layer to a blob
|
||||
for (const layer of konvaLayers) {
|
||||
if (layerIds && !layerIds.includes(layer.id())) {
|
||||
continue;
|
||||
}
|
||||
const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
|
||||
assert(reduxLayer, `Redux layer ${layer.id()} not found`);
|
||||
stage.add(layer);
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
stage.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
openBase64ImageInTab([
|
||||
{
|
||||
base64,
|
||||
caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
|
||||
},
|
||||
]);
|
||||
}
|
||||
layer.remove();
|
||||
blobs[layer.id()] = blob;
|
||||
}
|
||||
|
||||
return blobs;
|
||||
};
|
@ -0,0 +1,619 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
|
||||
import { getScaledFlooredCursorPosition } from 'features/regionalPrompts/hooks/mouseEventHooks';
|
||||
import type {
|
||||
Layer,
|
||||
Tool,
|
||||
VectorMaskLayer,
|
||||
VectorMaskLine,
|
||||
VectorMaskRect,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import {
|
||||
$tool,
|
||||
BACKGROUND_LAYER_ID,
|
||||
BACKGROUND_RECT_ID,
|
||||
getLayerBboxId,
|
||||
getVectorMaskLayerObjectGroupId,
|
||||
isVectorMaskLayer,
|
||||
LAYER_BBOX_NAME,
|
||||
TOOL_PREVIEW_BRUSH_BORDER_INNER_ID,
|
||||
TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID,
|
||||
TOOL_PREVIEW_BRUSH_FILL_ID,
|
||||
TOOL_PREVIEW_BRUSH_GROUP_ID,
|
||||
TOOL_PREVIEW_LAYER_ID,
|
||||
TOOL_PREVIEW_RECT_ID,
|
||||
VECTOR_MASK_LAYER_LINE_NAME,
|
||||
VECTOR_MASK_LAYER_NAME,
|
||||
VECTOR_MASK_LAYER_OBJECT_GROUP_NAME,
|
||||
VECTOR_MASK_LAYER_RECT_NAME,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { getLayerBboxFast, getLayerBboxPixels } from 'features/regionalPrompts/util/bbox';
|
||||
import Konva from 'konva';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { RgbColor } from 'react-colorful';
|
||||
import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
|
||||
const BBOX_NOT_SELECTED_STROKE = 'rgba(255, 255, 255, 0.353)';
|
||||
const BBOX_NOT_SELECTED_MOUSEOVER_STROKE = 'rgba(255, 255, 255, 0.661)';
|
||||
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||
const STAGE_BG_DATAURL =
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
|
||||
|
||||
const mapId = (object: { id: string }) => object.id;
|
||||
|
||||
const getIsSelected = (layerId?: string | null) => {
|
||||
if (!layerId) {
|
||||
return false;
|
||||
}
|
||||
return layerId === getStore().getState().regionalPrompts.present.selectedLayerId;
|
||||
};
|
||||
|
||||
const selectVectorMaskObjects = (node: Konva.Node) => {
|
||||
return node.name() === VECTOR_MASK_LAYER_LINE_NAME || node.name() === VECTOR_MASK_LAYER_RECT_NAME;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates the brush preview layer.
|
||||
* @param stage The konva stage to render on.
|
||||
* @returns The brush preview layer.
|
||||
*/
|
||||
const createToolPreviewLayer = (stage: Konva.Stage) => {
|
||||
// Initialize the brush preview layer & add to the stage
|
||||
const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false });
|
||||
stage.add(toolPreviewLayer);
|
||||
|
||||
// Add handlers to show/hide the brush preview layer
|
||||
stage.on('mousemove', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
stage.on('mouseleave', (e) => {
|
||||
e.target.getStage()?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
|
||||
});
|
||||
stage.on('mouseenter', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
|
||||
// Create the brush preview group & circles
|
||||
const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID });
|
||||
const brushPreviewFill = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_FILL_ID,
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewFill);
|
||||
const brushPreviewBorderInner = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_BORDER_INNER_ID,
|
||||
listening: false,
|
||||
stroke: BRUSH_BORDER_INNER_COLOR,
|
||||
strokeWidth: 1,
|
||||
strokeEnabled: true,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewBorderInner);
|
||||
const brushPreviewBorderOuter = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID,
|
||||
listening: false,
|
||||
stroke: BRUSH_BORDER_OUTER_COLOR,
|
||||
strokeWidth: 1,
|
||||
strokeEnabled: true,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewBorderOuter);
|
||||
toolPreviewLayer.add(brushPreviewGroup);
|
||||
|
||||
// Create the rect preview
|
||||
const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 });
|
||||
toolPreviewLayer.add(rectPreview);
|
||||
|
||||
return toolPreviewLayer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the brush preview for the selected tool.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param tool The selected tool.
|
||||
* @param color The selected layer's color.
|
||||
* @param cursorPos The cursor position.
|
||||
* @param lastMouseDownPos The position of the last mouse down event - used for the rect tool.
|
||||
* @param brushSize The brush size.
|
||||
*/
|
||||
const renderToolPreview = (
|
||||
stage: Konva.Stage,
|
||||
tool: Tool,
|
||||
color: RgbColor | null,
|
||||
globalMaskLayerOpacity: number,
|
||||
cursorPos: Vector2d | null,
|
||||
lastMouseDownPos: Vector2d | null,
|
||||
isMouseOver: boolean,
|
||||
brushSize: number
|
||||
) => {
|
||||
const layerCount = stage.find(`.${VECTOR_MASK_LAYER_NAME}`).length;
|
||||
// Update the stage's pointer style
|
||||
if (layerCount === 0) {
|
||||
// We have no layers, so we should not render any tool
|
||||
stage.container().style.cursor = 'default';
|
||||
} else if (tool === 'move') {
|
||||
// Move tool gets a pointer
|
||||
stage.container().style.cursor = 'default';
|
||||
} else if (tool === 'rect') {
|
||||
// Move rect gets a crosshair
|
||||
stage.container().style.cursor = 'crosshair';
|
||||
} else {
|
||||
// Else we use the brush preview
|
||||
stage.container().style.cursor = 'none';
|
||||
}
|
||||
|
||||
const toolPreviewLayer = stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`) ?? createToolPreviewLayer(stage);
|
||||
|
||||
if (!isMouseOver || layerCount === 0) {
|
||||
// We can bail early if the mouse isn't over the stage or there are no layers
|
||||
toolPreviewLayer.visible(false);
|
||||
return;
|
||||
}
|
||||
|
||||
toolPreviewLayer.visible(true);
|
||||
|
||||
const brushPreviewGroup = stage.findOne<Konva.Group>(`#${TOOL_PREVIEW_BRUSH_GROUP_ID}`);
|
||||
assert(brushPreviewGroup, 'Brush preview group not found');
|
||||
|
||||
const rectPreview = stage.findOne<Konva.Rect>(`#${TOOL_PREVIEW_RECT_ID}`);
|
||||
assert(rectPreview, 'Rect preview not found');
|
||||
|
||||
// No need to render the brush preview if the cursor position or color is missing
|
||||
if (cursorPos && color && (tool === 'brush' || tool === 'eraser')) {
|
||||
// Update the fill circle
|
||||
const brushPreviewFill = brushPreviewGroup.findOne<Konva.Circle>(`#${TOOL_PREVIEW_BRUSH_FILL_ID}`);
|
||||
brushPreviewFill?.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
radius: brushSize / 2,
|
||||
fill: rgbaColorToString({ ...color, a: globalMaskLayerOpacity }),
|
||||
globalCompositeOperation: tool === 'brush' ? 'source-over' : 'destination-out',
|
||||
});
|
||||
|
||||
// Update the inner border of the brush preview
|
||||
const brushPreviewInner = toolPreviewLayer.findOne<Konva.Circle>(`#${TOOL_PREVIEW_BRUSH_BORDER_INNER_ID}`);
|
||||
brushPreviewInner?.setAttrs({ x: cursorPos.x, y: cursorPos.y, radius: brushSize / 2 });
|
||||
|
||||
// Update the outer border of the brush preview
|
||||
const brushPreviewOuter = toolPreviewLayer.findOne<Konva.Circle>(`#${TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID}`);
|
||||
brushPreviewOuter?.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
radius: brushSize / 2 + 1,
|
||||
});
|
||||
|
||||
brushPreviewGroup.visible(true);
|
||||
} else {
|
||||
brushPreviewGroup.visible(false);
|
||||
}
|
||||
|
||||
if (cursorPos && lastMouseDownPos && tool === 'rect') {
|
||||
const rectPreview = toolPreviewLayer.findOne<Konva.Rect>(`#${TOOL_PREVIEW_RECT_ID}`);
|
||||
rectPreview?.setAttrs({
|
||||
x: Math.min(cursorPos.x, lastMouseDownPos.x),
|
||||
y: Math.min(cursorPos.y, lastMouseDownPos.y),
|
||||
width: Math.abs(cursorPos.x - lastMouseDownPos.x),
|
||||
height: Math.abs(cursorPos.y - lastMouseDownPos.y),
|
||||
});
|
||||
rectPreview?.visible(true);
|
||||
} else {
|
||||
rectPreview?.visible(false);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a vector mask layer.
|
||||
* @param stage The konva stage to attach the layer to.
|
||||
* @param reduxLayer The redux layer to create the konva layer from.
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes.
|
||||
*/
|
||||
const createVectorMaskLayer = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayer: VectorMaskLayer,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
) => {
|
||||
// This layer hasn't been added to the konva state yet
|
||||
const konvaLayer = new Konva.Layer({
|
||||
id: reduxLayer.id,
|
||||
name: VECTOR_MASK_LAYER_NAME,
|
||||
draggable: true,
|
||||
dragDistance: 0,
|
||||
});
|
||||
|
||||
// Create a `dragmove` listener for this layer
|
||||
if (onLayerPosChanged) {
|
||||
konvaLayer.on('dragend', function (e) {
|
||||
onLayerPosChanged(reduxLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
|
||||
});
|
||||
}
|
||||
|
||||
// The dragBoundFunc limits how far the layer can be dragged
|
||||
konvaLayer.dragBoundFunc(function (pos) {
|
||||
const cursorPos = getScaledFlooredCursorPosition(stage);
|
||||
if (!cursorPos) {
|
||||
return this.getAbsolutePosition();
|
||||
}
|
||||
// Prevent the user from dragging the layer out of the stage bounds.
|
||||
if (
|
||||
cursorPos.x < 0 ||
|
||||
cursorPos.x > stage.width() / stage.scaleX() ||
|
||||
cursorPos.y < 0 ||
|
||||
cursorPos.y > stage.height() / stage.scaleY()
|
||||
) {
|
||||
return this.getAbsolutePosition();
|
||||
}
|
||||
return pos;
|
||||
});
|
||||
|
||||
// The object group holds all of the layer's objects (e.g. lines and rects)
|
||||
const konvaObjectGroup = new Konva.Group({
|
||||
id: getVectorMaskLayerObjectGroupId(reduxLayer.id, uuidv4()),
|
||||
name: VECTOR_MASK_LAYER_OBJECT_GROUP_NAME,
|
||||
listening: false,
|
||||
});
|
||||
konvaLayer.add(konvaObjectGroup);
|
||||
|
||||
stage.add(konvaLayer);
|
||||
|
||||
return konvaLayer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a konva line from a redux vector mask line.
|
||||
* @param reduxObject The redux object to create the konva line from.
|
||||
* @param konvaGroup The konva group to add the line to.
|
||||
*/
|
||||
const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Group): Konva.Line => {
|
||||
const vectorMaskLine = new Konva.Line({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
name: VECTOR_MASK_LAYER_LINE_NAME,
|
||||
strokeWidth: reduxObject.strokeWidth,
|
||||
tension: 0,
|
||||
lineCap: 'round',
|
||||
lineJoin: 'round',
|
||||
shadowForStrokeEnabled: false,
|
||||
globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||
listening: false,
|
||||
});
|
||||
konvaGroup.add(vectorMaskLine);
|
||||
return vectorMaskLine;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a konva rect from a redux vector mask rect.
|
||||
* @param reduxObject The redux object to create the konva rect from.
|
||||
* @param konvaGroup The konva group to add the rect to.
|
||||
*/
|
||||
const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Group): Konva.Rect => {
|
||||
const vectorMaskRect = new Konva.Rect({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
name: VECTOR_MASK_LAYER_RECT_NAME,
|
||||
x: reduxObject.x,
|
||||
y: reduxObject.y,
|
||||
width: reduxObject.width,
|
||||
height: reduxObject.height,
|
||||
listening: false,
|
||||
});
|
||||
konvaGroup.add(vectorMaskRect);
|
||||
return vectorMaskRect;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders a vector mask layer.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param reduxLayer The redux vector mask layer to render.
|
||||
* @param reduxLayerIndex The index of the layer in the redux store.
|
||||
* @param globalMaskLayerOpacity The opacity of the global mask layer.
|
||||
* @param tool The current tool.
|
||||
*/
|
||||
const renderVectorMaskLayer = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayer: VectorMaskLayer,
|
||||
globalMaskLayerOpacity: number,
|
||||
tool: Tool,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
): void => {
|
||||
const konvaLayer =
|
||||
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createVectorMaskLayer(stage, reduxLayer, onLayerPosChanged);
|
||||
|
||||
// Update the layer's position and listening state
|
||||
konvaLayer.setAttrs({
|
||||
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
|
||||
x: Math.floor(reduxLayer.x),
|
||||
y: Math.floor(reduxLayer.y),
|
||||
});
|
||||
|
||||
// Convert the color to a string, stripping the alpha - the object group will handle opacity.
|
||||
const rgbColor = rgbColorToString(reduxLayer.previewColor);
|
||||
|
||||
const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${VECTOR_MASK_LAYER_OBJECT_GROUP_NAME}`);
|
||||
assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`);
|
||||
|
||||
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
|
||||
let groupNeedsCache = false;
|
||||
|
||||
const objectIds = reduxLayer.objects.map(mapId);
|
||||
for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) {
|
||||
if (!objectIds.includes(objectNode.id())) {
|
||||
objectNode.destroy();
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (const reduxObject of reduxLayer.objects) {
|
||||
if (reduxObject.type === 'vector_mask_line') {
|
||||
const vectorMaskLine =
|
||||
stage.findOne<Konva.Line>(`#${reduxObject.id}`) ?? createVectorMaskLine(reduxObject, konvaObjectGroup);
|
||||
|
||||
// Only update the points if they have changed. The point values are never mutated, they are only added to the
|
||||
// array, so checking the length is sufficient to determine if we need to re-cache.
|
||||
if (vectorMaskLine.points().length !== reduxObject.points.length) {
|
||||
vectorMaskLine.points(reduxObject.points);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
// Only update the color if it has changed.
|
||||
if (vectorMaskLine.stroke() !== rgbColor) {
|
||||
vectorMaskLine.stroke(rgbColor);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
} else if (reduxObject.type === 'vector_mask_rect') {
|
||||
const konvaObject =
|
||||
stage.findOne<Konva.Rect>(`#${reduxObject.id}`) ?? createVectorMaskRect(reduxObject, konvaObjectGroup);
|
||||
|
||||
// Only update the color if it has changed.
|
||||
if (konvaObject.fill() !== rgbColor) {
|
||||
konvaObject.fill(rgbColor);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only update layer visibility if it has changed.
|
||||
if (konvaLayer.visible() !== reduxLayer.isVisible) {
|
||||
konvaLayer.visible(reduxLayer.isVisible);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
|
||||
if (konvaObjectGroup.children.length === 0) {
|
||||
// No objects - clear the cache to reset the previous pixel data
|
||||
konvaObjectGroup.clearCache();
|
||||
} else if (groupNeedsCache) {
|
||||
konvaObjectGroup.cache();
|
||||
}
|
||||
|
||||
// Updating group opacity does not require re-caching
|
||||
if (konvaObjectGroup.opacity() !== globalMaskLayerOpacity) {
|
||||
konvaObjectGroup.opacity(globalMaskLayerOpacity);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the layers on the stage.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param reduxLayers Array of the layers from the redux store.
|
||||
* @param layerOpacity The opacity of the layer.
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes. This is optional to allow for offscreen rendering.
|
||||
* @returns
|
||||
*/
|
||||
const renderLayers = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayers: Layer[],
|
||||
globalMaskLayerOpacity: number,
|
||||
tool: Tool,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
) => {
|
||||
const reduxLayerIds = reduxLayers.map(mapId);
|
||||
|
||||
// Remove un-rendered layers
|
||||
for (const konvaLayer of stage.find<Konva.Layer>(`.${VECTOR_MASK_LAYER_NAME}`)) {
|
||||
if (!reduxLayerIds.includes(konvaLayer.id())) {
|
||||
konvaLayer.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
for (const reduxLayer of reduxLayers) {
|
||||
if (isVectorMaskLayer(reduxLayer)) {
|
||||
renderVectorMaskLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a bounding box rect for a layer.
|
||||
* @param reduxLayer The redux layer to create the bounding box for.
|
||||
* @param konvaLayer The konva layer to attach the bounding box to.
|
||||
* @param onBboxMouseDown Callback for when the bounding box is clicked.
|
||||
*/
|
||||
const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer, onBboxMouseDown: (layerId: string) => void) => {
|
||||
const rect = new Konva.Rect({
|
||||
id: getLayerBboxId(reduxLayer.id),
|
||||
name: LAYER_BBOX_NAME,
|
||||
strokeWidth: 1,
|
||||
});
|
||||
rect.on('mousedown', function () {
|
||||
onBboxMouseDown(reduxLayer.id);
|
||||
});
|
||||
rect.on('mouseover', function (e) {
|
||||
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||
this.stroke(BBOX_SELECTED_STROKE);
|
||||
} else {
|
||||
this.stroke(BBOX_NOT_SELECTED_MOUSEOVER_STROKE);
|
||||
}
|
||||
});
|
||||
rect.on('mouseout', function (e) {
|
||||
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||
this.stroke(BBOX_SELECTED_STROKE);
|
||||
} else {
|
||||
this.stroke(BBOX_NOT_SELECTED_STROKE);
|
||||
}
|
||||
});
|
||||
konvaLayer.add(rect);
|
||||
return rect;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the bounding boxes for the layers.
|
||||
* @param stage The konva stage to render on
|
||||
* @param reduxLayers An array of all redux layers to draw bboxes for
|
||||
* @param selectedLayerId The selected layer's id
|
||||
* @param tool The current tool
|
||||
* @param onBboxChanged Callback for when the bbox is changed
|
||||
* @param onBboxMouseDown Callback for when the bbox is clicked
|
||||
* @returns
|
||||
*/
|
||||
const renderBbox = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayers: Layer[],
|
||||
selectedLayerId: string | null,
|
||||
tool: Tool,
|
||||
onBboxChanged: (layerId: string, bbox: IRect | null) => void,
|
||||
onBboxMouseDown: (layerId: string) => void
|
||||
) => {
|
||||
// Hide all bboxes so they don't interfere with getClientRect
|
||||
for (const bboxRect of stage.find<Konva.Rect>(`.${LAYER_BBOX_NAME}`)) {
|
||||
bboxRect.visible(false);
|
||||
bboxRect.listening(false);
|
||||
}
|
||||
// No selected layer or not using the move tool - nothing more to do here
|
||||
if (tool !== 'move') {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const reduxLayer of reduxLayers) {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
|
||||
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
|
||||
|
||||
let bbox = reduxLayer.bbox;
|
||||
|
||||
// We only need to recalculate the bbox if the layer has changed and it has objects
|
||||
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
|
||||
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes
|
||||
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
|
||||
// Update the layer's bbox in the redux store
|
||||
onBboxChanged(reduxLayer.id, bbox);
|
||||
}
|
||||
|
||||
if (!bbox) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const rect =
|
||||
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
|
||||
|
||||
rect.setAttrs({
|
||||
visible: true,
|
||||
listening: true,
|
||||
x: bbox.x,
|
||||
y: bbox.y,
|
||||
width: bbox.width,
|
||||
height: bbox.height,
|
||||
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates the background layer for the stage.
|
||||
* @param stage The konva stage to render on
|
||||
*/
|
||||
const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
const layer = new Konva.Layer({
|
||||
id: BACKGROUND_LAYER_ID,
|
||||
});
|
||||
const background = new Konva.Rect({
|
||||
id: BACKGROUND_RECT_ID,
|
||||
x: stage.x(),
|
||||
y: 0,
|
||||
width: stage.width() / stage.scaleX(),
|
||||
height: stage.height() / stage.scaleY(),
|
||||
listening: false,
|
||||
opacity: 0.2,
|
||||
});
|
||||
layer.add(background);
|
||||
stage.add(layer);
|
||||
const image = new Image();
|
||||
image.onload = () => {
|
||||
background.fillPatternImage(image);
|
||||
};
|
||||
image.src = STAGE_BG_DATAURL;
|
||||
return layer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the background layer for the stage.
|
||||
* @param stage The konva stage to render on
|
||||
* @param width The unscaled width of the canvas
|
||||
* @param height The unscaled height of the canvas
|
||||
*/
|
||||
const renderBackground = (stage: Konva.Stage, width: number, height: number) => {
|
||||
const layer = stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage);
|
||||
|
||||
const background = layer.findOne<Konva.Rect>(`#${BACKGROUND_RECT_ID}`);
|
||||
assert(background, 'Background rect not found');
|
||||
// ensure background rect is in the top-left of the canvas
|
||||
background.absolutePosition({ x: 0, y: 0 });
|
||||
|
||||
// set the dimensions of the background rect to match the canvas - not the stage!!!
|
||||
background.size({
|
||||
width: width / stage.scaleX(),
|
||||
height: height / stage.scaleY(),
|
||||
});
|
||||
|
||||
// Calculate the amount the stage is moved - including the effect of scaling
|
||||
const stagePos = {
|
||||
x: -stage.x() / stage.scaleX(),
|
||||
y: -stage.y() / stage.scaleY(),
|
||||
};
|
||||
|
||||
// Apply that movement to the fill pattern
|
||||
background.fillPatternOffset(stagePos);
|
||||
};
|
||||
|
||||
/**
|
||||
* Arranges all layers in the z-axis by updating their z-indices.
|
||||
* @param stage The konva stage
|
||||
* @param layerIds An array of redux layer ids, in their z-index order
|
||||
*/
|
||||
const arrangeLayers = (stage: Konva.Stage, layerIds: string[]): void => {
|
||||
let nextZIndex = 0;
|
||||
// Background is the first layer
|
||||
stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(nextZIndex++);
|
||||
// Then arrange the redux layers in order
|
||||
for (const layerId of layerIds) {
|
||||
stage.findOne<Konva.Layer>(`#${layerId}`)?.zIndex(nextZIndex++);
|
||||
}
|
||||
// Finally, the tool preview layer is always on top
|
||||
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++);
|
||||
};
|
||||
|
||||
export const renderers = {
|
||||
renderToolPreview,
|
||||
renderLayers,
|
||||
renderBbox,
|
||||
renderBackground,
|
||||
arrangeLayers,
|
||||
};
|
||||
|
||||
const DEBOUNCE_MS = 300;
|
||||
|
||||
export const debouncedRenderers = {
|
||||
renderToolPreview: debounce(renderToolPreview, DEBOUNCE_MS),
|
||||
renderLayers: debounce(renderLayers, DEBOUNCE_MS),
|
||||
renderBbox: debounce(renderBbox, DEBOUNCE_MS),
|
||||
renderBackground: debounce(renderBackground, DEBOUNCE_MS),
|
||||
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
|
||||
};
|
@ -13,51 +13,60 @@ import {
|
||||
selectValidIPAdapters,
|
||||
selectValidT2IAdapters,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { Fragment, memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const badges: string[] = [];
|
||||
let isError = false;
|
||||
const selector = createMemoizedSelector(
|
||||
[selectControlAdaptersSlice, selectRegionalPromptsSlice],
|
||||
(controlAdapters, regionalPrompts) => {
|
||||
const badges: string[] = [];
|
||||
let isError = false;
|
||||
|
||||
const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters).filter((ca) => ca.isEnabled).length;
|
||||
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
|
||||
if (enabledIPAdapterCount > 0) {
|
||||
badges.push(`${enabledIPAdapterCount} IP`);
|
||||
}
|
||||
if (enabledIPAdapterCount > validIPAdapterCount) {
|
||||
isError = true;
|
||||
}
|
||||
const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters)
|
||||
.filter((ca) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)))
|
||||
.filter((ca) => ca.isEnabled).length;
|
||||
|
||||
const enabledControlNetCount = selectAllControlNets(controlAdapters).filter((ca) => ca.isEnabled).length;
|
||||
const validControlNetCount = selectValidControlNets(controlAdapters).length;
|
||||
if (enabledControlNetCount > 0) {
|
||||
badges.push(`${enabledControlNetCount} ControlNet`);
|
||||
}
|
||||
if (enabledControlNetCount > validControlNetCount) {
|
||||
isError = true;
|
||||
}
|
||||
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
|
||||
if (enabledIPAdapterCount > 0) {
|
||||
badges.push(`${enabledIPAdapterCount} IP`);
|
||||
}
|
||||
if (enabledIPAdapterCount > validIPAdapterCount) {
|
||||
isError = true;
|
||||
}
|
||||
|
||||
const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters).filter((ca) => ca.isEnabled).length;
|
||||
const validT2IAdapterCount = selectValidT2IAdapters(controlAdapters).length;
|
||||
if (enabledT2IAdapterCount > 0) {
|
||||
badges.push(`${enabledT2IAdapterCount} T2I`);
|
||||
}
|
||||
if (enabledT2IAdapterCount > validT2IAdapterCount) {
|
||||
isError = true;
|
||||
}
|
||||
const enabledControlNetCount = selectAllControlNets(controlAdapters).filter((ca) => ca.isEnabled).length;
|
||||
const validControlNetCount = selectValidControlNets(controlAdapters).length;
|
||||
if (enabledControlNetCount > 0) {
|
||||
badges.push(`${enabledControlNetCount} ControlNet`);
|
||||
}
|
||||
if (enabledControlNetCount > validControlNetCount) {
|
||||
isError = true;
|
||||
}
|
||||
|
||||
const controlAdapterIds = selectControlAdapterIds(controlAdapters);
|
||||
const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters).filter((ca) => ca.isEnabled).length;
|
||||
const validT2IAdapterCount = selectValidT2IAdapters(controlAdapters).length;
|
||||
if (enabledT2IAdapterCount > 0) {
|
||||
badges.push(`${enabledT2IAdapterCount} T2I`);
|
||||
}
|
||||
if (enabledT2IAdapterCount > validT2IAdapterCount) {
|
||||
isError = true;
|
||||
}
|
||||
|
||||
return {
|
||||
controlAdapterIds,
|
||||
badges,
|
||||
isError, // TODO: Add some visual indicator that the control adapters are in an error state
|
||||
};
|
||||
});
|
||||
const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter(
|
||||
(id) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(id))
|
||||
);
|
||||
|
||||
return {
|
||||
controlAdapterIds,
|
||||
badges,
|
||||
isError, // TODO: Add some visual indicator that the control adapters are in an error state
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
export const ControlSettingsAccordion: React.FC = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
@ -2,6 +2,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { aspectRatioChanged, setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import ParamBoundingBoxHeight from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
||||
import ParamBoundingBoxWidth from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
||||
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
@ -41,6 +42,7 @@ export const ImageSizeCanvas = memo(() => {
|
||||
aspectRatioState={aspectRatioState}
|
||||
heightComponent={<ParamBoundingBoxHeight />}
|
||||
widthComponent={<ParamBoundingBoxWidth />}
|
||||
previewComponent={<AspectRatioIconPreview />}
|
||||
onChangeAspectRatioState={onChangeAspectRatioState}
|
||||
onChangeWidth={onChangeWidth}
|
||||
onChangeHeight={onChangeHeight}
|
||||
|
@ -1,13 +1,17 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ParamHeight } from 'features/parameters/components/Core/ParamHeight';
|
||||
import { ParamWidth } from 'features/parameters/components/Core/ParamWidth';
|
||||
import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview';
|
||||
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||
import { aspectRatioChanged, heightChanged, widthChanged } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
export const ImageSizeLinear = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const tab = useAppSelector(activeTabNameSelector);
|
||||
const width = useAppSelector((s) => s.generation.width);
|
||||
const height = useAppSelector((s) => s.generation.height);
|
||||
const aspectRatioState = useAppSelector((s) => s.generation.aspectRatio);
|
||||
@ -40,6 +44,7 @@ export const ImageSizeLinear = memo(() => {
|
||||
aspectRatioState={aspectRatioState}
|
||||
heightComponent={<ParamHeight />}
|
||||
widthComponent={<ParamWidth />}
|
||||
previewComponent={tab === 'txt2img' ? <AspectRatioCanvasPreview /> : <AspectRatioIconPreview />}
|
||||
onChangeAspectRatioState={onChangeAspectRatioState}
|
||||
onChangeWidth={onChangeWidth}
|
||||
onChangeHeight={onChangeHeight}
|
||||
|
@ -1,8 +1,10 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
|
||||
import QueueControls from 'features/queue/components/QueueControls';
|
||||
import { RegionalPromptsPanelContent } from 'features/regionalPrompts/components/RegionalPromptsPanelContent';
|
||||
import { useRegionalControlTitle } from 'features/regionalPrompts/hooks/useRegionalControlTitle';
|
||||
import { SDXLPrompts } from 'features/sdxl/components/SDXLPrompts/SDXLPrompts';
|
||||
import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion';
|
||||
import { CompositingSettingsAccordion } from 'features/settingsAccordions/components/CompositingSettingsAccordion/CompositingSettingsAccordion';
|
||||
@ -14,6 +16,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const overlayScrollbarsStyles: CSSProperties = {
|
||||
height: '100%',
|
||||
@ -21,7 +24,9 @@ const overlayScrollbarsStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
const ParametersPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const regionalControlTitle = useRegionalControlTitle();
|
||||
const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl');
|
||||
|
||||
return (
|
||||
@ -32,12 +37,28 @@ const ParametersPanel = () => {
|
||||
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayScrollbarsParams.options}>
|
||||
<Flex gap={2} flexDirection="column" h="full" w="full">
|
||||
{isSDXL ? <SDXLPrompts /> : <Prompts />}
|
||||
<ImageSettingsAccordion />
|
||||
<GenerationSettingsAccordion />
|
||||
<ControlSettingsAccordion />
|
||||
{activeTabName === 'unifiedCanvas' && <CompositingSettingsAccordion />}
|
||||
{isSDXL && <RefinerSettingsAccordion />}
|
||||
<AdvancedSettingsAccordion />
|
||||
<Tabs variant="line" isLazy={true} display="flex" flexDir="column" w="full" h="full">
|
||||
<TabList>
|
||||
<Tab>{t('parameters.globalSettings')}</Tab>
|
||||
<Tab>{regionalControlTitle}</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels w="full" h="full">
|
||||
<TabPanel>
|
||||
<Flex gap={2} flexDirection="column" h="full" w="full">
|
||||
<ImageSettingsAccordion />
|
||||
<GenerationSettingsAccordion />
|
||||
<ControlSettingsAccordion />
|
||||
{activeTabName === 'unifiedCanvas' && <CompositingSettingsAccordion />}
|
||||
{isSDXL && <RefinerSettingsAccordion />}
|
||||
<AdvancedSettingsAccordion />
|
||||
</Flex>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<RegionalPromptsPanelContent />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
|
@ -1,13 +1,31 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import CurrentImageDisplay from 'features/gallery/components/CurrentImage/CurrentImageDisplay';
|
||||
import { RegionalPromptsEditor } from 'features/regionalPrompts/components/RegionalPromptsEditor';
|
||||
import { useRegionalControlTitle } from 'features/regionalPrompts/hooks/useRegionalControlTitle';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const TextToImageTab = () => {
|
||||
const { t } = useTranslation();
|
||||
const regionalControlTitle = useRegionalControlTitle();
|
||||
|
||||
return (
|
||||
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||
<Flex w="full" h="full">
|
||||
<CurrentImageDisplay />
|
||||
</Flex>
|
||||
<Box position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||
<Tabs variant="line" isLazy={true} display="flex" flexDir="column" w="full" h="full">
|
||||
<TabList>
|
||||
<Tab>{t('common.viewer')}</Tab>
|
||||
<Tab>{regionalControlTitle}</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels w="full" h="full" minH={0} minW={0}>
|
||||
<TabPanel>
|
||||
<CurrentImageDisplay />
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<RegionalPromptsEditor />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -124,7 +124,9 @@ export const usePanel = (arg: UsePanelOptions): UsePanelReturn => {
|
||||
*
|
||||
* For now, we'll just resize the panel to the min size every time the panel group is resized.
|
||||
*/
|
||||
panelHandleRef.current.resize(minSizePct);
|
||||
if (!panelHandleRef.current.isCollapsed()) {
|
||||
panelHandleRef.current.resize(minSizePct);
|
||||
}
|
||||
});
|
||||
|
||||
resizeObserver.observe(panelGroupElement);
|
||||
|
File diff suppressed because one or more lines are too long
@ -1 +1 @@
|
||||
__version__ = "4.1.0"
|
||||
__version__ = "4.2.0a2"
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
from torch import tensor
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import InvalidModelConfigException
|
||||
from invokeai.backend.model_manager.config import InvalidModelConfigException, MainDiffusersConfig, ModelVariantType
|
||||
from invokeai.backend.model_manager.probe import (
|
||||
CkptType,
|
||||
ModelProbe,
|
||||
@ -78,3 +78,11 @@ def test_probe_handles_state_dict_with_integer_keys():
|
||||
}
|
||||
with pytest.raises(InvalidModelConfigException):
|
||||
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)
|
||||
|
||||
|
||||
def test_probe_sd1_diffusers_inpainting(datadir: Path):
|
||||
config = ModelProbe.probe(datadir / "sd-1/main/dreamshaper-8-inpainting")
|
||||
assert isinstance(config, MainDiffusersConfig)
|
||||
assert config.base is BaseModelType.StableDiffusion1
|
||||
assert config.variant is ModelVariantType.Inpaint
|
||||
assert config.repo_variant is ModelRepoVariant.FP16
|
||||
|
@ -0,0 +1 @@
|
||||
This folder contains config files copied from [Lykon/dreamshaper-8-inpainting](https://huggingface.co/Lykon/dreamshaper-8-inpainting).
|
@ -0,0 +1,34 @@
|
||||
{
|
||||
"_class_name": "StableDiffusionInpaintPipeline",
|
||||
"_diffusers_version": "0.21.0.dev0",
|
||||
"_name_or_path": "lykon-models/dreamshaper-8-inpainting",
|
||||
"feature_extractor": [
|
||||
"transformers",
|
||||
"CLIPFeatureExtractor"
|
||||
],
|
||||
"requires_safety_checker": true,
|
||||
"safety_checker": [
|
||||
"stable_diffusion",
|
||||
"StableDiffusionSafetyChecker"
|
||||
],
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"DEISMultistepScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
{
|
||||
"_class_name": "DEISMultistepScheduler",
|
||||
"_diffusers_version": "0.21.0.dev0",
|
||||
"algorithm_type": "deis",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"lower_order_final": true,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"solver_order": 2,
|
||||
"solver_type": "logrho",
|
||||
"steps_offset": 1,
|
||||
"thresholding": false,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": null,
|
||||
"use_karras_sigmas": false
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.21.0.dev0",
|
||||
"_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8-inpainting/snapshots/15dcb9dec91a39ee498e3917c9ef6174b103862d/unet",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": null,
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": null,
|
||||
"attention_head_dim": 8,
|
||||
"attention_type": "default",
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"class_embed_type": null,
|
||||
"class_embeddings_concat": false,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 768,
|
||||
"cross_attention_norm": null,
|
||||
"down_block_types": [
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"dual_cross_attention": false,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 9,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_only_cross_attention": null,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"out_channels": 4,
|
||||
"projection_class_embeddings_input_dim": null,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": false,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"sample_size": 64,
|
||||
"time_cond_proj_dim": null,
|
||||
"time_embedding_act_fn": null,
|
||||
"time_embedding_dim": null,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": null,
|
||||
"transformer_layers_per_block": 1,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D"
|
||||
],
|
||||
"upcast_attention": null,
|
||||
"use_linear_projection": false
|
||||
}
|
@ -99,6 +99,20 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
assert not Path(tmp_path, obj_1_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
assert not tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False)
|
||||
assert tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
|
Loading…
Reference in New Issue
Block a user