mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/logger-route
This commit is contained in:
commit
5134de7cfa
@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
CONTROLNET_MODE_VALUES = Literal[tuple(
|
CONTROLNET_MODE_VALUES = Literal[tuple(
|
||||||
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||||
# crop and fill options not ready yet
|
CONTROLNET_RESIZE_VALUES = Literal[tuple(
|
||||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModelField(BaseModel):
|
class ControlNetModelField(BaseModel):
|
||||||
@ -111,7 +111,8 @@ class ControlField(BaseModel):
|
|||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(
|
control_mode: CONTROLNET_MODE_VALUES = Field(
|
||||||
default="balanced", description="The control mode to use")
|
default="balanced", description="The control mode to use")
|
||||||
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
|
||||||
|
default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def validate_control_weight(cls, v):
|
def validate_control_weight(cls, v):
|
||||||
@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
control_mode=self.control_mode,
|
control_mode=self.control_mode,
|
||||||
|
resize_mode=self.resize_mode,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ from .compel import ConditioningField
|
|||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
|
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
# and do real check for classifier_free_guidance?
|
# and do real check for classifier_free_guidance?
|
||||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||||
control_image = model.prepare_control_image(
|
control_image = prepare_control_image(
|
||||||
image=input_image,
|
image=input_image,
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
width=control_width_resize,
|
width=control_width_resize,
|
||||||
@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
device=control_model.device,
|
device=control_model.device,
|
||||||
dtype=control_model.dtype,
|
dtype=control_model.dtype,
|
||||||
control_mode=control_info.control_mode,
|
control_mode=control_info.control_mode,
|
||||||
|
resize_mode=control_info.resize_mode,
|
||||||
)
|
)
|
||||||
control_item = ControlNetData(
|
control_item = ControlNetData(
|
||||||
model=control_model, image_tensor=control_image,
|
model=control_model,
|
||||||
|
image_tensor=control_image,
|
||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
end_step_percent=control_info.end_step_percent,
|
end_step_percent=control_info.end_step_percent,
|
||||||
control_mode=control_info.control_mode,
|
control_mode=control_info.control_mode,
|
||||||
|
# any resizing needed should currently be happening in prepare_control_image(),
|
||||||
|
# but adding resize_mode to ControlNetData in case needed in the future
|
||||||
|
resize_mode=control_info.resize_mode,
|
||||||
)
|
)
|
||||||
control_data.append(control_item)
|
control_data.append(control_item)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
|
342
invokeai/app/util/controlnet_utils.py
Normal file
342
invokeai/app/util/controlnet_utils.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers.utils import PIL_INTERPOLATION
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from controlnet_aux.util import HWC3, resize_image
|
||||||
|
|
||||||
|
###################################################################
|
||||||
|
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
||||||
|
###################################################################
|
||||||
|
# High Quality Edge Thinning using Pure Python
|
||||||
|
# Written by Lvmin Zhangu
|
||||||
|
# 2023 April
|
||||||
|
# Stanford University
|
||||||
|
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||||
|
|
||||||
|
lvmin_kernels_raw = [
|
||||||
|
np.array([
|
||||||
|
[-1, -1, -1],
|
||||||
|
[0, 1, 0],
|
||||||
|
[1, 1, 1]
|
||||||
|
], dtype=np.int32),
|
||||||
|
np.array([
|
||||||
|
[0, -1, -1],
|
||||||
|
[1, 1, -1],
|
||||||
|
[0, 1, 0]
|
||||||
|
], dtype=np.int32)
|
||||||
|
]
|
||||||
|
|
||||||
|
lvmin_kernels = []
|
||||||
|
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||||
|
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||||
|
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||||
|
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||||
|
|
||||||
|
lvmin_prunings_raw = [
|
||||||
|
np.array([
|
||||||
|
[-1, -1, -1],
|
||||||
|
[-1, 1, -1],
|
||||||
|
[0, 0, -1]
|
||||||
|
], dtype=np.int32),
|
||||||
|
np.array([
|
||||||
|
[-1, -1, -1],
|
||||||
|
[-1, 1, -1],
|
||||||
|
[-1, 0, 0]
|
||||||
|
], dtype=np.int32)
|
||||||
|
]
|
||||||
|
|
||||||
|
lvmin_prunings = []
|
||||||
|
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||||
|
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||||
|
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||||
|
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_pattern(x, kernel):
|
||||||
|
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
|
||||||
|
objects = np.where(objects > 127)
|
||||||
|
x[objects] = 0
|
||||||
|
return x, objects[0].shape[0] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def thin_one_time(x, kernels):
|
||||||
|
y = x
|
||||||
|
is_done = True
|
||||||
|
for k in kernels:
|
||||||
|
y, has_update = remove_pattern(y, k)
|
||||||
|
if has_update:
|
||||||
|
is_done = False
|
||||||
|
return y, is_done
|
||||||
|
|
||||||
|
|
||||||
|
def lvmin_thin(x, prunings=True):
|
||||||
|
y = x
|
||||||
|
for i in range(32):
|
||||||
|
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||||
|
if is_done:
|
||||||
|
break
|
||||||
|
if prunings:
|
||||||
|
y, _ = thin_one_time(y, lvmin_prunings)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def nake_nms(x):
|
||||||
|
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||||
|
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||||
|
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||||
|
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||||
|
y = np.zeros_like(x)
|
||||||
|
for f in [f1, f2, f3, f4]:
|
||||||
|
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
|
||||||
|
################################################################################
|
||||||
|
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
||||||
|
def pixel_perfect_resolution(
|
||||||
|
image: np.ndarray,
|
||||||
|
target_H: int,
|
||||||
|
target_W: int,
|
||||||
|
resize_mode: str,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||||
|
|
||||||
|
The function first calculates scaling factors for height and width of the image based on the target
|
||||||
|
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
|
||||||
|
scaling factor to estimate the new resolution.
|
||||||
|
|
||||||
|
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
|
||||||
|
fits within the target dimensions, potentially leaving some empty space.
|
||||||
|
|
||||||
|
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
|
||||||
|
dimensions are fully filled, potentially cropping the image.
|
||||||
|
|
||||||
|
After calculating the estimated resolution, the function prints some debugging information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
|
||||||
|
target_H (int): The target height for the image.
|
||||||
|
target_W (int): The target width for the image.
|
||||||
|
resize_mode (ResizeMode): The mode for resizing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The estimated resolution after resizing.
|
||||||
|
"""
|
||||||
|
raw_H, raw_W, _ = image.shape
|
||||||
|
|
||||||
|
k0 = float(target_H) / float(raw_H)
|
||||||
|
k1 = float(target_W) / float(raw_W)
|
||||||
|
|
||||||
|
if resize_mode == "fill_resize":
|
||||||
|
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||||
|
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||||
|
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||||
|
|
||||||
|
# print(f"Pixel Perfect Computation:")
|
||||||
|
# print(f"resize_mode = {resize_mode}")
|
||||||
|
# print(f"raw_H = {raw_H}")
|
||||||
|
# print(f"raw_W = {raw_W}")
|
||||||
|
# print(f"target_H = {target_H}")
|
||||||
|
# print(f"target_W = {target_W}")
|
||||||
|
# print(f"estimation = {estimation}")
|
||||||
|
|
||||||
|
return int(np.round(estimation))
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################################
|
||||||
|
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
|
||||||
|
# modified for InvokeAI
|
||||||
|
###########################################################################
|
||||||
|
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
||||||
|
def np_img_resize(
|
||||||
|
np_img: np.ndarray,
|
||||||
|
resize_mode: str,
|
||||||
|
h: int,
|
||||||
|
w: int,
|
||||||
|
device: torch.device = torch.device('cpu')
|
||||||
|
):
|
||||||
|
# if 'inpaint' in module:
|
||||||
|
# np_img = np_img.astype(np.float32)
|
||||||
|
# else:
|
||||||
|
# np_img = HWC3(np_img)
|
||||||
|
np_img = HWC3(np_img)
|
||||||
|
|
||||||
|
def safe_numpy(x):
|
||||||
|
# A very safe method to make sure that Apple/Mac works
|
||||||
|
y = x
|
||||||
|
|
||||||
|
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||||
|
y = y.copy()
|
||||||
|
y = np.ascontiguousarray(y)
|
||||||
|
y = y.copy()
|
||||||
|
return y
|
||||||
|
|
||||||
|
def get_pytorch_control(x):
|
||||||
|
# A very safe method to make sure that Apple/Mac works
|
||||||
|
y = x
|
||||||
|
|
||||||
|
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||||
|
y = torch.from_numpy(y)
|
||||||
|
y = y.float() / 255.0
|
||||||
|
y = rearrange(y, 'h w c -> 1 c h w')
|
||||||
|
y = y.clone()
|
||||||
|
# y = y.to(devices.get_device_for("controlnet"))
|
||||||
|
y = y.to(device)
|
||||||
|
y = y.clone()
|
||||||
|
return y
|
||||||
|
|
||||||
|
def high_quality_resize(x: np.ndarray,
|
||||||
|
size):
|
||||||
|
# Written by lvmin
|
||||||
|
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||||
|
inpaint_mask = None
|
||||||
|
if x.ndim == 3 and x.shape[2] == 4:
|
||||||
|
inpaint_mask = x[:, :, 3]
|
||||||
|
x = x[:, :, 0:3]
|
||||||
|
|
||||||
|
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
||||||
|
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
||||||
|
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
|
||||||
|
is_one_pixel_edge = False
|
||||||
|
is_binary = False
|
||||||
|
if unique_color_count == 2:
|
||||||
|
is_binary = np.min(x) < 16 and np.max(x) > 240
|
||||||
|
if is_binary:
|
||||||
|
xc = x
|
||||||
|
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||||
|
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||||
|
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
|
||||||
|
all_edge_count = np.where(x > 127)[0].shape[0]
|
||||||
|
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
|
||||||
|
|
||||||
|
if 2 < unique_color_count < 200:
|
||||||
|
interpolation = cv2.INTER_NEAREST
|
||||||
|
elif new_size_is_smaller:
|
||||||
|
interpolation = cv2.INTER_AREA
|
||||||
|
else:
|
||||||
|
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
||||||
|
|
||||||
|
y = cv2.resize(x, size, interpolation=interpolation)
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
|
||||||
|
|
||||||
|
if is_binary:
|
||||||
|
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
||||||
|
if is_one_pixel_edge:
|
||||||
|
y = nake_nms(y)
|
||||||
|
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||||
|
y = lvmin_thin(y, prunings=new_size_is_bigger)
|
||||||
|
else:
|
||||||
|
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||||
|
y = np.stack([y] * 3, axis=2)
|
||||||
|
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
|
||||||
|
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
|
||||||
|
y = np.concatenate([y, inpaint_mask], axis=2)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||||
|
if resize_mode == "just_resize": # RESIZE
|
||||||
|
np_img = high_quality_resize(np_img, (w, h))
|
||||||
|
np_img = safe_numpy(np_img)
|
||||||
|
return get_pytorch_control(np_img), np_img
|
||||||
|
|
||||||
|
old_h, old_w, _ = np_img.shape
|
||||||
|
old_w = float(old_w)
|
||||||
|
old_h = float(old_h)
|
||||||
|
k0 = float(h) / old_h
|
||||||
|
k1 = float(w) / old_w
|
||||||
|
|
||||||
|
safeint = lambda x: int(np.round(x))
|
||||||
|
|
||||||
|
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||||
|
if resize_mode == "fill_resize": # OUTER_FIT
|
||||||
|
k = min(k0, k1)
|
||||||
|
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
|
||||||
|
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
|
||||||
|
if len(high_quality_border_color) == 4:
|
||||||
|
# Inpaint hijack
|
||||||
|
high_quality_border_color[3] = 255
|
||||||
|
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
||||||
|
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||||
|
new_h, new_w, _ = np_img.shape
|
||||||
|
pad_h = max(0, (h - new_h) // 2)
|
||||||
|
pad_w = max(0, (w - new_w) // 2)
|
||||||
|
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
|
||||||
|
np_img = high_quality_background
|
||||||
|
np_img = safe_numpy(np_img)
|
||||||
|
return get_pytorch_control(np_img), np_img
|
||||||
|
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||||
|
k = max(k0, k1)
|
||||||
|
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||||
|
new_h, new_w, _ = np_img.shape
|
||||||
|
pad_h = max(0, (new_h - h) // 2)
|
||||||
|
pad_w = max(0, (new_w - w) // 2)
|
||||||
|
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
||||||
|
np_img = safe_numpy(np_img)
|
||||||
|
return get_pytorch_control(np_img), np_img
|
||||||
|
|
||||||
|
def prepare_control_image(
|
||||||
|
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||||
|
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||||
|
image: Image,
|
||||||
|
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||||
|
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
|
||||||
|
width=512, # should be 8 * latent.shape[3]
|
||||||
|
height=512, # should be 8 * latent height[2]
|
||||||
|
# batch_size=1, # currently no batching
|
||||||
|
# num_images_per_prompt=1, # currently only single image
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
control_mode="balanced",
|
||||||
|
resize_mode="just_resize_simple",
|
||||||
|
):
|
||||||
|
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||||
|
if (resize_mode == "just_resize_simple" or
|
||||||
|
resize_mode == "crop_resize_simple" or
|
||||||
|
resize_mode == "fill_resize_simple"):
|
||||||
|
image = image.convert("RGB")
|
||||||
|
if (resize_mode == "just_resize_simple"):
|
||||||
|
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
|
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
||||||
|
pass
|
||||||
|
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
||||||
|
pass
|
||||||
|
nimage = np.array(image)
|
||||||
|
nimage = nimage[None, :]
|
||||||
|
nimage = np.concatenate([nimage], axis=0)
|
||||||
|
# normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
|
||||||
|
nimage = np.array(nimage).astype(np.float32) / 255.0
|
||||||
|
nimage = nimage.transpose(0, 3, 1, 2)
|
||||||
|
timage = torch.from_numpy(nimage)
|
||||||
|
|
||||||
|
# use fancy lvmin controlnet resizing
|
||||||
|
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
||||||
|
nimage = np.array(image)
|
||||||
|
timage, nimage = np_img_resize(
|
||||||
|
np_img=nimage,
|
||||||
|
resize_mode=resize_mode,
|
||||||
|
h=height,
|
||||||
|
w=width,
|
||||||
|
# device=torch.device('cpu')
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
print("ERROR: invalid resize_mode ==> ", resize_mode)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
timage = timage.to(device=device, dtype=dtype)
|
||||||
|
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
|
if do_classifier_free_guidance and not cfg_injection:
|
||||||
|
timage = torch.cat([timage] * 2)
|
||||||
|
return timage
|
@ -219,6 +219,7 @@ class ControlNetData:
|
|||||||
begin_step_percent: float = Field(default=0.0)
|
begin_step_percent: float = Field(default=0.0)
|
||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
control_mode: str = Field(default="balanced")
|
control_mode: str = Field(default="balanced")
|
||||||
|
resize_mode: str = Field(default="just_resize")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if cfg_injection:
|
if cfg_injection:
|
||||||
# Inferred ControlNet only for the conditional batch.
|
# Inferred ControlNet only for the conditional batch.
|
||||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||||
# add 0 to the unconditional batch to keep it unchanged.
|
# prepend zeros for unconditional batch
|
||||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||||
|
|
||||||
@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
debug_image(
|
debug_image(
|
||||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
|
||||||
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
|
||||||
@staticmethod
|
|
||||||
def prepare_control_image(
|
|
||||||
image,
|
|
||||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
|
||||||
# latents,
|
|
||||||
width=512, # should be 8 * latent.shape[3]
|
|
||||||
height=512, # should be 8 * latent height[2]
|
|
||||||
batch_size=1,
|
|
||||||
num_images_per_prompt=1,
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.float16,
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
control_mode="balanced"
|
|
||||||
):
|
|
||||||
|
|
||||||
if not isinstance(image, torch.Tensor):
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
|
||||||
image = [image]
|
|
||||||
|
|
||||||
if isinstance(image[0], PIL.Image.Image):
|
|
||||||
images = []
|
|
||||||
for image_ in image:
|
|
||||||
image_ = image_.convert("RGB")
|
|
||||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
|
||||||
image_ = np.array(image_)
|
|
||||||
image_ = image_[None, :]
|
|
||||||
images.append(image_)
|
|
||||||
image = images
|
|
||||||
image = np.concatenate(image, axis=0)
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
image = image.transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
elif isinstance(image[0], torch.Tensor):
|
|
||||||
image = torch.cat(image, dim=0)
|
|
||||||
|
|
||||||
image_batch_size = image.shape[0]
|
|
||||||
if image_batch_size == 1:
|
|
||||||
repeat_by = batch_size
|
|
||||||
else:
|
|
||||||
# image batch size is the same as prompt batch size
|
|
||||||
repeat_by = num_images_per_prompt
|
|
||||||
image = image.repeat_interleave(repeat_by, dim=0)
|
|
||||||
image = image.to(device=device, dtype=dtype)
|
|
||||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
|
||||||
if do_classifier_free_guidance and not cfg_injection:
|
|
||||||
image = torch.cat([image] * 2)
|
|
||||||
return image
|
|
||||||
|
@ -17,13 +17,13 @@ import {
|
|||||||
} from 'common/components/IAIImageFallback';
|
} from 'common/components/IAIImageFallback';
|
||||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||||
|
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||||
import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react';
|
import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react';
|
||||||
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
|
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
|
||||||
import { ImageDTO, PostUploadAction } from 'services/api/types';
|
import { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
import IAIDraggable from './IAIDraggable';
|
import IAIDraggable from './IAIDraggable';
|
||||||
import IAIDroppable from './IAIDroppable';
|
import IAIDroppable from './IAIDroppable';
|
||||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
|
||||||
|
|
||||||
type IAIDndImageProps = {
|
type IAIDndImageProps = {
|
||||||
imageDTO: ImageDTO | undefined;
|
imageDTO: ImageDTO | undefined;
|
||||||
@ -148,7 +148,9 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
maxH: 'full',
|
maxH: 'full',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
shadow: isSelected ? 'selected.light' : undefined,
|
shadow: isSelected ? 'selected.light' : undefined,
|
||||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
_dark: {
|
||||||
|
shadow: isSelected ? 'selected.dark' : undefined,
|
||||||
|
},
|
||||||
...imageSx,
|
...imageSx,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
@ -183,13 +185,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||||
{!isDropDisabled && (
|
|
||||||
<IAIDroppable
|
|
||||||
data={droppableData}
|
|
||||||
disabled={isDropDisabled}
|
|
||||||
dropLabel={dropLabel}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{imageDTO && !isDragDisabled && (
|
{imageDTO && !isDragDisabled && (
|
||||||
<IAIDraggable
|
<IAIDraggable
|
||||||
data={draggableData}
|
data={draggableData}
|
||||||
@ -197,6 +192,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{!isDropDisabled && (
|
||||||
|
<IAIDroppable
|
||||||
|
data={droppableData}
|
||||||
|
disabled={isDropDisabled}
|
||||||
|
dropLabel={dropLabel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{onClickReset && withResetIcon && imageDTO && (
|
{onClickReset && withResetIcon && imageDTO && (
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
onClick={onClickReset}
|
onClick={onClickReset}
|
||||||
|
@ -13,10 +13,11 @@ type IAIDroppableProps = {
|
|||||||
dropLabel?: ReactNode;
|
dropLabel?: ReactNode;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
data?: TypesafeDroppableData;
|
data?: TypesafeDroppableData;
|
||||||
|
hoverRef?: React.Ref<HTMLDivElement>;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIDroppable = (props: IAIDroppableProps) => {
|
const IAIDroppable = (props: IAIDroppableProps) => {
|
||||||
const { dropLabel, data, disabled } = props;
|
const { dropLabel, data, disabled, hoverRef } = props;
|
||||||
const dndId = useRef(uuidv4());
|
const dndId = useRef(uuidv4());
|
||||||
|
|
||||||
const { isOver, setNodeRef, active } = useDroppable({
|
const { isOver, setNodeRef, active } = useDroppable({
|
||||||
|
@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
|||||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||||
|
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
|
||||||
|
|
||||||
type ControlNetProps = {
|
type ControlNetProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
flexDir: 'column',
|
flexDir: 'column',
|
||||||
gap: 2,
|
gap: 3,
|
||||||
p: 3,
|
p: 3,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
@ -117,7 +118,12 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||||
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||||
onClick={toggleIsExpanded}
|
onClick={toggleIsExpanded}
|
||||||
variant="link"
|
variant="ghost"
|
||||||
|
sx={{
|
||||||
|
_hover: {
|
||||||
|
bg: 'none',
|
||||||
|
},
|
||||||
|
}}
|
||||||
icon={
|
icon={
|
||||||
<ChevronUpIcon
|
<ChevronUpIcon
|
||||||
sx={{
|
sx={{
|
||||||
@ -151,7 +157,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
|
||||||
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
@ -176,16 +182,16 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
h: 28,
|
h: 28,
|
||||||
w: 28,
|
w: 28,
|
||||||
aspectRatio: '1/1',
|
aspectRatio: '1/1',
|
||||||
mt: 3,
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box mt={2}>
|
<Flex sx={{ gap: 2 }}>
|
||||||
<ParamControlNetControlMode controlNetId={controlNetId} />
|
<ParamControlNetControlMode controlNetId={controlNetId} />
|
||||||
</Box>
|
<ParamControlNetResizeMode controlNetId={controlNetId} />
|
||||||
|
</Flex>
|
||||||
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
|
@ -0,0 +1,62 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import {
|
||||||
|
ResizeModes,
|
||||||
|
controlNetResizeModeChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type ParamControlNetResizeModeProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const RESIZE_MODE_DATA = [
|
||||||
|
{ label: 'Resize', value: 'just_resize' },
|
||||||
|
{ label: 'Crop', value: 'crop_resize' },
|
||||||
|
{ label: 'Fill', value: 'fill_resize' },
|
||||||
|
];
|
||||||
|
|
||||||
|
export default function ParamControlNetResizeMode(
|
||||||
|
props: ParamControlNetResizeModeProps
|
||||||
|
) {
|
||||||
|
const { controlNetId } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { resizeMode, isEnabled } =
|
||||||
|
controlNet.controlNets[controlNetId];
|
||||||
|
return { resizeMode, isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { resizeMode, isEnabled } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleResizeModeChange = useCallback(
|
||||||
|
(resizeMode: ResizeModes) => {
|
||||||
|
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
disabled={!isEnabled}
|
||||||
|
label="Resize Mode"
|
||||||
|
data={RESIZE_MODE_DATA}
|
||||||
|
value={String(resizeMode)}
|
||||||
|
onChange={handleResizeModeChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
|
|||||||
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||||
import { cloneDeep, forEach } from 'lodash-es';
|
import { cloneDeep, forEach } from 'lodash-es';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { components } from 'services/api/schema';
|
||||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||||
import { appSocketInvocationError } from 'services/events/actions';
|
import { appSocketInvocationError } from 'services/events/actions';
|
||||||
import { controlNetImageProcessed } from './actions';
|
import { controlNetImageProcessed } from './actions';
|
||||||
@ -16,11 +17,13 @@ import {
|
|||||||
RequiredControlNetProcessorNode,
|
RequiredControlNetProcessorNode,
|
||||||
} from './types';
|
} from './types';
|
||||||
|
|
||||||
export type ControlModes =
|
export type ControlModes = NonNullable<
|
||||||
| 'balanced'
|
components['schemas']['ControlNetInvocation']['control_mode']
|
||||||
| 'more_prompt'
|
>;
|
||||||
| 'more_control'
|
|
||||||
| 'unbalanced';
|
export type ResizeModes = NonNullable<
|
||||||
|
components['schemas']['ControlNetInvocation']['resize_mode']
|
||||||
|
>;
|
||||||
|
|
||||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
@ -29,6 +32,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
|||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
controlMode: 'balanced',
|
controlMode: 'balanced',
|
||||||
|
resizeMode: 'just_resize',
|
||||||
controlImage: null,
|
controlImage: null,
|
||||||
processedControlImage: null,
|
processedControlImage: null,
|
||||||
processorType: 'canny_image_processor',
|
processorType: 'canny_image_processor',
|
||||||
@ -45,6 +49,7 @@ export type ControlNetConfig = {
|
|||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
controlMode: ControlModes;
|
controlMode: ControlModes;
|
||||||
|
resizeMode: ResizeModes;
|
||||||
controlImage: string | null;
|
controlImage: string | null;
|
||||||
processedControlImage: string | null;
|
processedControlImage: string | null;
|
||||||
processorType: ControlNetProcessorType;
|
processorType: ControlNetProcessorType;
|
||||||
@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({
|
|||||||
const { controlNetId, controlMode } = action.payload;
|
const { controlNetId, controlMode } = action.payload;
|
||||||
state.controlNets[controlNetId].controlMode = controlMode;
|
state.controlNets[controlNetId].controlMode = controlMode;
|
||||||
},
|
},
|
||||||
|
controlNetResizeModeChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{
|
||||||
|
controlNetId: string;
|
||||||
|
resizeMode: ResizeModes;
|
||||||
|
}>
|
||||||
|
) => {
|
||||||
|
const { controlNetId, resizeMode } = action.payload;
|
||||||
|
state.controlNets[controlNetId].resizeMode = resizeMode;
|
||||||
|
},
|
||||||
controlNetProcessorParamsChanged: (
|
controlNetProcessorParamsChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -342,6 +357,7 @@ export const {
|
|||||||
controlNetBeginStepPctChanged,
|
controlNetBeginStepPctChanged,
|
||||||
controlNetEndStepPctChanged,
|
controlNetEndStepPctChanged,
|
||||||
controlNetControlModeChanged,
|
controlNetControlModeChanged,
|
||||||
|
controlNetResizeModeChanged,
|
||||||
controlNetProcessorParamsChanged,
|
controlNetProcessorParamsChanged,
|
||||||
controlNetProcessorTypeChanged,
|
controlNetProcessorTypeChanged,
|
||||||
controlNetReset,
|
controlNetReset,
|
||||||
|
@ -23,7 +23,6 @@ const BoardContextMenu = memo(
|
|||||||
dispatch(boardIdSelected(board?.board_id ?? board_id));
|
dispatch(boardIdSelected(board?.board_id ?? board_id));
|
||||||
}, [board?.board_id, board_id, dispatch]);
|
}, [board?.board_id, board_id, dispatch]);
|
||||||
return (
|
return (
|
||||||
<Box sx={{ touchAction: 'none', height: 'full' }}>
|
|
||||||
<ContextMenu<HTMLDivElement>
|
<ContextMenu<HTMLDivElement>
|
||||||
menuProps={{ size: 'sm', isLazy: true }}
|
menuProps={{ size: 'sm', isLazy: true }}
|
||||||
menuButtonProps={{
|
menuButtonProps={{
|
||||||
@ -50,7 +49,6 @@ const BoardContextMenu = memo(
|
|||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
</ContextMenu>
|
</ContextMenu>
|
||||||
</Box>
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -1,27 +1,21 @@
|
|||||||
import {
|
import { ButtonGroup, Collapse, Flex, Grid, GridItem } from '@chakra-ui/react';
|
||||||
Collapse,
|
|
||||||
Flex,
|
|
||||||
Grid,
|
|
||||||
GridItem,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
import { memo, useState } from 'react';
|
import { memo, useCallback, useState } from 'react';
|
||||||
|
import { FaSearch } from 'react-icons/fa';
|
||||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||||
|
import { BoardDTO } from 'services/api/types';
|
||||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||||
|
import DeleteBoardModal from '../DeleteBoardModal';
|
||||||
import AddBoardButton from './AddBoardButton';
|
import AddBoardButton from './AddBoardButton';
|
||||||
import AllAssetsBoard from './AllAssetsBoard';
|
|
||||||
import AllImagesBoard from './AllImagesBoard';
|
|
||||||
import BatchBoard from './BatchBoard';
|
|
||||||
import BoardsSearch from './BoardsSearch';
|
import BoardsSearch from './BoardsSearch';
|
||||||
import GalleryBoard from './GalleryBoard';
|
import GalleryBoard from './GalleryBoard';
|
||||||
import NoBoardBoard from './NoBoardBoard';
|
import SystemBoardButton from './SystemBoardButton';
|
||||||
import DeleteBoardModal from '../DeleteBoardModal';
|
|
||||||
import { BoardDTO } from 'services/api/types';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
@ -48,7 +42,10 @@ const BoardsList = (props: Props) => {
|
|||||||
)
|
)
|
||||||
: boards;
|
: boards;
|
||||||
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
|
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
|
||||||
const [searchMode, setSearchMode] = useState(false);
|
const [isSearching, setIsSearching] = useState(false);
|
||||||
|
const handleClickSearchIcon = useCallback(() => {
|
||||||
|
setIsSearching((v) => !v);
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -64,7 +61,54 @@ const BoardsList = (props: Props) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||||
<BoardsSearch setSearchMode={setSearchMode} />
|
<AnimatePresence mode="popLayout">
|
||||||
|
{isSearching ? (
|
||||||
|
<motion.div
|
||||||
|
key="boards-search"
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: 1,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
>
|
||||||
|
<BoardsSearch setIsSearching={setIsSearching} />
|
||||||
|
</motion.div>
|
||||||
|
) : (
|
||||||
|
<motion.div
|
||||||
|
key="system-boards-select"
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: 1,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
>
|
||||||
|
<ButtonGroup sx={{ w: 'full', ps: 1.5 }} isAttached>
|
||||||
|
<SystemBoardButton board_id="images" />
|
||||||
|
<SystemBoardButton board_id="assets" />
|
||||||
|
<SystemBoardButton board_id="no_board" />
|
||||||
|
</ButtonGroup>
|
||||||
|
</motion.div>
|
||||||
|
)}
|
||||||
|
</AnimatePresence>
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label="Search Boards"
|
||||||
|
size="sm"
|
||||||
|
isChecked={isSearching}
|
||||||
|
onClick={handleClickSearchIcon}
|
||||||
|
icon={<FaSearch />}
|
||||||
|
/>
|
||||||
<AddBoardButton />
|
<AddBoardButton />
|
||||||
</Flex>
|
</Flex>
|
||||||
<OverlayScrollbarsComponent
|
<OverlayScrollbarsComponent
|
||||||
@ -82,29 +126,10 @@ const BoardsList = (props: Props) => {
|
|||||||
<Grid
|
<Grid
|
||||||
className="list-container"
|
className="list-container"
|
||||||
sx={{
|
sx={{
|
||||||
gridTemplateRows: '6.5rem 6.5rem',
|
gridTemplateColumns: `repeat(auto-fill, minmax(96px, 1fr));`,
|
||||||
gridAutoFlow: 'column dense',
|
maxH: 346,
|
||||||
gridAutoColumns: '5rem',
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{!searchMode && (
|
|
||||||
<>
|
|
||||||
<GridItem sx={{ p: 1.5 }}>
|
|
||||||
<AllImagesBoard isSelected={selectedBoardId === 'images'} />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem sx={{ p: 1.5 }}>
|
|
||||||
<AllAssetsBoard isSelected={selectedBoardId === 'assets'} />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem sx={{ p: 1.5 }}>
|
|
||||||
<NoBoardBoard isSelected={selectedBoardId === 'no_board'} />
|
|
||||||
</GridItem>
|
|
||||||
{isBatchEnabled && (
|
|
||||||
<GridItem sx={{ p: 1.5 }}>
|
|
||||||
<BatchBoard isSelected={selectedBoardId === 'batch'} />
|
|
||||||
</GridItem>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{filteredBoards &&
|
{filteredBoards &&
|
||||||
filteredBoards.map((board) => (
|
filteredBoards.map((board) => (
|
||||||
<GridItem key={board.board_id} sx={{ p: 1.5 }}>
|
<GridItem key={board.board_id} sx={{ p: 1.5 }}>
|
||||||
|
@ -10,7 +10,14 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { setBoardSearchText } from 'features/gallery/store/boardSlice';
|
import { setBoardSearchText } from 'features/gallery/store/boardSlice';
|
||||||
import { memo } from 'react';
|
import {
|
||||||
|
ChangeEvent,
|
||||||
|
KeyboardEvent,
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useEffect,
|
||||||
|
useRef,
|
||||||
|
} from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
@ -22,31 +29,60 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
setSearchMode: (searchMode: boolean) => void;
|
setIsSearching: (isSearching: boolean) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
const BoardsSearch = (props: Props) => {
|
const BoardsSearch = (props: Props) => {
|
||||||
const { setSearchMode } = props;
|
const { setIsSearching } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { searchText } = useAppSelector(selector);
|
const { searchText } = useAppSelector(selector);
|
||||||
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
const handleBoardSearch = (searchTerm: string) => {
|
const handleBoardSearch = useCallback(
|
||||||
setSearchMode(searchTerm.length > 0);
|
(searchTerm: string) => {
|
||||||
dispatch(setBoardSearchText(searchTerm));
|
dispatch(setBoardSearchText(searchTerm));
|
||||||
};
|
},
|
||||||
const clearBoardSearch = () => {
|
[dispatch]
|
||||||
setSearchMode(false);
|
);
|
||||||
|
|
||||||
|
const clearBoardSearch = useCallback(() => {
|
||||||
dispatch(setBoardSearchText(''));
|
dispatch(setBoardSearchText(''));
|
||||||
};
|
setIsSearching(false);
|
||||||
|
}, [dispatch, setIsSearching]);
|
||||||
|
|
||||||
|
const handleKeydown = useCallback(
|
||||||
|
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||||
|
// exit search mode on escape
|
||||||
|
if (e.key === 'Escape') {
|
||||||
|
clearBoardSearch();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[clearBoardSearch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
handleBoardSearch(e.target.value);
|
||||||
|
},
|
||||||
|
[handleBoardSearch]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// focus the search box on mount
|
||||||
|
if (!inputRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
inputRef.current.focus();
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<InputGroup>
|
<InputGroup>
|
||||||
<Input
|
<Input
|
||||||
|
ref={inputRef}
|
||||||
placeholder="Search Boards..."
|
placeholder="Search Boards..."
|
||||||
value={searchText}
|
value={searchText}
|
||||||
onChange={(e) => {
|
onKeyDown={handleKeydown}
|
||||||
handleBoardSearch(e.target.value);
|
onChange={handleChange}
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
{searchText && searchText.length && (
|
{searchText && searchText.length && (
|
||||||
<InputRightElement>
|
<InputRightElement>
|
||||||
@ -55,7 +91,8 @@ const BoardsSearch = (props: Props) => {
|
|||||||
size="xs"
|
size="xs"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
aria-label="Clear Search"
|
aria-label="Clear Search"
|
||||||
icon={<CloseIcon boxSize={3} />}
|
opacity={0.5}
|
||||||
|
icon={<CloseIcon boxSize={2} />}
|
||||||
/>
|
/>
|
||||||
</InputRightElement>
|
</InputRightElement>
|
||||||
)}
|
)}
|
||||||
|
@ -6,9 +6,9 @@ import {
|
|||||||
EditableInput,
|
EditableInput,
|
||||||
EditablePreview,
|
EditablePreview,
|
||||||
Flex,
|
Flex,
|
||||||
|
Icon,
|
||||||
Image,
|
Image,
|
||||||
Text,
|
Text,
|
||||||
useColorMode,
|
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
@ -17,14 +17,12 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDroppable from 'common/components/IAIDroppable';
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
|
||||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { FaUser } from 'react-icons/fa';
|
import { FaFolder } from 'react-icons/fa';
|
||||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { BoardDTO } from 'services/api/types';
|
import { BoardDTO } from 'services/api/types';
|
||||||
import { mode } from 'theme/util/mode';
|
|
||||||
import BoardContextMenu from '../BoardContextMenu';
|
import BoardContextMenu from '../BoardContextMenu';
|
||||||
|
|
||||||
const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = {
|
const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = {
|
||||||
@ -66,8 +64,9 @@ const GalleryBoard = memo(
|
|||||||
board.cover_image_name ?? skipToken
|
board.cover_image_name ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
const { colorMode } = useColorMode();
|
|
||||||
const { board_name, board_id } = board;
|
const { board_name, board_id } = board;
|
||||||
|
const [localBoardName, setLocalBoardName] = useState(board_name);
|
||||||
|
|
||||||
const handleSelectBoard = useCallback(() => {
|
const handleSelectBoard = useCallback(() => {
|
||||||
dispatch(boardIdSelected(board_id));
|
dispatch(boardIdSelected(board_id));
|
||||||
}, [board_id, dispatch]);
|
}, [board_id, dispatch]);
|
||||||
@ -75,10 +74,6 @@ const GalleryBoard = memo(
|
|||||||
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
||||||
useUpdateBoardMutation();
|
useUpdateBoardMutation();
|
||||||
|
|
||||||
const handleUpdateBoardName = (newBoardName: string) => {
|
|
||||||
updateBoard({ board_id, changes: { board_name: newBoardName } });
|
|
||||||
};
|
|
||||||
|
|
||||||
const droppableData: MoveBoardDropData = useMemo(
|
const droppableData: MoveBoardDropData = useMemo(
|
||||||
() => ({
|
() => ({
|
||||||
id: board_id,
|
id: board_id,
|
||||||
@ -88,8 +83,49 @@ const GalleryBoard = memo(
|
|||||||
[board_id]
|
[board_id]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleSubmit = useCallback(
|
||||||
|
(newBoardName: string) => {
|
||||||
|
if (!newBoardName) {
|
||||||
|
// empty strings are not allowed
|
||||||
|
setLocalBoardName(board_name);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (newBoardName === board_name) {
|
||||||
|
// don't updated the board name if it hasn't changed
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
updateBoard({ board_id, changes: { board_name: newBoardName } })
|
||||||
|
.unwrap()
|
||||||
|
.then((response) => {
|
||||||
|
// update local state
|
||||||
|
setLocalBoardName(response.board_name);
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
// revert on error
|
||||||
|
setLocalBoardName(board_name);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[board_id, board_name, updateBoard]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChange = useCallback((newBoardName: string) => {
|
||||||
|
setLocalBoardName(newBoardName);
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box sx={{ touchAction: 'none', height: 'full' }}>
|
<Box
|
||||||
|
sx={{ w: 'full', h: 'full', touchAction: 'none', userSelect: 'none' }}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
aspectRatio: '1/1',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
<BoardContextMenu
|
<BoardContextMenu
|
||||||
board={board}
|
board={board}
|
||||||
board_id={board_id}
|
board_id={board_id}
|
||||||
@ -97,50 +133,66 @@ const GalleryBoard = memo(
|
|||||||
>
|
>
|
||||||
{(ref) => (
|
{(ref) => (
|
||||||
<Flex
|
<Flex
|
||||||
key={board_id}
|
|
||||||
userSelect="none"
|
|
||||||
ref={ref}
|
ref={ref}
|
||||||
sx={{
|
|
||||||
flexDir: 'column',
|
|
||||||
justifyContent: 'space-between',
|
|
||||||
alignItems: 'center',
|
|
||||||
cursor: 'pointer',
|
|
||||||
w: 'full',
|
|
||||||
h: 'full',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Flex
|
|
||||||
onClick={handleSelectBoard}
|
onClick={handleSelectBoard}
|
||||||
sx={{
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
w: 'full',
|
cursor: 'pointer',
|
||||||
aspectRatio: '1/1',
|
|
||||||
overflow: 'hidden',
|
|
||||||
shadow: isSelected ? 'selected.light' : undefined,
|
|
||||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
|
||||||
flexShrink: 0,
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{board.cover_image_name && coverImage?.thumbnail_url && (
|
<Flex
|
||||||
<Image src={coverImage?.thumbnail_url} draggable={false} />
|
|
||||||
)}
|
|
||||||
{!(board.cover_image_name && coverImage?.thumbnail_url) && (
|
|
||||||
<IAINoContentFallback
|
|
||||||
boxSize={8}
|
|
||||||
icon={FaUser}
|
|
||||||
sx={{
|
sx={{
|
||||||
borderWidth: '2px',
|
w: 'full',
|
||||||
borderStyle: 'solid',
|
h: 'full',
|
||||||
borderColor: 'base.200',
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
bg: 'base.200',
|
||||||
_dark: {
|
_dark: {
|
||||||
borderColor: 'base.800',
|
bg: 'base.800',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{coverImage?.thumbnail_url ? (
|
||||||
|
<Image
|
||||||
|
src={coverImage?.thumbnail_url}
|
||||||
|
draggable={false}
|
||||||
|
sx={{
|
||||||
|
maxW: 'full',
|
||||||
|
maxH: 'full',
|
||||||
|
borderRadius: 'base',
|
||||||
|
borderBottomRadius: 'lg',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Icon
|
||||||
|
boxSize={12}
|
||||||
|
as={FaFolder}
|
||||||
|
sx={{
|
||||||
|
mt: -3,
|
||||||
|
opacity: 0.7,
|
||||||
|
color: 'base.500',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.500',
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
</Flex>
|
||||||
)}
|
)}
|
||||||
|
</Flex>
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
@ -160,37 +212,59 @@ const GalleryBoard = memo(
|
|||||||
{board.image_count}
|
{board.image_count}
|
||||||
</Badge>
|
</Badge>
|
||||||
</Flex>
|
</Flex>
|
||||||
<IAIDroppable
|
<Box
|
||||||
data={droppableData}
|
className="selection-box"
|
||||||
dropLabel={<Text fontSize="md">Move</Text>}
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
bottom: 0,
|
||||||
|
insetInlineStart: 0,
|
||||||
|
borderRadius: 'base',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'common',
|
||||||
|
shadow: isSelected ? 'selected.light' : undefined,
|
||||||
|
_dark: {
|
||||||
|
shadow: isSelected ? 'selected.dark' : undefined,
|
||||||
|
},
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
width: 'full',
|
position: 'absolute',
|
||||||
height: 'full',
|
bottom: 0,
|
||||||
|
left: 0,
|
||||||
|
p: 1,
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
|
w: 'full',
|
||||||
|
maxW: 'full',
|
||||||
|
borderBottomRadius: 'base',
|
||||||
|
bg: isSelected ? 'accent.400' : 'base.500',
|
||||||
|
color: isSelected ? 'base.50' : 'base.100',
|
||||||
|
_dark: {
|
||||||
|
bg: isSelected ? 'accent.500' : 'base.600',
|
||||||
|
color: isSelected ? 'base.50' : 'base.100',
|
||||||
|
},
|
||||||
|
lineHeight: 'short',
|
||||||
|
fontSize: 'xs',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Editable
|
<Editable
|
||||||
defaultValue={board_name}
|
value={localBoardName}
|
||||||
|
isDisabled={isUpdateBoardLoading}
|
||||||
submitOnBlur={true}
|
submitOnBlur={true}
|
||||||
onSubmit={(nextValue) => {
|
onChange={handleChange}
|
||||||
handleUpdateBoardName(nextValue);
|
onSubmit={handleSubmit}
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
}}
|
}}
|
||||||
sx={{ maxW: 'full' }}
|
|
||||||
>
|
>
|
||||||
<EditablePreview
|
<EditablePreview
|
||||||
sx={{
|
sx={{
|
||||||
color: isSelected
|
|
||||||
? mode('base.900', 'base.50')(colorMode)
|
|
||||||
: mode('base.700', 'base.200')(colorMode),
|
|
||||||
fontWeight: isSelected ? 600 : undefined,
|
|
||||||
fontSize: 'xs',
|
|
||||||
textAlign: 'center',
|
|
||||||
p: 0,
|
p: 0,
|
||||||
|
fontWeight: isSelected ? 700 : 500,
|
||||||
|
textAlign: 'center',
|
||||||
overflow: 'hidden',
|
overflow: 'hidden',
|
||||||
textOverflow: 'ellipsis',
|
textOverflow: 'ellipsis',
|
||||||
}}
|
}}
|
||||||
@ -198,18 +272,26 @@ const GalleryBoard = memo(
|
|||||||
/>
|
/>
|
||||||
<EditableInput
|
<EditableInput
|
||||||
sx={{
|
sx={{
|
||||||
color: mode('base.900', 'base.50')(colorMode),
|
|
||||||
fontSize: 'xs',
|
|
||||||
borderColor: mode('base.500', 'base.500')(colorMode),
|
|
||||||
p: 0,
|
p: 0,
|
||||||
outline: 0,
|
_focusVisible: {
|
||||||
|
p: 0,
|
||||||
|
textAlign: 'center',
|
||||||
|
// get rid of the edit border
|
||||||
|
boxShadow: 'none',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Editable>
|
</Editable>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
|
<IAIDroppable
|
||||||
|
data={droppableData}
|
||||||
|
dropLabel={<Text fontSize="md">Move</Text>}
|
||||||
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</BoardContextMenu>
|
</BoardContextMenu>
|
||||||
|
</Flex>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ type GenericBoardProps = {
|
|||||||
badgeCount?: number;
|
badgeCount?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
const formatBadgeCount = (count: number) =>
|
export const formatBadgeCount = (count: number) =>
|
||||||
Intl.NumberFormat('en-US', {
|
Intl.NumberFormat('en-US', {
|
||||||
notation: 'compact',
|
notation: 'compact',
|
||||||
maximumFractionDigits: 1,
|
maximumFractionDigits: 1,
|
||||||
@ -92,7 +92,7 @@ const GenericBoard = (props: GenericBoardProps) => {
|
|||||||
h: 'full',
|
h: 'full',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
fontWeight: isSelected ? 600 : undefined,
|
fontWeight: isSelected ? 600 : undefined,
|
||||||
fontSize: 'xs',
|
fontSize: 'sm',
|
||||||
color: isSelected ? 'base.900' : 'base.700',
|
color: isSelected ? 'base.900' : 'base.700',
|
||||||
_dark: { color: isSelected ? 'base.50' : 'base.200' },
|
_dark: { color: isSelected ? 'base.50' : 'base.200' },
|
||||||
}}
|
}}
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
board_id: 'images' | 'assets' | 'no_board';
|
||||||
|
};
|
||||||
|
|
||||||
|
const SystemBoardButton = ({ board_id }: Props) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ gallery }) => {
|
||||||
|
const { selectedBoardId } = gallery;
|
||||||
|
return { isSelected: selectedBoardId === board_id };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[board_id]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { isSelected } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const boardName = useBoardName(board_id);
|
||||||
|
|
||||||
|
const handleClick = useCallback(() => {
|
||||||
|
dispatch(boardIdSelected(board_id));
|
||||||
|
}, [board_id, dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIButton
|
||||||
|
onClick={handleClick}
|
||||||
|
size="sm"
|
||||||
|
isChecked={isSelected}
|
||||||
|
sx={{
|
||||||
|
flexGrow: 1,
|
||||||
|
borderRadius: 'base',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{boardName}
|
||||||
|
</IAIButton>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SystemBoardButton);
|
@ -4,8 +4,9 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { memo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||||
|
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
@ -26,6 +27,17 @@ const GalleryBoardName = (props: Props) => {
|
|||||||
const { isOpen, onToggle } = props;
|
const { isOpen, onToggle } = props;
|
||||||
const { selectedBoardId } = useAppSelector(selector);
|
const { selectedBoardId } = useAppSelector(selector);
|
||||||
const boardName = useBoardName(selectedBoardId);
|
const boardName = useBoardName(selectedBoardId);
|
||||||
|
const numOfBoardImages = useBoardTotal(selectedBoardId);
|
||||||
|
|
||||||
|
const formattedBoardName = useMemo(() => {
|
||||||
|
if (!boardName || !numOfBoardImages) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
if (boardName.length > 20) {
|
||||||
|
return `${boardName.substring(0, 20)}... (${numOfBoardImages})`;
|
||||||
|
}
|
||||||
|
return `${boardName} (${numOfBoardImages})`;
|
||||||
|
}, [boardName, numOfBoardImages]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -58,7 +70,7 @@ const GalleryBoardName = (props: Props) => {
|
|||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{boardName}
|
{formattedBoardName}
|
||||||
</Text>
|
</Text>
|
||||||
</Box>
|
</Box>
|
||||||
<Spacer />
|
<Spacer />
|
||||||
|
@ -109,7 +109,7 @@ const GalleryDrawer = () => {
|
|||||||
isResizable={true}
|
isResizable={true}
|
||||||
isOpen={shouldShowGallery}
|
isOpen={shouldShowGallery}
|
||||||
onClose={handleCloseGallery}
|
onClose={handleCloseGallery}
|
||||||
minWidth={337}
|
minWidth={400}
|
||||||
>
|
>
|
||||||
<ImageGalleryContent />
|
<ImageGalleryContent />
|
||||||
</ResizableDrawer>
|
</ResizableDrawer>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
@ -86,15 +86,10 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
|
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
|
||||||
<ImageContextMenu imageDTO={imageDTO}>
|
<Flex
|
||||||
{(ref) => (
|
|
||||||
<Box
|
|
||||||
position="relative"
|
|
||||||
key={imageName}
|
|
||||||
userSelect="none"
|
userSelect="none"
|
||||||
ref={ref}
|
|
||||||
sx={{
|
sx={{
|
||||||
display: 'flex',
|
position: 'relative',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
aspectRatio: '1/1',
|
aspectRatio: '1/1',
|
||||||
@ -115,9 +110,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
// resetTooltip="Delete image"
|
// resetTooltip="Delete image"
|
||||||
// withResetIcon // removed bc it's too easy to accidentally delete images
|
// withResetIcon // removed bc it's too easy to accidentally delete images
|
||||||
/>
|
/>
|
||||||
</Box>
|
</Flex>
|
||||||
)}
|
|
||||||
</ImageContextMenu>
|
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = (
|
|||||||
beginStepPct,
|
beginStepPct,
|
||||||
endStepPct,
|
endStepPct,
|
||||||
controlMode,
|
controlMode,
|
||||||
|
resizeMode,
|
||||||
model,
|
model,
|
||||||
processorType,
|
processorType,
|
||||||
weight,
|
weight,
|
||||||
@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = (
|
|||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
control_mode: controlMode,
|
control_mode: controlMode,
|
||||||
|
resize_mode: resizeMode,
|
||||||
control_model: model as ControlNetInvocation['control_model'],
|
control_model: model as ControlNetInvocation['control_model'],
|
||||||
control_weight: weight,
|
control_weight: weight,
|
||||||
};
|
};
|
||||||
|
@ -105,7 +105,7 @@ const enabledTabsSelector = createSelector(
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const MIN_GALLERY_WIDTH = 300;
|
const MIN_GALLERY_WIDTH = 350;
|
||||||
const DEFAULT_GALLERY_PCT = 20;
|
const DEFAULT_GALLERY_PCT = 20;
|
||||||
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];
|
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import { memo } from 'react';
|
|||||||
import { PanelResizeHandle } from 'react-resizable-panels';
|
import { PanelResizeHandle } from 'react-resizable-panels';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
|
|
||||||
type ResizeHandleProps = FlexProps & {
|
type ResizeHandleProps = Omit<FlexProps, 'direction'> & {
|
||||||
direction?: 'horizontal' | 'vertical';
|
direction?: 'horizontal' | 'vertical';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -169,10 +169,11 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
],
|
],
|
||||||
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
|
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for deleteImage:
|
* Cache changes for `deleteImage`:
|
||||||
* - Remove from "All Images"
|
* - *remove* from "All Images" / "All Assets"
|
||||||
* - Remove from image's `board_id` if it has one, or "No Board" if not
|
* - IF it has a board:
|
||||||
* - Remove from "Batch"
|
* - THEN *remove* from it's own board
|
||||||
|
* - ELSE *remove* from "No Board"
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { image_name, board_id, image_category } = imageDTO;
|
const { image_name, board_id, image_category } = imageDTO;
|
||||||
@ -181,22 +182,23 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// That means constructing the possible query args that are serialized into the cache key...
|
// That means constructing the possible query args that are serialized into the cache key...
|
||||||
|
|
||||||
const removeFromCacheKeys: ListImagesArgs[] = [];
|
const removeFromCacheKeys: ListImagesArgs[] = [];
|
||||||
|
|
||||||
|
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||||
? IMAGE_CATEGORIES
|
? IMAGE_CATEGORIES
|
||||||
: ASSETS_CATEGORIES;
|
: ASSETS_CATEGORIES;
|
||||||
|
|
||||||
// All Images board (e.g. no board)
|
// remove from "All Images"
|
||||||
removeFromCacheKeys.push({ categories });
|
removeFromCacheKeys.push({ categories });
|
||||||
|
|
||||||
// Board specific
|
|
||||||
if (board_id) {
|
if (board_id) {
|
||||||
|
// remove from it's own board
|
||||||
removeFromCacheKeys.push({ board_id });
|
removeFromCacheKeys.push({ board_id });
|
||||||
} else {
|
} else {
|
||||||
// TODO: No Board
|
// remove from "No Board"
|
||||||
|
removeFromCacheKeys.push({ board_id: 'none' });
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Batch
|
|
||||||
|
|
||||||
const patches: PatchCollection[] = [];
|
const patches: PatchCollection[] = [];
|
||||||
removeFromCacheKeys.forEach((cacheKey) => {
|
removeFromCacheKeys.forEach((cacheKey) => {
|
||||||
patches.push(
|
patches.push(
|
||||||
@ -240,32 +242,37 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
{ imageDTO: oldImageDTO, changes: _changes },
|
{ imageDTO: oldImageDTO, changes: _changes },
|
||||||
{ dispatch, queryFulfilled, getState }
|
{ dispatch, queryFulfilled, getState }
|
||||||
) {
|
) {
|
||||||
// TODO: Should we handle changes to boards via this mutation? Seems reasonable...
|
|
||||||
|
|
||||||
// let's be extra-sure we do not accidentally change categories
|
// let's be extra-sure we do not accidentally change categories
|
||||||
const changes = omit(_changes, 'image_category');
|
const changes = omit(_changes, 'image_category');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cache changes for `updateImage`:
|
* Cache changes for "updateImage":
|
||||||
* - Update the ImageDTO
|
* - *update* "getImageDTO" cache
|
||||||
* - Update the image in "All Images" board:
|
* - for "All Images" || "All Assets":
|
||||||
* - IF it is in the date range represented by the cache:
|
* - IF it is not already in the cache
|
||||||
* - add the image IF it is not already in the cache & update the total
|
* - THEN *add* it to "All Images" / "All Assets" and update the total
|
||||||
* - ELSE update the image IF it is already in the cache
|
* - ELSE *update* it
|
||||||
* - IF the image has a board:
|
* - IF the image has a board:
|
||||||
* - Update the image in it's own board
|
* - THEN *update* it's own board
|
||||||
* - ELSE Update the image in the "No Board" board (TODO)
|
* - ELSE *update* the "No Board" board
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const patches: PatchCollection[] = [];
|
const patches: PatchCollection[] = [];
|
||||||
const { image_name, board_id, image_category } = oldImageDTO;
|
const { image_name, board_id, image_category, is_intermediate } =
|
||||||
|
oldImageDTO;
|
||||||
|
|
||||||
|
const isChangingFromIntermediate = changes.is_intermediate === false;
|
||||||
|
// do not add intermediates to gallery cache
|
||||||
|
if (is_intermediate && !isChangingFromIntermediate) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||||
? IMAGE_CATEGORIES
|
? IMAGE_CATEGORIES
|
||||||
: ASSETS_CATEGORIES;
|
: ASSETS_CATEGORIES;
|
||||||
|
|
||||||
// TODO: No Board
|
// update `getImageDTO` cache
|
||||||
|
|
||||||
// Update `getImageDTO` cache
|
|
||||||
patches.push(
|
patches.push(
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.util.updateQueryData(
|
imagesApi.util.updateQueryData(
|
||||||
@ -281,9 +288,13 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// Update the "All Image" or "All Assets" board
|
// Update the "All Image" or "All Assets" board
|
||||||
const queryArgsToUpdate: ListImagesArgs[] = [{ categories }];
|
const queryArgsToUpdate: ListImagesArgs[] = [{ categories }];
|
||||||
|
|
||||||
|
// IF the image has a board:
|
||||||
if (board_id) {
|
if (board_id) {
|
||||||
// We also need to update the user board
|
// THEN update it's own board
|
||||||
queryArgsToUpdate.push({ board_id });
|
queryArgsToUpdate.push({ board_id });
|
||||||
|
} else {
|
||||||
|
// ELSE update the "No Board" board
|
||||||
|
queryArgsToUpdate.push({ board_id: 'none' });
|
||||||
}
|
}
|
||||||
|
|
||||||
queryArgsToUpdate.forEach((queryArg) => {
|
queryArgsToUpdate.forEach((queryArg) => {
|
||||||
@ -371,12 +382,12 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the image to the "All Images" / "All Assets" board
|
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||||
const queryArg = {
|
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||||
categories: IMAGE_CATEGORIES.includes(image_category)
|
|
||||||
? IMAGE_CATEGORIES
|
? IMAGE_CATEGORIES
|
||||||
: ASSETS_CATEGORIES,
|
: ASSETS_CATEGORIES;
|
||||||
};
|
|
||||||
|
const queryArg = { categories };
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
|
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
|
||||||
@ -410,16 +421,14 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
{ dispatch, queryFulfilled, getState }
|
{ dispatch, queryFulfilled, getState }
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for addImageToBoard:
|
* Cache changes for `addImageToBoard`:
|
||||||
* - Remove from "No Board"
|
* - *update* the `getImageDTO` cache
|
||||||
* - Remove from `old_board_id` if it has one
|
* - *remove* from "No Board"
|
||||||
* - Add to new `board_id`
|
* - IF the image has an old `board_id`:
|
||||||
|
* - THEN *remove* from it's old `board_id`
|
||||||
* - IF the image's `created_at` is within the range of the board's cached images
|
* - IF the image's `created_at` is within the range of the board's cached images
|
||||||
* - OR the board cache has length of 0 or 1
|
* - OR the board cache has length of 0 or 1
|
||||||
* - Update the `total` for each board whose cache is updated
|
* - THEN *add* it to new `board_id`
|
||||||
* - Update the ImageDTO
|
|
||||||
*
|
|
||||||
* TODO: maybe total should just be updated in the boards endpoints?
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { image_name, board_id: old_board_id } = oldImageDTO;
|
const { image_name, board_id: old_board_id } = oldImageDTO;
|
||||||
@ -427,13 +436,10 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// Figure out the `listImages` caches that we need to update
|
// Figure out the `listImages` caches that we need to update
|
||||||
const removeFromQueryArgs: ListImagesArgs[] = [];
|
const removeFromQueryArgs: ListImagesArgs[] = [];
|
||||||
|
|
||||||
// TODO: No Board
|
// remove from "No Board"
|
||||||
// TODO: Batch
|
|
||||||
|
|
||||||
// Remove from No Board
|
|
||||||
removeFromQueryArgs.push({ board_id: 'none' });
|
removeFromQueryArgs.push({ board_id: 'none' });
|
||||||
|
|
||||||
// Remove from old board
|
// remove from old board
|
||||||
if (old_board_id) {
|
if (old_board_id) {
|
||||||
removeFromQueryArgs.push({ board_id: old_board_id });
|
removeFromQueryArgs.push({ board_id: old_board_id });
|
||||||
}
|
}
|
||||||
@ -534,17 +540,15 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
{ dispatch, queryFulfilled, getState }
|
{ dispatch, queryFulfilled, getState }
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for removeImageFromBoard:
|
* Cache changes for `removeImageFromBoard`:
|
||||||
* - Add to "No Board"
|
* - *update* `getImageDTO`
|
||||||
* - IF the image's `created_at` is within the range of the board's cached images
|
* - IF the image's `created_at` is within the range of the board's cached images
|
||||||
* - Remove from `old_board_id`
|
* - THEN *add* to "No Board"
|
||||||
* - Update the ImageDTO
|
* - *remove* from `old_board_id`
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { image_name, board_id: old_board_id } = imageDTO;
|
const { image_name, board_id: old_board_id } = imageDTO;
|
||||||
|
|
||||||
// TODO: Batch
|
|
||||||
|
|
||||||
const patches: PatchCollection[] = [];
|
const patches: PatchCollection[] = [];
|
||||||
|
|
||||||
// Updated imageDTO with new board_id
|
// Updated imageDTO with new board_id
|
||||||
|
@ -6,9 +6,9 @@ export const useBoardName = (board_id: BoardId | null | undefined) => {
|
|||||||
selectFromResult: ({ data }) => {
|
selectFromResult: ({ data }) => {
|
||||||
let boardName = '';
|
let boardName = '';
|
||||||
if (board_id === 'images') {
|
if (board_id === 'images') {
|
||||||
boardName = 'All Images';
|
boardName = 'Images';
|
||||||
} else if (board_id === 'assets') {
|
} else if (board_id === 'assets') {
|
||||||
boardName = 'All Assets';
|
boardName = 'Assets';
|
||||||
} else if (board_id === 'no_board') {
|
} else if (board_id === 'no_board') {
|
||||||
boardName = 'No Board';
|
boardName = 'No Board';
|
||||||
} else if (board_id === 'batch') {
|
} else if (board_id === 'batch') {
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
import {
|
||||||
|
ASSETS_CATEGORIES,
|
||||||
|
BoardId,
|
||||||
|
IMAGE_CATEGORIES,
|
||||||
|
INITIAL_IMAGE_LIMIT,
|
||||||
|
} from 'features/gallery/store/gallerySlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { ListImagesArgs, useListImagesQuery } from '../endpoints/images';
|
||||||
|
|
||||||
|
const baseQueryArg: ListImagesArgs = {
|
||||||
|
offset: 0,
|
||||||
|
limit: INITIAL_IMAGE_LIMIT,
|
||||||
|
is_intermediate: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
const imagesQueryArg: ListImagesArgs = {
|
||||||
|
categories: IMAGE_CATEGORIES,
|
||||||
|
...baseQueryArg,
|
||||||
|
};
|
||||||
|
|
||||||
|
const assetsQueryArg: ListImagesArgs = {
|
||||||
|
categories: ASSETS_CATEGORIES,
|
||||||
|
...baseQueryArg,
|
||||||
|
};
|
||||||
|
|
||||||
|
const noBoardQueryArg: ListImagesArgs = {
|
||||||
|
board_id: 'none',
|
||||||
|
...baseQueryArg,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useBoardTotal = (board_id: BoardId | null | undefined) => {
|
||||||
|
const queryArg = useMemo(() => {
|
||||||
|
if (!board_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (board_id === 'images') {
|
||||||
|
return imagesQueryArg;
|
||||||
|
} else if (board_id === 'assets') {
|
||||||
|
return assetsQueryArg;
|
||||||
|
} else if (board_id === 'no_board') {
|
||||||
|
return noBoardQueryArg;
|
||||||
|
} else {
|
||||||
|
return { board_id, ...baseQueryArg };
|
||||||
|
}
|
||||||
|
}, [board_id]);
|
||||||
|
|
||||||
|
const { total } = useListImagesQuery(queryArg ?? skipToken, {
|
||||||
|
selectFromResult: ({ currentData }) => ({ total: currentData?.total }),
|
||||||
|
});
|
||||||
|
|
||||||
|
return total;
|
||||||
|
};
|
@ -167,7 +167,7 @@ export type paths = {
|
|||||||
"/api/v1/images/clear-intermediates": {
|
"/api/v1/images/clear-intermediates": {
|
||||||
/**
|
/**
|
||||||
* Clear Intermediates
|
* Clear Intermediates
|
||||||
* @description Clears first 100 intermediates
|
* @description Clears all intermediates
|
||||||
*/
|
*/
|
||||||
post: operations["clear_intermediates"];
|
post: operations["clear_intermediates"];
|
||||||
};
|
};
|
||||||
@ -800,6 +800,13 @@ export type components = {
|
|||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||||
|
/**
|
||||||
|
* Resize Mode
|
||||||
|
* @description The resize mode to use
|
||||||
|
* @default just_resize
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* ControlNetInvocation
|
* ControlNetInvocation
|
||||||
@ -859,6 +866,13 @@ export type components = {
|
|||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||||
|
/**
|
||||||
|
* Resize Mode
|
||||||
|
* @description The resize mode used
|
||||||
|
* @default just_resize
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
|
||||||
};
|
};
|
||||||
/** ControlNetModelConfig */
|
/** ControlNetModelConfig */
|
||||||
ControlNetModelConfig: {
|
ControlNetModelConfig: {
|
||||||
@ -5324,11 +5338,11 @@ export type components = {
|
|||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* StableDiffusion2ModelFormat
|
* StableDiffusion1ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
/**
|
/**
|
||||||
* StableDiffusionXLModelFormat
|
* StableDiffusionXLModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -5336,11 +5350,11 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||||
/**
|
/**
|
||||||
* StableDiffusion1ModelFormat
|
* StableDiffusion2ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
@ -6125,7 +6139,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Clear Intermediates
|
* Clear Intermediates
|
||||||
* @description Clears first 100 intermediates
|
* @description Clears all intermediates
|
||||||
*/
|
*/
|
||||||
clear_intermediates: {
|
clear_intermediates: {
|
||||||
responses: {
|
responses: {
|
||||||
|
@ -2,11 +2,16 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
|
|||||||
import { generateColorPalette } from 'theme/util/generateColorPalette';
|
import { generateColorPalette } from 'theme/util/generateColorPalette';
|
||||||
|
|
||||||
const BASE = { H: 220, S: 16 };
|
const BASE = { H: 220, S: 16 };
|
||||||
const ACCENT = { H: 250, S: 52 };
|
const ACCENT = { H: 250, S: 42 };
|
||||||
const WORKING = { H: 47, S: 50 };
|
// const ACCENT = { H: 250, S: 52 };
|
||||||
const WARNING = { H: 28, S: 50 };
|
const WORKING = { H: 47, S: 42 };
|
||||||
const OK = { H: 113, S: 50 };
|
// const WORKING = { H: 47, S: 50 };
|
||||||
const ERROR = { H: 0, S: 50 };
|
const WARNING = { H: 28, S: 42 };
|
||||||
|
// const WARNING = { H: 28, S: 50 };
|
||||||
|
const OK = { H: 113, S: 42 };
|
||||||
|
// const OK = { H: 113, S: 50 };
|
||||||
|
const ERROR = { H: 0, S: 42 };
|
||||||
|
// const ERROR = { H: 0, S: 50 };
|
||||||
|
|
||||||
export const InvokeAIColors: InvokeAIThemeColors = {
|
export const InvokeAIColors: InvokeAIThemeColors = {
|
||||||
base: generateColorPalette(BASE.H, BASE.S),
|
base: generateColorPalette(BASE.H, BASE.S),
|
||||||
|
56
invokeai/frontend/web/src/theme/components/editable.ts
Normal file
56
invokeai/frontend/web/src/theme/components/editable.ts
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import { editableAnatomy as parts } from '@chakra-ui/anatomy';
|
||||||
|
import {
|
||||||
|
createMultiStyleConfigHelpers,
|
||||||
|
defineStyle,
|
||||||
|
} from '@chakra-ui/styled-system';
|
||||||
|
import { mode } from '@chakra-ui/theme-tools';
|
||||||
|
|
||||||
|
const { definePartsStyle, defineMultiStyleConfig } =
|
||||||
|
createMultiStyleConfigHelpers(parts.keys);
|
||||||
|
|
||||||
|
const baseStylePreview = defineStyle({
|
||||||
|
borderRadius: 'md',
|
||||||
|
py: '1',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
});
|
||||||
|
|
||||||
|
const baseStyleInput = defineStyle((props) => ({
|
||||||
|
borderRadius: 'md',
|
||||||
|
py: '1',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
width: 'full',
|
||||||
|
_focusVisible: { boxShadow: 'outline' },
|
||||||
|
_placeholder: { opacity: 0.6 },
|
||||||
|
'::selection': {
|
||||||
|
color: mode('accent.900', 'accent.50')(props),
|
||||||
|
bg: mode('accent.200', 'accent.400')(props),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
const baseStyleTextarea = defineStyle({
|
||||||
|
borderRadius: 'md',
|
||||||
|
py: '1',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
width: 'full',
|
||||||
|
_focusVisible: { boxShadow: 'outline' },
|
||||||
|
_placeholder: { opacity: 0.6 },
|
||||||
|
});
|
||||||
|
|
||||||
|
const invokeAI = definePartsStyle((props) => ({
|
||||||
|
preview: baseStylePreview,
|
||||||
|
input: baseStyleInput(props),
|
||||||
|
textarea: baseStyleTextarea,
|
||||||
|
}));
|
||||||
|
|
||||||
|
export const editableTheme = defineMultiStyleConfig({
|
||||||
|
variants: {
|
||||||
|
invokeAI,
|
||||||
|
},
|
||||||
|
defaultProps: {
|
||||||
|
size: 'sm',
|
||||||
|
variant: 'invokeAI',
|
||||||
|
},
|
||||||
|
});
|
@ -4,6 +4,7 @@ import { InvokeAIColors } from './colors/colors';
|
|||||||
import { accordionTheme } from './components/accordion';
|
import { accordionTheme } from './components/accordion';
|
||||||
import { buttonTheme } from './components/button';
|
import { buttonTheme } from './components/button';
|
||||||
import { checkboxTheme } from './components/checkbox';
|
import { checkboxTheme } from './components/checkbox';
|
||||||
|
import { editableTheme } from './components/editable';
|
||||||
import { formLabelTheme } from './components/formLabel';
|
import { formLabelTheme } from './components/formLabel';
|
||||||
import { inputTheme } from './components/input';
|
import { inputTheme } from './components/input';
|
||||||
import { menuTheme } from './components/menu';
|
import { menuTheme } from './components/menu';
|
||||||
@ -72,7 +73,17 @@ export const theme: ThemeOverride = {
|
|||||||
selected: {
|
selected: {
|
||||||
light:
|
light:
|
||||||
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
|
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
|
||||||
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
|
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
|
||||||
|
},
|
||||||
|
hoverSelected: {
|
||||||
|
light:
|
||||||
|
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
|
||||||
|
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-300)',
|
||||||
|
},
|
||||||
|
hoverUnselected: {
|
||||||
|
light:
|
||||||
|
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-200)',
|
||||||
|
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-600)',
|
||||||
},
|
},
|
||||||
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`,
|
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`,
|
||||||
},
|
},
|
||||||
@ -80,6 +91,7 @@ export const theme: ThemeOverride = {
|
|||||||
components: {
|
components: {
|
||||||
Button: buttonTheme, // Button and IconButton
|
Button: buttonTheme, // Button and IconButton
|
||||||
Input: inputTheme,
|
Input: inputTheme,
|
||||||
|
Editable: editableTheme,
|
||||||
Textarea: textareaTheme,
|
Textarea: textareaTheme,
|
||||||
Tabs: tabsTheme,
|
Tabs: tabsTheme,
|
||||||
Progress: progressTheme,
|
Progress: progressTheme,
|
||||||
|
@ -37,4 +37,7 @@ export const getInputOutlineStyles = (props: StyleFunctionProps) => ({
|
|||||||
_placeholder: {
|
_placeholder: {
|
||||||
color: mode('base.700', 'base.400')(props),
|
color: mode('base.700', 'base.400')(props),
|
||||||
},
|
},
|
||||||
|
'::selection': {
|
||||||
|
bg: mode('accent.200', 'accent.400')(props),
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user