mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/logger-route
This commit is contained in:
commit
5134de7cfa
@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(
|
||||
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
# crop and fill options not ready yet
|
||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[tuple(
|
||||
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
@ -111,7 +111,8 @@ class ControlField(BaseModel):
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(
|
||||
default="balanced", description="The control mode to use")
|
||||
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
|
||||
default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
control_image = prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, image_tensor=control_image,
|
||||
model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
# any resizing needed should currently be happening in prepare_control_image(),
|
||||
# but adding resize_mode to ControlNetData in case needed in the future
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
@ -601,7 +607,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
@ -647,7 +653,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
|
342
invokeai/app/util/controlnet_utils.py
Normal file
342
invokeai/app/util/controlnet_utils.py
Normal file
@ -0,0 +1,342 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
|
||||
from einops import rearrange
|
||||
from controlnet_aux.util import HWC3, resize_image
|
||||
|
||||
###################################################################
|
||||
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
||||
###################################################################
|
||||
# High Quality Edge Thinning using Pure Python
|
||||
# Written by Lvmin Zhangu
|
||||
# 2023 April
|
||||
# Stanford University
|
||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||
|
||||
lvmin_kernels_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[0, -1, -1],
|
||||
[1, 1, -1],
|
||||
[0, 1, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_kernels = []
|
||||
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
|
||||
lvmin_prunings_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[0, 0, -1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[-1, 0, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_prunings = []
|
||||
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
|
||||
|
||||
def remove_pattern(x, kernel):
|
||||
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
|
||||
objects = np.where(objects > 127)
|
||||
x[objects] = 0
|
||||
return x, objects[0].shape[0] > 0
|
||||
|
||||
|
||||
def thin_one_time(x, kernels):
|
||||
y = x
|
||||
is_done = True
|
||||
for k in kernels:
|
||||
y, has_update = remove_pattern(y, k)
|
||||
if has_update:
|
||||
is_done = False
|
||||
return y, is_done
|
||||
|
||||
|
||||
def lvmin_thin(x, prunings=True):
|
||||
y = x
|
||||
for i in range(32):
|
||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||
if is_done:
|
||||
break
|
||||
if prunings:
|
||||
y, _ = thin_one_time(y, lvmin_prunings)
|
||||
return y
|
||||
|
||||
|
||||
def nake_nms(x):
|
||||
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
y = np.zeros_like(x)
|
||||
for f in [f1, f2, f3, f4]:
|
||||
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||
return y
|
||||
|
||||
|
||||
################################################################################
|
||||
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
|
||||
################################################################################
|
||||
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
||||
def pixel_perfect_resolution(
|
||||
image: np.ndarray,
|
||||
target_H: int,
|
||||
target_W: int,
|
||||
resize_mode: str,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||
|
||||
The function first calculates scaling factors for height and width of the image based on the target
|
||||
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
|
||||
scaling factor to estimate the new resolution.
|
||||
|
||||
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
|
||||
fits within the target dimensions, potentially leaving some empty space.
|
||||
|
||||
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
|
||||
dimensions are fully filled, potentially cropping the image.
|
||||
|
||||
After calculating the estimated resolution, the function prints some debugging information.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
|
||||
target_H (int): The target height for the image.
|
||||
target_W (int): The target width for the image.
|
||||
resize_mode (ResizeMode): The mode for resizing.
|
||||
|
||||
Returns:
|
||||
int: The estimated resolution after resizing.
|
||||
"""
|
||||
raw_H, raw_W, _ = image.shape
|
||||
|
||||
k0 = float(target_H) / float(raw_H)
|
||||
k1 = float(target_W) / float(raw_W)
|
||||
|
||||
if resize_mode == "fill_resize":
|
||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||
|
||||
# print(f"Pixel Perfect Computation:")
|
||||
# print(f"resize_mode = {resize_mode}")
|
||||
# print(f"raw_H = {raw_H}")
|
||||
# print(f"raw_W = {raw_W}")
|
||||
# print(f"target_H = {target_H}")
|
||||
# print(f"target_W = {target_W}")
|
||||
# print(f"estimation = {estimation}")
|
||||
|
||||
return int(np.round(estimation))
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
|
||||
# modified for InvokeAI
|
||||
###########################################################################
|
||||
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
||||
def np_img_resize(
|
||||
np_img: np.ndarray,
|
||||
resize_mode: str,
|
||||
h: int,
|
||||
w: int,
|
||||
device: torch.device = torch.device('cpu')
|
||||
):
|
||||
# if 'inpaint' in module:
|
||||
# np_img = np_img.astype(np.float32)
|
||||
# else:
|
||||
# np_img = HWC3(np_img)
|
||||
np_img = HWC3(np_img)
|
||||
|
||||
def safe_numpy(x):
|
||||
# A very safe method to make sure that Apple/Mac works
|
||||
y = x
|
||||
|
||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||
y = y.copy()
|
||||
y = np.ascontiguousarray(y)
|
||||
y = y.copy()
|
||||
return y
|
||||
|
||||
def get_pytorch_control(x):
|
||||
# A very safe method to make sure that Apple/Mac works
|
||||
y = x
|
||||
|
||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||
y = torch.from_numpy(y)
|
||||
y = y.float() / 255.0
|
||||
y = rearrange(y, 'h w c -> 1 c h w')
|
||||
y = y.clone()
|
||||
# y = y.to(devices.get_device_for("controlnet"))
|
||||
y = y.to(device)
|
||||
y = y.clone()
|
||||
return y
|
||||
|
||||
def high_quality_resize(x: np.ndarray,
|
||||
size):
|
||||
# Written by lvmin
|
||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||
inpaint_mask = None
|
||||
if x.ndim == 3 and x.shape[2] == 4:
|
||||
inpaint_mask = x[:, :, 3]
|
||||
x = x[:, :, 0:3]
|
||||
|
||||
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
||||
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
||||
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
|
||||
is_one_pixel_edge = False
|
||||
is_binary = False
|
||||
if unique_color_count == 2:
|
||||
is_binary = np.min(x) < 16 and np.max(x) > 240
|
||||
if is_binary:
|
||||
xc = x
|
||||
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
|
||||
all_edge_count = np.where(x > 127)[0].shape[0]
|
||||
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
|
||||
|
||||
if 2 < unique_color_count < 200:
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
elif new_size_is_smaller:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
||||
|
||||
y = cv2.resize(x, size, interpolation=interpolation)
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
|
||||
|
||||
if is_binary:
|
||||
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
||||
if is_one_pixel_edge:
|
||||
y = nake_nms(y)
|
||||
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
y = lvmin_thin(y, prunings=new_size_is_bigger)
|
||||
else:
|
||||
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
y = np.stack([y] * 3, axis=2)
|
||||
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
|
||||
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
|
||||
y = np.concatenate([y, inpaint_mask], axis=2)
|
||||
|
||||
return y
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||
if resize_mode == "just_resize": # RESIZE
|
||||
np_img = high_quality_resize(np_img, (w, h))
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
|
||||
old_h, old_w, _ = np_img.shape
|
||||
old_w = float(old_w)
|
||||
old_h = float(old_h)
|
||||
k0 = float(h) / old_h
|
||||
k1 = float(w) / old_w
|
||||
|
||||
safeint = lambda x: int(np.round(x))
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
if resize_mode == "fill_resize": # OUTER_FIT
|
||||
k = min(k0, k1)
|
||||
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
|
||||
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
|
||||
if len(high_quality_border_color) == 4:
|
||||
# Inpaint hijack
|
||||
high_quality_border_color[3] = 255
|
||||
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (h - new_h) // 2)
|
||||
pad_w = max(0, (w - new_w) // 2)
|
||||
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
|
||||
np_img = high_quality_background
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||
k = max(k0, k1)
|
||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (new_h - h) // 2)
|
||||
pad_w = max(0, (new_w - w) // 2)
|
||||
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
|
||||
def prepare_control_image(
|
||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||
image: Image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
|
||||
width=512, # should be 8 * latent.shape[3]
|
||||
height=512, # should be 8 * latent height[2]
|
||||
# batch_size=1, # currently no batching
|
||||
# num_images_per_prompt=1, # currently only single image
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
control_mode="balanced",
|
||||
resize_mode="just_resize_simple",
|
||||
):
|
||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||
if (resize_mode == "just_resize_simple" or
|
||||
resize_mode == "crop_resize_simple" or
|
||||
resize_mode == "fill_resize_simple"):
|
||||
image = image.convert("RGB")
|
||||
if (resize_mode == "just_resize_simple"):
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
||||
pass
|
||||
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
||||
pass
|
||||
nimage = np.array(image)
|
||||
nimage = nimage[None, :]
|
||||
nimage = np.concatenate([nimage], axis=0)
|
||||
# normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
|
||||
nimage = np.array(nimage).astype(np.float32) / 255.0
|
||||
nimage = nimage.transpose(0, 3, 1, 2)
|
||||
timage = torch.from_numpy(nimage)
|
||||
|
||||
# use fancy lvmin controlnet resizing
|
||||
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
||||
nimage = np.array(image)
|
||||
timage, nimage = np_img_resize(
|
||||
np_img=nimage,
|
||||
resize_mode=resize_mode,
|
||||
h=height,
|
||||
w=width,
|
||||
# device=torch.device('cpu')
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
pass
|
||||
print("ERROR: invalid resize_mode ==> ", resize_mode)
|
||||
exit(1)
|
||||
|
||||
timage = timage.to(device=device, dtype=dtype)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
timage = torch.cat([timage] * 2)
|
||||
return timage
|
@ -219,6 +219,7 @@ class ControlNetData:
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
control_mode: str = Field(default="balanced")
|
||||
resize_mode: str = Field(default="just_resize")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if cfg_injection:
|
||||
# Inferred ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
# prepend zeros for unconditional batch
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
debug_image(
|
||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||
)
|
||||
|
||||
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||
@staticmethod
|
||||
def prepare_control_image(
|
||||
image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents,
|
||||
width=512, # should be 8 * latent.shape[3]
|
||||
height=512, # should be 8 * latent height[2]
|
||||
batch_size=1,
|
||||
num_images_per_prompt=1,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
control_mode="balanced"
|
||||
):
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
image = images
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
image = torch.cat([image] * 2)
|
||||
return image
|
||||
|
@ -17,13 +17,13 @@ import {
|
||||
} from 'common/components/IAIImageFallback';
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react';
|
||||
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import IAIDraggable from './IAIDraggable';
|
||||
import IAIDroppable from './IAIDroppable';
|
||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
|
||||
type IAIDndImageProps = {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
@ -148,7 +148,9 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
maxH: 'full',
|
||||
borderRadius: 'base',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
_dark: {
|
||||
shadow: isSelected ? 'selected.dark' : undefined,
|
||||
},
|
||||
...imageSx,
|
||||
}}
|
||||
/>
|
||||
@ -183,13 +185,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
</>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
{!isDropDisabled && (
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
@ -197,6 +192,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
onClick={onClick}
|
||||
/>
|
||||
)}
|
||||
{!isDropDisabled && (
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{onClickReset && withResetIcon && imageDTO && (
|
||||
<IAIIconButton
|
||||
onClick={onClickReset}
|
||||
|
@ -13,10 +13,11 @@ type IAIDroppableProps = {
|
||||
dropLabel?: ReactNode;
|
||||
disabled?: boolean;
|
||||
data?: TypesafeDroppableData;
|
||||
hoverRef?: React.Ref<HTMLDivElement>;
|
||||
};
|
||||
|
||||
const IAIDroppable = (props: IAIDroppableProps) => {
|
||||
const { dropLabel, data, disabled } = props;
|
||||
const { dropLabel, data, disabled, hoverRef } = props;
|
||||
const dndId = useRef(uuidv4());
|
||||
|
||||
const { isOver, setNodeRef, active } = useDroppable({
|
||||
|
@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
|
||||
|
||||
type ControlNetProps = {
|
||||
controlNetId: string;
|
||||
@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
gap: 3,
|
||||
p: 3,
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
@ -117,7 +118,12 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
onClick={toggleIsExpanded}
|
||||
variant="link"
|
||||
variant="ghost"
|
||||
sx={{
|
||||
_hover: {
|
||||
bg: 'none',
|
||||
},
|
||||
}}
|
||||
icon={
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
@ -151,7 +157,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
|
||||
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
@ -176,16 +182,16 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
h: 28,
|
||||
w: 28,
|
||||
aspectRatio: '1/1',
|
||||
mt: 3,
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<Box mt={2}>
|
||||
<Flex sx={{ gap: 2 }}>
|
||||
<ParamControlNetControlMode controlNetId={controlNetId} />
|
||||
</Box>
|
||||
<ParamControlNetResizeMode controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
|
||||
|
@ -0,0 +1,62 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
ResizeModes,
|
||||
controlNetResizeModeChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ParamControlNetResizeModeProps = {
|
||||
controlNetId: string;
|
||||
};
|
||||
|
||||
const RESIZE_MODE_DATA = [
|
||||
{ label: 'Resize', value: 'just_resize' },
|
||||
{ label: 'Crop', value: 'crop_resize' },
|
||||
{ label: 'Fill', value: 'fill_resize' },
|
||||
];
|
||||
|
||||
export default function ParamControlNetResizeMode(
|
||||
props: ParamControlNetResizeModeProps
|
||||
) {
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { resizeMode, isEnabled } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { resizeMode, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { resizeMode, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleResizeModeChange = useCallback(
|
||||
(resizeMode: ResizeModes) => {
|
||||
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
disabled={!isEnabled}
|
||||
label="Resize Mode"
|
||||
data={RESIZE_MODE_DATA}
|
||||
value={String(resizeMode)}
|
||||
onChange={handleResizeModeChange}
|
||||
/>
|
||||
);
|
||||
}
|
@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
|
||||
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { components } from 'services/api/schema';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
@ -16,11 +17,13 @@ import {
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
|
||||
export type ControlModes =
|
||||
| 'balanced'
|
||||
| 'more_prompt'
|
||||
| 'more_control'
|
||||
| 'unbalanced';
|
||||
export type ControlModes = NonNullable<
|
||||
components['schemas']['ControlNetInvocation']['control_mode']
|
||||
>;
|
||||
|
||||
export type ResizeModes = NonNullable<
|
||||
components['schemas']['ControlNetInvocation']['resize_mode']
|
||||
>;
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
@ -29,6 +32,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
controlMode: 'balanced',
|
||||
resizeMode: 'just_resize',
|
||||
controlImage: null,
|
||||
processedControlImage: null,
|
||||
processorType: 'canny_image_processor',
|
||||
@ -45,6 +49,7 @@ export type ControlNetConfig = {
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
controlMode: ControlModes;
|
||||
resizeMode: ResizeModes;
|
||||
controlImage: string | null;
|
||||
processedControlImage: string | null;
|
||||
processorType: ControlNetProcessorType;
|
||||
@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({
|
||||
const { controlNetId, controlMode } = action.payload;
|
||||
state.controlNets[controlNetId].controlMode = controlMode;
|
||||
},
|
||||
controlNetResizeModeChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
resizeMode: ResizeModes;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, resizeMode } = action.payload;
|
||||
state.controlNets[controlNetId].resizeMode = resizeMode;
|
||||
},
|
||||
controlNetProcessorParamsChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -342,6 +357,7 @@ export const {
|
||||
controlNetBeginStepPctChanged,
|
||||
controlNetEndStepPctChanged,
|
||||
controlNetControlModeChanged,
|
||||
controlNetResizeModeChanged,
|
||||
controlNetProcessorParamsChanged,
|
||||
controlNetProcessorTypeChanged,
|
||||
controlNetReset,
|
||||
|
@ -23,34 +23,32 @@ const BoardContextMenu = memo(
|
||||
dispatch(boardIdSelected(board?.board_id ?? board_id));
|
||||
}, [board?.board_id, board_id, dispatch]);
|
||||
return (
|
||||
<Box sx={{ touchAction: 'none', height: 'full' }}>
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
menuButtonProps={{
|
||||
bg: 'transparent',
|
||||
_hover: { bg: 'transparent' },
|
||||
}}
|
||||
renderMenu={() => (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
>
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleSelectBoard}>
|
||||
Select Board
|
||||
</MenuItem>
|
||||
{!board && <SystemBoardContextMenuItems board_id={board_id} />}
|
||||
{board && (
|
||||
<GalleryBoardContextMenuItems
|
||||
board={board}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
/>
|
||||
)}
|
||||
</MenuList>
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</ContextMenu>
|
||||
</Box>
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
menuButtonProps={{
|
||||
bg: 'transparent',
|
||||
_hover: { bg: 'transparent' },
|
||||
}}
|
||||
renderMenu={() => (
|
||||
<MenuList
|
||||
sx={{ visibility: 'visible !important' }}
|
||||
motionProps={menuListMotionProps}
|
||||
>
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleSelectBoard}>
|
||||
Select Board
|
||||
</MenuItem>
|
||||
{!board && <SystemBoardContextMenuItems board_id={board_id} />}
|
||||
{board && (
|
||||
<GalleryBoardContextMenuItems
|
||||
board={board}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
/>
|
||||
)}
|
||||
</MenuList>
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</ContextMenu>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
@ -1,27 +1,21 @@
|
||||
import {
|
||||
Collapse,
|
||||
Flex,
|
||||
Grid,
|
||||
GridItem,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { ButtonGroup, Collapse, Flex, Grid, GridItem } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useState } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { FaSearch } from 'react-icons/fa';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
import { BoardDTO } from 'services/api/types';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
import DeleteBoardModal from '../DeleteBoardModal';
|
||||
import AddBoardButton from './AddBoardButton';
|
||||
import AllAssetsBoard from './AllAssetsBoard';
|
||||
import AllImagesBoard from './AllImagesBoard';
|
||||
import BatchBoard from './BatchBoard';
|
||||
import BoardsSearch from './BoardsSearch';
|
||||
import GalleryBoard from './GalleryBoard';
|
||||
import NoBoardBoard from './NoBoardBoard';
|
||||
import DeleteBoardModal from '../DeleteBoardModal';
|
||||
import { BoardDTO } from 'services/api/types';
|
||||
import SystemBoardButton from './SystemBoardButton';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
@ -48,7 +42,10 @@ const BoardsList = (props: Props) => {
|
||||
)
|
||||
: boards;
|
||||
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
|
||||
const [searchMode, setSearchMode] = useState(false);
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
const handleClickSearchIcon = useCallback(() => {
|
||||
setIsSearching((v) => !v);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -64,7 +61,54 @@ const BoardsList = (props: Props) => {
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||
<BoardsSearch setSearchMode={setSearchMode} />
|
||||
<AnimatePresence mode="popLayout">
|
||||
{isSearching ? (
|
||||
<motion.div
|
||||
key="boards-search"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
style={{ width: '100%' }}
|
||||
>
|
||||
<BoardsSearch setIsSearching={setIsSearching} />
|
||||
</motion.div>
|
||||
) : (
|
||||
<motion.div
|
||||
key="system-boards-select"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
style={{ width: '100%' }}
|
||||
>
|
||||
<ButtonGroup sx={{ w: 'full', ps: 1.5 }} isAttached>
|
||||
<SystemBoardButton board_id="images" />
|
||||
<SystemBoardButton board_id="assets" />
|
||||
<SystemBoardButton board_id="no_board" />
|
||||
</ButtonGroup>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<IAIIconButton
|
||||
aria-label="Search Boards"
|
||||
size="sm"
|
||||
isChecked={isSearching}
|
||||
onClick={handleClickSearchIcon}
|
||||
icon={<FaSearch />}
|
||||
/>
|
||||
<AddBoardButton />
|
||||
</Flex>
|
||||
<OverlayScrollbarsComponent
|
||||
@ -82,29 +126,10 @@ const BoardsList = (props: Props) => {
|
||||
<Grid
|
||||
className="list-container"
|
||||
sx={{
|
||||
gridTemplateRows: '6.5rem 6.5rem',
|
||||
gridAutoFlow: 'column dense',
|
||||
gridAutoColumns: '5rem',
|
||||
gridTemplateColumns: `repeat(auto-fill, minmax(96px, 1fr));`,
|
||||
maxH: 346,
|
||||
}}
|
||||
>
|
||||
{!searchMode && (
|
||||
<>
|
||||
<GridItem sx={{ p: 1.5 }}>
|
||||
<AllImagesBoard isSelected={selectedBoardId === 'images'} />
|
||||
</GridItem>
|
||||
<GridItem sx={{ p: 1.5 }}>
|
||||
<AllAssetsBoard isSelected={selectedBoardId === 'assets'} />
|
||||
</GridItem>
|
||||
<GridItem sx={{ p: 1.5 }}>
|
||||
<NoBoardBoard isSelected={selectedBoardId === 'no_board'} />
|
||||
</GridItem>
|
||||
{isBatchEnabled && (
|
||||
<GridItem sx={{ p: 1.5 }}>
|
||||
<BatchBoard isSelected={selectedBoardId === 'batch'} />
|
||||
</GridItem>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{filteredBoards &&
|
||||
filteredBoards.map((board) => (
|
||||
<GridItem key={board.board_id} sx={{ p: 1.5 }}>
|
||||
|
@ -10,7 +10,14 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { setBoardSearchText } from 'features/gallery/store/boardSlice';
|
||||
import { memo } from 'react';
|
||||
import {
|
||||
ChangeEvent,
|
||||
KeyboardEvent,
|
||||
memo,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
} from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
@ -22,31 +29,60 @@ const selector = createSelector(
|
||||
);
|
||||
|
||||
type Props = {
|
||||
setSearchMode: (searchMode: boolean) => void;
|
||||
setIsSearching: (isSearching: boolean) => void;
|
||||
};
|
||||
|
||||
const BoardsSearch = (props: Props) => {
|
||||
const { setSearchMode } = props;
|
||||
const { setIsSearching } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { searchText } = useAppSelector(selector);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const handleBoardSearch = (searchTerm: string) => {
|
||||
setSearchMode(searchTerm.length > 0);
|
||||
dispatch(setBoardSearchText(searchTerm));
|
||||
};
|
||||
const clearBoardSearch = () => {
|
||||
setSearchMode(false);
|
||||
const handleBoardSearch = useCallback(
|
||||
(searchTerm: string) => {
|
||||
dispatch(setBoardSearchText(searchTerm));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const clearBoardSearch = useCallback(() => {
|
||||
dispatch(setBoardSearchText(''));
|
||||
};
|
||||
setIsSearching(false);
|
||||
}, [dispatch, setIsSearching]);
|
||||
|
||||
const handleKeydown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
// exit search mode on escape
|
||||
if (e.key === 'Escape') {
|
||||
clearBoardSearch();
|
||||
}
|
||||
},
|
||||
[clearBoardSearch]
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
handleBoardSearch(e.target.value);
|
||||
},
|
||||
[handleBoardSearch]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// focus the search box on mount
|
||||
if (!inputRef.current) {
|
||||
return;
|
||||
}
|
||||
inputRef.current.focus();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<InputGroup>
|
||||
<Input
|
||||
ref={inputRef}
|
||||
placeholder="Search Boards..."
|
||||
value={searchText}
|
||||
onChange={(e) => {
|
||||
handleBoardSearch(e.target.value);
|
||||
}}
|
||||
onKeyDown={handleKeydown}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
{searchText && searchText.length && (
|
||||
<InputRightElement>
|
||||
@ -55,7 +91,8 @@ const BoardsSearch = (props: Props) => {
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
aria-label="Clear Search"
|
||||
icon={<CloseIcon boxSize={3} />}
|
||||
opacity={0.5}
|
||||
icon={<CloseIcon boxSize={2} />}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
|
@ -6,9 +6,9 @@ import {
|
||||
EditableInput,
|
||||
EditablePreview,
|
||||
Flex,
|
||||
Icon,
|
||||
Image,
|
||||
Text,
|
||||
useColorMode,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
@ -17,14 +17,12 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { FaUser } from 'react-icons/fa';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { FaFolder } from 'react-icons/fa';
|
||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { BoardDTO } from 'services/api/types';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import BoardContextMenu from '../BoardContextMenu';
|
||||
|
||||
const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = {
|
||||
@ -66,8 +64,9 @@ const GalleryBoard = memo(
|
||||
board.cover_image_name ?? skipToken
|
||||
);
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const { board_name, board_id } = board;
|
||||
const [localBoardName, setLocalBoardName] = useState(board_name);
|
||||
|
||||
const handleSelectBoard = useCallback(() => {
|
||||
dispatch(boardIdSelected(board_id));
|
||||
}, [board_id, dispatch]);
|
||||
@ -75,10 +74,6 @@ const GalleryBoard = memo(
|
||||
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
||||
useUpdateBoardMutation();
|
||||
|
||||
const handleUpdateBoardName = (newBoardName: string) => {
|
||||
updateBoard({ board_id, changes: { board_name: newBoardName } });
|
||||
};
|
||||
|
||||
const droppableData: MoveBoardDropData = useMemo(
|
||||
() => ({
|
||||
id: board_id,
|
||||
@ -88,59 +83,116 @@ const GalleryBoard = memo(
|
||||
[board_id]
|
||||
);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
(newBoardName: string) => {
|
||||
if (!newBoardName) {
|
||||
// empty strings are not allowed
|
||||
setLocalBoardName(board_name);
|
||||
return;
|
||||
}
|
||||
if (newBoardName === board_name) {
|
||||
// don't updated the board name if it hasn't changed
|
||||
return;
|
||||
}
|
||||
updateBoard({ board_id, changes: { board_name: newBoardName } })
|
||||
.unwrap()
|
||||
.then((response) => {
|
||||
// update local state
|
||||
setLocalBoardName(response.board_name);
|
||||
})
|
||||
.catch(() => {
|
||||
// revert on error
|
||||
setLocalBoardName(board_name);
|
||||
});
|
||||
},
|
||||
[board_id, board_name, updateBoard]
|
||||
);
|
||||
|
||||
const handleChange = useCallback((newBoardName: string) => {
|
||||
setLocalBoardName(newBoardName);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Box sx={{ touchAction: 'none', height: 'full' }}>
|
||||
<BoardContextMenu
|
||||
board={board}
|
||||
board_id={board_id}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
<Box
|
||||
sx={{ w: 'full', h: 'full', touchAction: 'none', userSelect: 'none' }}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
aspectRatio: '1/1',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
key={board_id}
|
||||
userSelect="none"
|
||||
ref={ref}
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
cursor: 'pointer',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<BoardContextMenu
|
||||
board={board}
|
||||
board_id={board_id}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
overflow: 'hidden',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
flexShrink: 0,
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
>
|
||||
{board.cover_image_name && coverImage?.thumbnail_url && (
|
||||
<Image src={coverImage?.thumbnail_url} draggable={false} />
|
||||
)}
|
||||
{!(board.cover_image_name && coverImage?.thumbnail_url) && (
|
||||
<IAINoContentFallback
|
||||
boxSize={8}
|
||||
icon={FaUser}
|
||||
sx={{
|
||||
borderWidth: '2px',
|
||||
borderStyle: 'solid',
|
||||
borderColor: 'base.200',
|
||||
_dark: {
|
||||
borderColor: 'base.800',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.200',
|
||||
_dark: {
|
||||
bg: 'base.800',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{coverImage?.thumbnail_url ? (
|
||||
<Image
|
||||
src={coverImage?.thumbnail_url}
|
||||
draggable={false}
|
||||
sx={{
|
||||
maxW: 'full',
|
||||
maxH: 'full',
|
||||
borderRadius: 'base',
|
||||
borderBottomRadius: 'lg',
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
boxSize={12}
|
||||
as={FaFolder}
|
||||
sx={{
|
||||
mt: -3,
|
||||
opacity: 0.7,
|
||||
color: 'base.500',
|
||||
_dark: {
|
||||
color: 'base.500',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
@ -160,56 +212,86 @@ const GalleryBoard = memo(
|
||||
{board.image_count}
|
||||
</Badge>
|
||||
</Flex>
|
||||
<Box
|
||||
className="selection-box"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
bottom: 0,
|
||||
insetInlineStart: 0,
|
||||
borderRadius: 'base',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'common',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: {
|
||||
shadow: isSelected ? 'selected.dark' : undefined,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
bottom: 0,
|
||||
left: 0,
|
||||
p: 1,
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
w: 'full',
|
||||
maxW: 'full',
|
||||
borderBottomRadius: 'base',
|
||||
bg: isSelected ? 'accent.400' : 'base.500',
|
||||
color: isSelected ? 'base.50' : 'base.100',
|
||||
_dark: {
|
||||
bg: isSelected ? 'accent.500' : 'base.600',
|
||||
color: isSelected ? 'base.50' : 'base.100',
|
||||
},
|
||||
lineHeight: 'short',
|
||||
fontSize: 'xs',
|
||||
}}
|
||||
>
|
||||
<Editable
|
||||
value={localBoardName}
|
||||
isDisabled={isUpdateBoardLoading}
|
||||
submitOnBlur={true}
|
||||
onChange={handleChange}
|
||||
onSubmit={handleSubmit}
|
||||
sx={{
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<EditablePreview
|
||||
sx={{
|
||||
p: 0,
|
||||
fontWeight: isSelected ? 700 : 500,
|
||||
textAlign: 'center',
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
}}
|
||||
noOfLines={1}
|
||||
/>
|
||||
<EditableInput
|
||||
sx={{
|
||||
p: 0,
|
||||
_focusVisible: {
|
||||
p: 0,
|
||||
textAlign: 'center',
|
||||
// get rid of the edit border
|
||||
boxShadow: 'none',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</Editable>
|
||||
</Flex>
|
||||
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
dropLabel={<Text fontSize="md">Move</Text>}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<Editable
|
||||
defaultValue={board_name}
|
||||
submitOnBlur={true}
|
||||
onSubmit={(nextValue) => {
|
||||
handleUpdateBoardName(nextValue);
|
||||
}}
|
||||
sx={{ maxW: 'full' }}
|
||||
>
|
||||
<EditablePreview
|
||||
sx={{
|
||||
color: isSelected
|
||||
? mode('base.900', 'base.50')(colorMode)
|
||||
: mode('base.700', 'base.200')(colorMode),
|
||||
fontWeight: isSelected ? 600 : undefined,
|
||||
fontSize: 'xs',
|
||||
textAlign: 'center',
|
||||
p: 0,
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
}}
|
||||
noOfLines={1}
|
||||
/>
|
||||
<EditableInput
|
||||
sx={{
|
||||
color: mode('base.900', 'base.50')(colorMode),
|
||||
fontSize: 'xs',
|
||||
borderColor: mode('base.500', 'base.500')(colorMode),
|
||||
p: 0,
|
||||
outline: 0,
|
||||
}}
|
||||
/>
|
||||
</Editable>
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
</BoardContextMenu>
|
||||
)}
|
||||
</BoardContextMenu>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ type GenericBoardProps = {
|
||||
badgeCount?: number;
|
||||
};
|
||||
|
||||
const formatBadgeCount = (count: number) =>
|
||||
export const formatBadgeCount = (count: number) =>
|
||||
Intl.NumberFormat('en-US', {
|
||||
notation: 'compact',
|
||||
maximumFractionDigits: 1,
|
||||
@ -92,7 +92,7 @@ const GenericBoard = (props: GenericBoardProps) => {
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
fontWeight: isSelected ? 600 : undefined,
|
||||
fontSize: 'xs',
|
||||
fontSize: 'sm',
|
||||
color: isSelected ? 'base.900' : 'base.700',
|
||||
_dark: { color: isSelected ? 'base.50' : 'base.200' },
|
||||
}}
|
||||
|
@ -0,0 +1,53 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
type Props = {
|
||||
board_id: 'images' | 'assets' | 'no_board';
|
||||
};
|
||||
|
||||
const SystemBoardButton = ({ board_id }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
[stateSelector],
|
||||
({ gallery }) => {
|
||||
const { selectedBoardId } = gallery;
|
||||
return { isSelected: selectedBoardId === board_id };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[board_id]
|
||||
);
|
||||
|
||||
const { isSelected } = useAppSelector(selector);
|
||||
|
||||
const boardName = useBoardName(board_id);
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(boardIdSelected(board_id));
|
||||
}, [board_id, dispatch]);
|
||||
|
||||
return (
|
||||
<IAIButton
|
||||
onClick={handleClick}
|
||||
size="sm"
|
||||
isChecked={isSelected}
|
||||
sx={{
|
||||
flexGrow: 1,
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
{boardName}
|
||||
</IAIButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SystemBoardButton);
|
@ -4,8 +4,9 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
@ -26,6 +27,17 @@ const GalleryBoardName = (props: Props) => {
|
||||
const { isOpen, onToggle } = props;
|
||||
const { selectedBoardId } = useAppSelector(selector);
|
||||
const boardName = useBoardName(selectedBoardId);
|
||||
const numOfBoardImages = useBoardTotal(selectedBoardId);
|
||||
|
||||
const formattedBoardName = useMemo(() => {
|
||||
if (!boardName || !numOfBoardImages) {
|
||||
return '';
|
||||
}
|
||||
if (boardName.length > 20) {
|
||||
return `${boardName.substring(0, 20)}... (${numOfBoardImages})`;
|
||||
}
|
||||
return `${boardName} (${numOfBoardImages})`;
|
||||
}, [boardName, numOfBoardImages]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -58,7 +70,7 @@ const GalleryBoardName = (props: Props) => {
|
||||
},
|
||||
}}
|
||||
>
|
||||
{boardName}
|
||||
{formattedBoardName}
|
||||
</Text>
|
||||
</Box>
|
||||
<Spacer />
|
||||
|
@ -109,7 +109,7 @@ const GalleryDrawer = () => {
|
||||
isResizable={true}
|
||||
isOpen={shouldShowGallery}
|
||||
onClose={handleCloseGallery}
|
||||
minWidth={337}
|
||||
minWidth={400}
|
||||
>
|
||||
<ImageGalleryContent />
|
||||
</ResizableDrawer>
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
@ -86,38 +86,31 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
|
||||
return (
|
||||
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
|
||||
<ImageContextMenu imageDTO={imageDTO}>
|
||||
{(ref) => (
|
||||
<Box
|
||||
position="relative"
|
||||
key={imageName}
|
||||
userSelect="none"
|
||||
ref={ref}
|
||||
sx={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
onClick={handleClick}
|
||||
imageDTO={imageDTO}
|
||||
draggableData={draggableData}
|
||||
isSelected={isSelected}
|
||||
minSize={0}
|
||||
onClickReset={handleDelete}
|
||||
imageSx={{ w: 'full', h: 'full' }}
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
thumbnail={true}
|
||||
// resetIcon={<FaTrash />}
|
||||
// resetTooltip="Delete image"
|
||||
// withResetIcon // removed bc it's too easy to accidentally delete images
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</ImageContextMenu>
|
||||
<Flex
|
||||
userSelect="none"
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
onClick={handleClick}
|
||||
imageDTO={imageDTO}
|
||||
draggableData={draggableData}
|
||||
isSelected={isSelected}
|
||||
minSize={0}
|
||||
onClickReset={handleDelete}
|
||||
imageSx={{ w: 'full', h: 'full' }}
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
thumbnail={true}
|
||||
// resetIcon={<FaTrash />}
|
||||
// resetTooltip="Delete image"
|
||||
// withResetIcon // removed bc it's too easy to accidentally delete images
|
||||
/>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = (
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
resizeMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = (
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_mode: controlMode,
|
||||
resize_mode: resizeMode,
|
||||
control_model: model as ControlNetInvocation['control_model'],
|
||||
control_weight: weight,
|
||||
};
|
||||
|
@ -105,7 +105,7 @@ const enabledTabsSelector = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
const MIN_GALLERY_WIDTH = 300;
|
||||
const MIN_GALLERY_WIDTH = 350;
|
||||
const DEFAULT_GALLERY_PCT = 20;
|
||||
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];
|
||||
|
||||
|
@ -3,7 +3,7 @@ import { memo } from 'react';
|
||||
import { PanelResizeHandle } from 'react-resizable-panels';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
type ResizeHandleProps = FlexProps & {
|
||||
type ResizeHandleProps = Omit<FlexProps, 'direction'> & {
|
||||
direction?: 'horizontal' | 'vertical';
|
||||
};
|
||||
|
||||
|
@ -169,10 +169,11 @@ export const imagesApi = api.injectEndpoints({
|
||||
],
|
||||
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
|
||||
/**
|
||||
* Cache changes for deleteImage:
|
||||
* - Remove from "All Images"
|
||||
* - Remove from image's `board_id` if it has one, or "No Board" if not
|
||||
* - Remove from "Batch"
|
||||
* Cache changes for `deleteImage`:
|
||||
* - *remove* from "All Images" / "All Assets"
|
||||
* - IF it has a board:
|
||||
* - THEN *remove* from it's own board
|
||||
* - ELSE *remove* from "No Board"
|
||||
*/
|
||||
|
||||
const { image_name, board_id, image_category } = imageDTO;
|
||||
@ -181,22 +182,23 @@ export const imagesApi = api.injectEndpoints({
|
||||
// That means constructing the possible query args that are serialized into the cache key...
|
||||
|
||||
const removeFromCacheKeys: ListImagesArgs[] = [];
|
||||
|
||||
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES;
|
||||
|
||||
// All Images board (e.g. no board)
|
||||
// remove from "All Images"
|
||||
removeFromCacheKeys.push({ categories });
|
||||
|
||||
// Board specific
|
||||
if (board_id) {
|
||||
// remove from it's own board
|
||||
removeFromCacheKeys.push({ board_id });
|
||||
} else {
|
||||
// TODO: No Board
|
||||
// remove from "No Board"
|
||||
removeFromCacheKeys.push({ board_id: 'none' });
|
||||
}
|
||||
|
||||
// TODO: Batch
|
||||
|
||||
const patches: PatchCollection[] = [];
|
||||
removeFromCacheKeys.forEach((cacheKey) => {
|
||||
patches.push(
|
||||
@ -240,32 +242,37 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ imageDTO: oldImageDTO, changes: _changes },
|
||||
{ dispatch, queryFulfilled, getState }
|
||||
) {
|
||||
// TODO: Should we handle changes to boards via this mutation? Seems reasonable...
|
||||
|
||||
// let's be extra-sure we do not accidentally change categories
|
||||
const changes = omit(_changes, 'image_category');
|
||||
|
||||
/**
|
||||
* Cache changes for `updateImage`:
|
||||
* - Update the ImageDTO
|
||||
* - Update the image in "All Images" board:
|
||||
* - IF it is in the date range represented by the cache:
|
||||
* - add the image IF it is not already in the cache & update the total
|
||||
* - ELSE update the image IF it is already in the cache
|
||||
* Cache changes for "updateImage":
|
||||
* - *update* "getImageDTO" cache
|
||||
* - for "All Images" || "All Assets":
|
||||
* - IF it is not already in the cache
|
||||
* - THEN *add* it to "All Images" / "All Assets" and update the total
|
||||
* - ELSE *update* it
|
||||
* - IF the image has a board:
|
||||
* - Update the image in it's own board
|
||||
* - ELSE Update the image in the "No Board" board (TODO)
|
||||
* - THEN *update* it's own board
|
||||
* - ELSE *update* the "No Board" board
|
||||
*/
|
||||
|
||||
const patches: PatchCollection[] = [];
|
||||
const { image_name, board_id, image_category } = oldImageDTO;
|
||||
const { image_name, board_id, image_category, is_intermediate } =
|
||||
oldImageDTO;
|
||||
|
||||
const isChangingFromIntermediate = changes.is_intermediate === false;
|
||||
// do not add intermediates to gallery cache
|
||||
if (is_intermediate && !isChangingFromIntermediate) {
|
||||
return;
|
||||
}
|
||||
|
||||
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES;
|
||||
|
||||
// TODO: No Board
|
||||
|
||||
// Update `getImageDTO` cache
|
||||
// update `getImageDTO` cache
|
||||
patches.push(
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData(
|
||||
@ -281,9 +288,13 @@ export const imagesApi = api.injectEndpoints({
|
||||
// Update the "All Image" or "All Assets" board
|
||||
const queryArgsToUpdate: ListImagesArgs[] = [{ categories }];
|
||||
|
||||
// IF the image has a board:
|
||||
if (board_id) {
|
||||
// We also need to update the user board
|
||||
// THEN update it's own board
|
||||
queryArgsToUpdate.push({ board_id });
|
||||
} else {
|
||||
// ELSE update the "No Board" board
|
||||
queryArgsToUpdate.push({ board_id: 'none' });
|
||||
}
|
||||
|
||||
queryArgsToUpdate.forEach((queryArg) => {
|
||||
@ -371,12 +382,12 @@ export const imagesApi = api.injectEndpoints({
|
||||
return;
|
||||
}
|
||||
|
||||
// Add the image to the "All Images" / "All Assets" board
|
||||
const queryArg = {
|
||||
categories: IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES,
|
||||
};
|
||||
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES;
|
||||
|
||||
const queryArg = { categories };
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
|
||||
@ -410,16 +421,14 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ dispatch, queryFulfilled, getState }
|
||||
) {
|
||||
/**
|
||||
* Cache changes for addImageToBoard:
|
||||
* - Remove from "No Board"
|
||||
* - Remove from `old_board_id` if it has one
|
||||
* - Add to new `board_id`
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* Cache changes for `addImageToBoard`:
|
||||
* - *update* the `getImageDTO` cache
|
||||
* - *remove* from "No Board"
|
||||
* - IF the image has an old `board_id`:
|
||||
* - THEN *remove* from it's old `board_id`
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* - OR the board cache has length of 0 or 1
|
||||
* - Update the `total` for each board whose cache is updated
|
||||
* - Update the ImageDTO
|
||||
*
|
||||
* TODO: maybe total should just be updated in the boards endpoints?
|
||||
* - THEN *add* it to new `board_id`
|
||||
*/
|
||||
|
||||
const { image_name, board_id: old_board_id } = oldImageDTO;
|
||||
@ -427,13 +436,10 @@ export const imagesApi = api.injectEndpoints({
|
||||
// Figure out the `listImages` caches that we need to update
|
||||
const removeFromQueryArgs: ListImagesArgs[] = [];
|
||||
|
||||
// TODO: No Board
|
||||
// TODO: Batch
|
||||
|
||||
// Remove from No Board
|
||||
// remove from "No Board"
|
||||
removeFromQueryArgs.push({ board_id: 'none' });
|
||||
|
||||
// Remove from old board
|
||||
// remove from old board
|
||||
if (old_board_id) {
|
||||
removeFromQueryArgs.push({ board_id: old_board_id });
|
||||
}
|
||||
@ -534,17 +540,15 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ dispatch, queryFulfilled, getState }
|
||||
) {
|
||||
/**
|
||||
* Cache changes for removeImageFromBoard:
|
||||
* - Add to "No Board"
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* - Remove from `old_board_id`
|
||||
* - Update the ImageDTO
|
||||
* Cache changes for `removeImageFromBoard`:
|
||||
* - *update* `getImageDTO`
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* - THEN *add* to "No Board"
|
||||
* - *remove* from `old_board_id`
|
||||
*/
|
||||
|
||||
const { image_name, board_id: old_board_id } = imageDTO;
|
||||
|
||||
// TODO: Batch
|
||||
|
||||
const patches: PatchCollection[] = [];
|
||||
|
||||
// Updated imageDTO with new board_id
|
||||
|
@ -6,9 +6,9 @@ export const useBoardName = (board_id: BoardId | null | undefined) => {
|
||||
selectFromResult: ({ data }) => {
|
||||
let boardName = '';
|
||||
if (board_id === 'images') {
|
||||
boardName = 'All Images';
|
||||
boardName = 'Images';
|
||||
} else if (board_id === 'assets') {
|
||||
boardName = 'All Assets';
|
||||
boardName = 'Assets';
|
||||
} else if (board_id === 'no_board') {
|
||||
boardName = 'No Board';
|
||||
} else if (board_id === 'batch') {
|
||||
|
@ -0,0 +1,53 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import {
|
||||
ASSETS_CATEGORIES,
|
||||
BoardId,
|
||||
IMAGE_CATEGORIES,
|
||||
INITIAL_IMAGE_LIMIT,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { useMemo } from 'react';
|
||||
import { ListImagesArgs, useListImagesQuery } from '../endpoints/images';
|
||||
|
||||
const baseQueryArg: ListImagesArgs = {
|
||||
offset: 0,
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
is_intermediate: false,
|
||||
};
|
||||
|
||||
const imagesQueryArg: ListImagesArgs = {
|
||||
categories: IMAGE_CATEGORIES,
|
||||
...baseQueryArg,
|
||||
};
|
||||
|
||||
const assetsQueryArg: ListImagesArgs = {
|
||||
categories: ASSETS_CATEGORIES,
|
||||
...baseQueryArg,
|
||||
};
|
||||
|
||||
const noBoardQueryArg: ListImagesArgs = {
|
||||
board_id: 'none',
|
||||
...baseQueryArg,
|
||||
};
|
||||
|
||||
export const useBoardTotal = (board_id: BoardId | null | undefined) => {
|
||||
const queryArg = useMemo(() => {
|
||||
if (!board_id) {
|
||||
return;
|
||||
}
|
||||
if (board_id === 'images') {
|
||||
return imagesQueryArg;
|
||||
} else if (board_id === 'assets') {
|
||||
return assetsQueryArg;
|
||||
} else if (board_id === 'no_board') {
|
||||
return noBoardQueryArg;
|
||||
} else {
|
||||
return { board_id, ...baseQueryArg };
|
||||
}
|
||||
}, [board_id]);
|
||||
|
||||
const { total } = useListImagesQuery(queryArg ?? skipToken, {
|
||||
selectFromResult: ({ currentData }) => ({ total: currentData?.total }),
|
||||
});
|
||||
|
||||
return total;
|
||||
};
|
@ -167,7 +167,7 @@ export type paths = {
|
||||
"/api/v1/images/clear-intermediates": {
|
||||
/**
|
||||
* Clear Intermediates
|
||||
* @description Clears first 100 intermediates
|
||||
* @description Clears all intermediates
|
||||
*/
|
||||
post: operations["clear_intermediates"];
|
||||
};
|
||||
@ -800,6 +800,13 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||
/**
|
||||
* Resize Mode
|
||||
* @description The resize mode to use
|
||||
* @default just_resize
|
||||
* @enum {string}
|
||||
*/
|
||||
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
|
||||
};
|
||||
/**
|
||||
* ControlNetInvocation
|
||||
@ -859,6 +866,13 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||
/**
|
||||
* Resize Mode
|
||||
* @description The resize mode used
|
||||
* @default just_resize
|
||||
* @enum {string}
|
||||
*/
|
||||
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
|
||||
};
|
||||
/** ControlNetModelConfig */
|
||||
ControlNetModelConfig: {
|
||||
@ -5324,11 +5338,11 @@ export type components = {
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
@ -5336,11 +5350,11 @@ export type components = {
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
@ -6125,7 +6139,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Clear Intermediates
|
||||
* @description Clears first 100 intermediates
|
||||
* @description Clears all intermediates
|
||||
*/
|
||||
clear_intermediates: {
|
||||
responses: {
|
||||
|
@ -2,11 +2,16 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
|
||||
import { generateColorPalette } from 'theme/util/generateColorPalette';
|
||||
|
||||
const BASE = { H: 220, S: 16 };
|
||||
const ACCENT = { H: 250, S: 52 };
|
||||
const WORKING = { H: 47, S: 50 };
|
||||
const WARNING = { H: 28, S: 50 };
|
||||
const OK = { H: 113, S: 50 };
|
||||
const ERROR = { H: 0, S: 50 };
|
||||
const ACCENT = { H: 250, S: 42 };
|
||||
// const ACCENT = { H: 250, S: 52 };
|
||||
const WORKING = { H: 47, S: 42 };
|
||||
// const WORKING = { H: 47, S: 50 };
|
||||
const WARNING = { H: 28, S: 42 };
|
||||
// const WARNING = { H: 28, S: 50 };
|
||||
const OK = { H: 113, S: 42 };
|
||||
// const OK = { H: 113, S: 50 };
|
||||
const ERROR = { H: 0, S: 42 };
|
||||
// const ERROR = { H: 0, S: 50 };
|
||||
|
||||
export const InvokeAIColors: InvokeAIThemeColors = {
|
||||
base: generateColorPalette(BASE.H, BASE.S),
|
||||
|
56
invokeai/frontend/web/src/theme/components/editable.ts
Normal file
56
invokeai/frontend/web/src/theme/components/editable.ts
Normal file
@ -0,0 +1,56 @@
|
||||
import { editableAnatomy as parts } from '@chakra-ui/anatomy';
|
||||
import {
|
||||
createMultiStyleConfigHelpers,
|
||||
defineStyle,
|
||||
} from '@chakra-ui/styled-system';
|
||||
import { mode } from '@chakra-ui/theme-tools';
|
||||
|
||||
const { definePartsStyle, defineMultiStyleConfig } =
|
||||
createMultiStyleConfigHelpers(parts.keys);
|
||||
|
||||
const baseStylePreview = defineStyle({
|
||||
borderRadius: 'md',
|
||||
py: '1',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
});
|
||||
|
||||
const baseStyleInput = defineStyle((props) => ({
|
||||
borderRadius: 'md',
|
||||
py: '1',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
width: 'full',
|
||||
_focusVisible: { boxShadow: 'outline' },
|
||||
_placeholder: { opacity: 0.6 },
|
||||
'::selection': {
|
||||
color: mode('accent.900', 'accent.50')(props),
|
||||
bg: mode('accent.200', 'accent.400')(props),
|
||||
},
|
||||
}));
|
||||
|
||||
const baseStyleTextarea = defineStyle({
|
||||
borderRadius: 'md',
|
||||
py: '1',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
width: 'full',
|
||||
_focusVisible: { boxShadow: 'outline' },
|
||||
_placeholder: { opacity: 0.6 },
|
||||
});
|
||||
|
||||
const invokeAI = definePartsStyle((props) => ({
|
||||
preview: baseStylePreview,
|
||||
input: baseStyleInput(props),
|
||||
textarea: baseStyleTextarea,
|
||||
}));
|
||||
|
||||
export const editableTheme = defineMultiStyleConfig({
|
||||
variants: {
|
||||
invokeAI,
|
||||
},
|
||||
defaultProps: {
|
||||
size: 'sm',
|
||||
variant: 'invokeAI',
|
||||
},
|
||||
});
|
@ -4,6 +4,7 @@ import { InvokeAIColors } from './colors/colors';
|
||||
import { accordionTheme } from './components/accordion';
|
||||
import { buttonTheme } from './components/button';
|
||||
import { checkboxTheme } from './components/checkbox';
|
||||
import { editableTheme } from './components/editable';
|
||||
import { formLabelTheme } from './components/formLabel';
|
||||
import { inputTheme } from './components/input';
|
||||
import { menuTheme } from './components/menu';
|
||||
@ -72,7 +73,17 @@ export const theme: ThemeOverride = {
|
||||
selected: {
|
||||
light:
|
||||
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
|
||||
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
|
||||
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
|
||||
},
|
||||
hoverSelected: {
|
||||
light:
|
||||
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
|
||||
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-300)',
|
||||
},
|
||||
hoverUnselected: {
|
||||
light:
|
||||
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-200)',
|
||||
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`,
|
||||
},
|
||||
@ -80,6 +91,7 @@ export const theme: ThemeOverride = {
|
||||
components: {
|
||||
Button: buttonTheme, // Button and IconButton
|
||||
Input: inputTheme,
|
||||
Editable: editableTheme,
|
||||
Textarea: textareaTheme,
|
||||
Tabs: tabsTheme,
|
||||
Progress: progressTheme,
|
||||
|
@ -37,4 +37,7 @@ export const getInputOutlineStyles = (props: StyleFunctionProps) => ({
|
||||
_placeholder: {
|
||||
color: mode('base.700', 'base.400')(props),
|
||||
},
|
||||
'::selection': {
|
||||
bg: mode('accent.200', 'accent.400')(props),
|
||||
},
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user