Merge branch 'main' into lstein/logger-route

This commit is contained in:
Lincoln Stein 2023-07-20 11:29:48 -04:00 committed by GitHub
commit 5134de7cfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1098 additions and 360 deletions

View File

@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple( CONTROLNET_MODE_VALUES = Literal[tuple(
["balanced", "more_prompt", "more_control", "unbalanced"])] ["balanced", "more_prompt", "more_control", "unbalanced"])]
# crop and fill options not ready yet CONTROLNET_RESIZE_VALUES = Literal[tuple(
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] ["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
class ControlNetModelField(BaseModel): class ControlNetModelField(BaseModel):
@ -111,7 +111,8 @@ class ControlField(BaseModel):
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field( control_mode: CONTROLNET_MODE_VALUES = Field(
default="balanced", description="The control mode to use") default="balanced", description="The control mode to use")
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(
default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def validate_control_weight(cls, v): def validate_control_weight(cls, v):
@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation):
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
control_mode=self.control_mode, control_mode=self.control_mode,
resize_mode=self.resize_mode,
), ),
) )

View File

@ -30,6 +30,7 @@ from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .image import ImageOutput from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation):
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance? # and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image( control_image = prepare_control_image(
image=input_image, image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize, width=control_width_resize,
@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation):
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
) )
control_item = ControlNetData( control_item = ControlNetData(
model=control_model, image_tensor=control_image, model=control_model,
image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
) )
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]

View File

@ -0,0 +1,342 @@
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers.utils import PIL_INTERPOLATION
from einops import rearrange
from controlnet_aux.util import HWC3, resize_image
###################################################################
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
###################################################################
# High Quality Edge Thinning using Pure Python
# Written by Lvmin Zhangu
# 2023 April
# Stanford University
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
lvmin_kernels_raw = [
np.array([
[-1, -1, -1],
[0, 1, 0],
[1, 1, 1]
], dtype=np.int32),
np.array([
[0, -1, -1],
[1, 1, -1],
[0, 1, 0]
], dtype=np.int32)
]
lvmin_kernels = []
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_prunings_raw = [
np.array([
[-1, -1, -1],
[-1, 1, -1],
[0, 0, -1]
], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[-1, 0, 0]
], dtype=np.int32)
]
lvmin_prunings = []
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
def remove_pattern(x, kernel):
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
objects = np.where(objects > 127)
x[objects] = 0
return x, objects[0].shape[0] > 0
def thin_one_time(x, kernels):
y = x
is_done = True
for k in kernels:
y, has_update = remove_pattern(y, k)
if has_update:
is_done = False
return y, is_done
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break
if prunings:
y, _ = thin_one_time(y, lvmin_prunings)
return y
def nake_nms(x):
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
return y
################################################################################
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
################################################################################
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
def pixel_perfect_resolution(
image: np.ndarray,
target_H: int,
target_W: int,
resize_mode: str,
) -> int:
"""
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
The function first calculates scaling factors for height and width of the image based on the target
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
scaling factor to estimate the new resolution.
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
fits within the target dimensions, potentially leaving some empty space.
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
dimensions are fully filled, potentially cropping the image.
After calculating the estimated resolution, the function prints some debugging information.
Args:
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
target_H (int): The target height for the image.
target_W (int): The target width for the image.
resize_mode (ResizeMode): The mode for resizing.
Returns:
int: The estimated resolution after resizing.
"""
raw_H, raw_W, _ = image.shape
k0 = float(target_H) / float(raw_H)
k1 = float(target_W) / float(raw_W)
if resize_mode == "fill_resize":
estimation = min(k0, k1) * float(min(raw_H, raw_W))
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
estimation = max(k0, k1) * float(min(raw_H, raw_W))
# print(f"Pixel Perfect Computation:")
# print(f"resize_mode = {resize_mode}")
# print(f"raw_H = {raw_H}")
# print(f"raw_W = {raw_W}")
# print(f"target_H = {target_H}")
# print(f"target_W = {target_W}")
# print(f"estimation = {estimation}")
return int(np.round(estimation))
###########################################################################
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
# modified for InvokeAI
###########################################################################
# def detectmap_proc(detected_map, module, resize_mode, h, w):
def np_img_resize(
np_img: np.ndarray,
resize_mode: str,
h: int,
w: int,
device: torch.device = torch.device('cpu')
):
# if 'inpaint' in module:
# np_img = np_img.astype(np.float32)
# else:
# np_img = HWC3(np_img)
np_img = HWC3(np_img)
def safe_numpy(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = y.copy()
y = np.ascontiguousarray(y)
y = y.copy()
return y
def get_pytorch_control(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = torch.from_numpy(y)
y = y.float() / 255.0
y = rearrange(y, 'h w c -> 1 c h w')
y = y.clone()
# y = y.to(devices.get_device_for("controlnet"))
y = y.to(device)
y = y.clone()
return y
def high_quality_resize(x: np.ndarray,
size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
is_binary = np.min(x) < 16 and np.max(x) > 240
if is_binary:
xc = x
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
all_edge_count = np.where(x > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
y = nake_nms(y)
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = lvmin_thin(y, prunings=new_size_is_bigger)
else:
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = np.stack([y] * 3, axis=2)
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y
# if resize_mode == external_code.ResizeMode.RESIZE:
if resize_mode == "just_resize": # RESIZE
np_img = high_quality_resize(np_img, (w, h))
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
old_h, old_w, _ = np_img.shape
old_w = float(old_w)
old_h = float(old_h)
k0 = float(h) / old_h
k1 = float(w) / old_w
safeint = lambda x: int(np.round(x))
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == "fill_resize": # OUTER_FIT
k = min(k0, k1)
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
if len(high_quality_border_color) == 4:
# Inpaint hijack
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
np_img = high_quality_background
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1)
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
image: Image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
# batch_size=1, # currently no batching
# num_images_per_prompt=1, # currently only single image
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced",
resize_mode="just_resize_simple",
):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
if (resize_mode == "just_resize_simple" or
resize_mode == "crop_resize_simple" or
resize_mode == "fill_resize_simple"):
image = image.convert("RGB")
if (resize_mode == "just_resize_simple"):
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif (resize_mode == "crop_resize_simple"): # not yet implemented
pass
elif (resize_mode == "fill_resize_simple"): # not yet implemented
pass
nimage = np.array(image)
nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0)
# normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
nimage = np.array(nimage).astype(np.float32) / 255.0
nimage = nimage.transpose(0, 3, 1, 2)
timage = torch.from_numpy(nimage)
# use fancy lvmin controlnet resizing
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
nimage = np.array(image)
timage, nimage = np_img_resize(
np_img=nimage,
resize_mode=resize_mode,
h=height,
w=width,
# device=torch.device('cpu')
device=device,
)
else:
pass
print("ERROR: invalid resize_mode ==> ", resize_mode)
exit(1)
timage = timage.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
timage = torch.cat([timage] * 2)
return timage

View File

@ -219,6 +219,7 @@ class ControlNetData:
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced") control_mode: str = Field(default="balanced")
resize_mode: str = Field(default="just_resize")
@dataclass @dataclass
@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection: if cfg_injection:
# Inferred ControlNet only for the conditional batch. # Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
@staticmethod
def prepare_control_image(
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
return image

View File

@ -17,13 +17,13 @@ import {
} from 'common/components/IAIImageFallback'; } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react'; import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa'; import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO, PostUploadAction } from 'services/api/types'; import { ImageDTO, PostUploadAction } from 'services/api/types';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable'; import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable'; import IAIDroppable from './IAIDroppable';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
type IAIDndImageProps = { type IAIDndImageProps = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
@ -148,7 +148,9 @@ const IAIDndImage = (props: IAIDndImageProps) => {
maxH: 'full', maxH: 'full',
borderRadius: 'base', borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined, shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined }, _dark: {
shadow: isSelected ? 'selected.dark' : undefined,
},
...imageSx, ...imageSx,
}} }}
/> />
@ -183,13 +185,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} {!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && ( {imageDTO && !isDragDisabled && (
<IAIDraggable <IAIDraggable
data={draggableData} data={draggableData}
@ -197,6 +192,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick} onClick={onClick}
/> />
)} )}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{onClickReset && withResetIcon && imageDTO && ( {onClickReset && withResetIcon && imageDTO && (
<IAIIconButton <IAIIconButton
onClick={onClickReset} onClick={onClickReset}

View File

@ -13,10 +13,11 @@ type IAIDroppableProps = {
dropLabel?: ReactNode; dropLabel?: ReactNode;
disabled?: boolean; disabled?: boolean;
data?: TypesafeDroppableData; data?: TypesafeDroppableData;
hoverRef?: React.Ref<HTMLDivElement>;
}; };
const IAIDroppable = (props: IAIDroppableProps) => { const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props; const { dropLabel, data, disabled, hoverRef } = props;
const dndId = useRef(uuidv4()); const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppable({ const { isOver, setNodeRef, active } = useDroppable({

View File

@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
type ControlNetProps = { type ControlNetProps = {
controlNetId: string; controlNetId: string;
@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => {
<Flex <Flex
sx={{ sx={{
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 3,
p: 3, p: 3,
borderRadius: 'base', borderRadius: 'base',
position: 'relative', position: 'relative',
@ -117,7 +118,12 @@ const ControlNet = (props: ControlNetProps) => {
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'} tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'} aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded} onClick={toggleIsExpanded}
variant="link" variant="ghost"
sx={{
_hover: {
bg: 'none',
},
}}
icon={ icon={
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
@ -151,7 +157,7 @@ const ControlNet = (props: ControlNetProps) => {
/> />
)} )}
</Flex> </Flex>
<Flex sx={{ w: 'full', flexDirection: 'column' }}> <Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}> <Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex <Flex
sx={{ sx={{
@ -176,16 +182,16 @@ const ControlNet = (props: ControlNetProps) => {
h: 28, h: 28,
w: 28, w: 28,
aspectRatio: '1/1', aspectRatio: '1/1',
mt: 3,
}} }}
> >
<ControlNetImagePreview controlNetId={controlNetId} height={28} /> <ControlNetImagePreview controlNetId={controlNetId} height={28} />
</Flex> </Flex>
)} )}
</Flex> </Flex>
<Box mt={2}> <Flex sx={{ gap: 2 }}>
<ParamControlNetControlMode controlNetId={controlNetId} /> <ParamControlNetControlMode controlNetId={controlNetId} />
</Box> <ParamControlNetResizeMode controlNetId={controlNetId} />
</Flex>
<ParamControlNetProcessorSelect controlNetId={controlNetId} /> <ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex> </Flex>

View File

@ -0,0 +1,62 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ResizeModes,
controlNetResizeModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetResizeModeProps = {
controlNetId: string;
};
const RESIZE_MODE_DATA = [
{ label: 'Resize', value: 'just_resize' },
{ label: 'Crop', value: 'crop_resize' },
{ label: 'Fill', value: 'fill_resize' },
];
export default function ParamControlNetResizeMode(
props: ParamControlNetResizeModeProps
) {
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { resizeMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { resizeMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { resizeMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation();
const handleResizeModeChange = useCallback(
(resizeMode: ResizeModes) => {
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
},
[controlNetId, dispatch]
);
return (
<IAIMantineSelect
disabled={!isEnabled}
label="Resize Mode"
data={RESIZE_MODE_DATA}
value={String(resizeMode)}
onChange={handleResizeModeChange}
/>
);
}

View File

@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es'; import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session'; import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions'; import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions'; import { controlNetImageProcessed } from './actions';
@ -16,11 +17,13 @@ import {
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
export type ControlModes = export type ControlModes = NonNullable<
| 'balanced' components['schemas']['ControlNetInvocation']['control_mode']
| 'more_prompt' >;
| 'more_control'
| 'unbalanced'; export type ResizeModes = NonNullable<
components['schemas']['ControlNetInvocation']['resize_mode']
>;
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
@ -29,6 +32,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
controlMode: 'balanced', controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null, controlImage: null,
processedControlImage: null, processedControlImage: null,
processorType: 'canny_image_processor', processorType: 'canny_image_processor',
@ -45,6 +49,7 @@ export type ControlNetConfig = {
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
controlMode: ControlModes; controlMode: ControlModes;
resizeMode: ResizeModes;
controlImage: string | null; controlImage: string | null;
processedControlImage: string | null; processedControlImage: string | null;
processorType: ControlNetProcessorType; processorType: ControlNetProcessorType;
@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({
const { controlNetId, controlMode } = action.payload; const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode; state.controlNets[controlNetId].controlMode = controlMode;
}, },
controlNetResizeModeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
resizeMode: ResizeModes;
}>
) => {
const { controlNetId, resizeMode } = action.payload;
state.controlNets[controlNetId].resizeMode = resizeMode;
},
controlNetProcessorParamsChanged: ( controlNetProcessorParamsChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -342,6 +357,7 @@ export const {
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
controlNetControlModeChanged, controlNetControlModeChanged,
controlNetResizeModeChanged,
controlNetProcessorParamsChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged, controlNetProcessorTypeChanged,
controlNetReset, controlNetReset,

View File

@ -23,7 +23,6 @@ const BoardContextMenu = memo(
dispatch(boardIdSelected(board?.board_id ?? board_id)); dispatch(boardIdSelected(board?.board_id ?? board_id));
}, [board?.board_id, board_id, dispatch]); }, [board?.board_id, board_id, dispatch]);
return ( return (
<Box sx={{ touchAction: 'none', height: 'full' }}>
<ContextMenu<HTMLDivElement> <ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
menuButtonProps={{ menuButtonProps={{
@ -50,7 +49,6 @@ const BoardContextMenu = memo(
> >
{children} {children}
</ContextMenu> </ContextMenu>
</Box>
); );
} }
); );

View File

@ -1,27 +1,21 @@
import { import { ButtonGroup, Collapse, Flex, Grid, GridItem } from '@chakra-ui/react';
Collapse,
Flex,
Grid,
GridItem,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import { AnimatePresence, motion } from 'framer-motion';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useState } from 'react'; import { memo, useCallback, useState } from 'react';
import { FaSearch } from 'react-icons/fa';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { BoardDTO } from 'services/api/types';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
import DeleteBoardModal from '../DeleteBoardModal';
import AddBoardButton from './AddBoardButton'; import AddBoardButton from './AddBoardButton';
import AllAssetsBoard from './AllAssetsBoard';
import AllImagesBoard from './AllImagesBoard';
import BatchBoard from './BatchBoard';
import BoardsSearch from './BoardsSearch'; import BoardsSearch from './BoardsSearch';
import GalleryBoard from './GalleryBoard'; import GalleryBoard from './GalleryBoard';
import NoBoardBoard from './NoBoardBoard'; import SystemBoardButton from './SystemBoardButton';
import DeleteBoardModal from '../DeleteBoardModal';
import { BoardDTO } from 'services/api/types';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -48,7 +42,10 @@ const BoardsList = (props: Props) => {
) )
: boards; : boards;
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>(); const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
const [searchMode, setSearchMode] = useState(false); const [isSearching, setIsSearching] = useState(false);
const handleClickSearchIcon = useCallback(() => {
setIsSearching((v) => !v);
}, []);
return ( return (
<> <>
@ -64,7 +61,54 @@ const BoardsList = (props: Props) => {
}} }}
> >
<Flex sx={{ gap: 2, alignItems: 'center' }}> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<BoardsSearch setSearchMode={setSearchMode} /> <AnimatePresence mode="popLayout">
{isSearching ? (
<motion.div
key="boards-search"
initial={{
opacity: 0,
}}
exit={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
style={{ width: '100%' }}
>
<BoardsSearch setIsSearching={setIsSearching} />
</motion.div>
) : (
<motion.div
key="system-boards-select"
initial={{
opacity: 0,
}}
exit={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
style={{ width: '100%' }}
>
<ButtonGroup sx={{ w: 'full', ps: 1.5 }} isAttached>
<SystemBoardButton board_id="images" />
<SystemBoardButton board_id="assets" />
<SystemBoardButton board_id="no_board" />
</ButtonGroup>
</motion.div>
)}
</AnimatePresence>
<IAIIconButton
aria-label="Search Boards"
size="sm"
isChecked={isSearching}
onClick={handleClickSearchIcon}
icon={<FaSearch />}
/>
<AddBoardButton /> <AddBoardButton />
</Flex> </Flex>
<OverlayScrollbarsComponent <OverlayScrollbarsComponent
@ -82,29 +126,10 @@ const BoardsList = (props: Props) => {
<Grid <Grid
className="list-container" className="list-container"
sx={{ sx={{
gridTemplateRows: '6.5rem 6.5rem', gridTemplateColumns: `repeat(auto-fill, minmax(96px, 1fr));`,
gridAutoFlow: 'column dense', maxH: 346,
gridAutoColumns: '5rem',
}} }}
> >
{!searchMode && (
<>
<GridItem sx={{ p: 1.5 }}>
<AllImagesBoard isSelected={selectedBoardId === 'images'} />
</GridItem>
<GridItem sx={{ p: 1.5 }}>
<AllAssetsBoard isSelected={selectedBoardId === 'assets'} />
</GridItem>
<GridItem sx={{ p: 1.5 }}>
<NoBoardBoard isSelected={selectedBoardId === 'no_board'} />
</GridItem>
{isBatchEnabled && (
<GridItem sx={{ p: 1.5 }}>
<BatchBoard isSelected={selectedBoardId === 'batch'} />
</GridItem>
)}
</>
)}
{filteredBoards && {filteredBoards &&
filteredBoards.map((board) => ( filteredBoards.map((board) => (
<GridItem key={board.board_id} sx={{ p: 1.5 }}> <GridItem key={board.board_id} sx={{ p: 1.5 }}>

View File

@ -10,7 +10,14 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { setBoardSearchText } from 'features/gallery/store/boardSlice'; import { setBoardSearchText } from 'features/gallery/store/boardSlice';
import { memo } from 'react'; import {
ChangeEvent,
KeyboardEvent,
memo,
useCallback,
useEffect,
useRef,
} from 'react';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -22,31 +29,60 @@ const selector = createSelector(
); );
type Props = { type Props = {
setSearchMode: (searchMode: boolean) => void; setIsSearching: (isSearching: boolean) => void;
}; };
const BoardsSearch = (props: Props) => { const BoardsSearch = (props: Props) => {
const { setSearchMode } = props; const { setIsSearching } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { searchText } = useAppSelector(selector); const { searchText } = useAppSelector(selector);
const inputRef = useRef<HTMLInputElement>(null);
const handleBoardSearch = (searchTerm: string) => { const handleBoardSearch = useCallback(
setSearchMode(searchTerm.length > 0); (searchTerm: string) => {
dispatch(setBoardSearchText(searchTerm)); dispatch(setBoardSearchText(searchTerm));
}; },
const clearBoardSearch = () => { [dispatch]
setSearchMode(false); );
const clearBoardSearch = useCallback(() => {
dispatch(setBoardSearchText('')); dispatch(setBoardSearchText(''));
}; setIsSearching(false);
}, [dispatch, setIsSearching]);
const handleKeydown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
// exit search mode on escape
if (e.key === 'Escape') {
clearBoardSearch();
}
},
[clearBoardSearch]
);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
handleBoardSearch(e.target.value);
},
[handleBoardSearch]
);
useEffect(() => {
// focus the search box on mount
if (!inputRef.current) {
return;
}
inputRef.current.focus();
}, []);
return ( return (
<InputGroup> <InputGroup>
<Input <Input
ref={inputRef}
placeholder="Search Boards..." placeholder="Search Boards..."
value={searchText} value={searchText}
onChange={(e) => { onKeyDown={handleKeydown}
handleBoardSearch(e.target.value); onChange={handleChange}
}}
/> />
{searchText && searchText.length && ( {searchText && searchText.length && (
<InputRightElement> <InputRightElement>
@ -55,7 +91,8 @@ const BoardsSearch = (props: Props) => {
size="xs" size="xs"
variant="ghost" variant="ghost"
aria-label="Clear Search" aria-label="Clear Search"
icon={<CloseIcon boxSize={3} />} opacity={0.5}
icon={<CloseIcon boxSize={2} />}
/> />
</InputRightElement> </InputRightElement>
)} )}

View File

@ -6,9 +6,9 @@ import {
EditableInput, EditableInput,
EditablePreview, EditablePreview,
Flex, Flex,
Icon,
Image, Image,
Text, Text,
useColorMode,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
@ -17,14 +17,12 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable'; import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/gallerySlice'; import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { FaUser } from 'react-icons/fa'; import { FaFolder } from 'react-icons/fa';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types'; import { BoardDTO } from 'services/api/types';
import { mode } from 'theme/util/mode';
import BoardContextMenu from '../BoardContextMenu'; import BoardContextMenu from '../BoardContextMenu';
const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = { const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = {
@ -66,8 +64,9 @@ const GalleryBoard = memo(
board.cover_image_name ?? skipToken board.cover_image_name ?? skipToken
); );
const { colorMode } = useColorMode();
const { board_name, board_id } = board; const { board_name, board_id } = board;
const [localBoardName, setLocalBoardName] = useState(board_name);
const handleSelectBoard = useCallback(() => { const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(board_id)); dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]); }, [board_id, dispatch]);
@ -75,10 +74,6 @@ const GalleryBoard = memo(
const [updateBoard, { isLoading: isUpdateBoardLoading }] = const [updateBoard, { isLoading: isUpdateBoardLoading }] =
useUpdateBoardMutation(); useUpdateBoardMutation();
const handleUpdateBoardName = (newBoardName: string) => {
updateBoard({ board_id, changes: { board_name: newBoardName } });
};
const droppableData: MoveBoardDropData = useMemo( const droppableData: MoveBoardDropData = useMemo(
() => ({ () => ({
id: board_id, id: board_id,
@ -88,8 +83,49 @@ const GalleryBoard = memo(
[board_id] [board_id]
); );
const handleSubmit = useCallback(
(newBoardName: string) => {
if (!newBoardName) {
// empty strings are not allowed
setLocalBoardName(board_name);
return;
}
if (newBoardName === board_name) {
// don't updated the board name if it hasn't changed
return;
}
updateBoard({ board_id, changes: { board_name: newBoardName } })
.unwrap()
.then((response) => {
// update local state
setLocalBoardName(response.board_name);
})
.catch(() => {
// revert on error
setLocalBoardName(board_name);
});
},
[board_id, board_name, updateBoard]
);
const handleChange = useCallback((newBoardName: string) => {
setLocalBoardName(newBoardName);
}, []);
return ( return (
<Box sx={{ touchAction: 'none', height: 'full' }}> <Box
sx={{ w: 'full', h: 'full', touchAction: 'none', userSelect: 'none' }}
>
<Flex
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
w: 'full',
h: 'full',
}}
>
<BoardContextMenu <BoardContextMenu
board={board} board={board}
board_id={board_id} board_id={board_id}
@ -97,50 +133,66 @@ const GalleryBoard = memo(
> >
{(ref) => ( {(ref) => (
<Flex <Flex
key={board_id}
userSelect="none"
ref={ref} ref={ref}
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
}}
>
<Flex
onClick={handleSelectBoard} onClick={handleSelectBoard}
sx={{ sx={{
w: 'full',
h: 'full',
position: 'relative', position: 'relative',
justifyContent: 'center', justifyContent: 'center',
alignItems: 'center', alignItems: 'center',
borderRadius: 'base', borderRadius: 'base',
w: 'full', cursor: 'pointer',
aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}} }}
> >
{board.cover_image_name && coverImage?.thumbnail_url && ( <Flex
<Image src={coverImage?.thumbnail_url} draggable={false} />
)}
{!(board.cover_image_name && coverImage?.thumbnail_url) && (
<IAINoContentFallback
boxSize={8}
icon={FaUser}
sx={{ sx={{
borderWidth: '2px', w: 'full',
borderStyle: 'solid', h: 'full',
borderColor: 'base.200', justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.200',
_dark: { _dark: {
borderColor: 'base.800', bg: 'base.800',
},
}}
>
{coverImage?.thumbnail_url ? (
<Image
src={coverImage?.thumbnail_url}
draggable={false}
sx={{
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
borderBottomRadius: 'lg',
}}
/>
) : (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
}}
>
<Icon
boxSize={12}
as={FaFolder}
sx={{
mt: -3,
opacity: 0.7,
color: 'base.500',
_dark: {
color: 'base.500',
}, },
}} }}
/> />
</Flex>
)} )}
</Flex>
<Flex <Flex
sx={{ sx={{
position: 'absolute', position: 'absolute',
@ -160,37 +212,59 @@ const GalleryBoard = memo(
{board.image_count} {board.image_count}
</Badge> </Badge>
</Flex> </Flex>
<IAIDroppable <Box
data={droppableData} className="selection-box"
dropLabel={<Text fontSize="md">Move</Text>} sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: 'common',
shadow: isSelected ? 'selected.light' : undefined,
_dark: {
shadow: isSelected ? 'selected.dark' : undefined,
},
}}
/> />
</Flex>
<Flex <Flex
sx={{ sx={{
width: 'full', position: 'absolute',
height: 'full', bottom: 0,
left: 0,
p: 1,
justifyContent: 'center', justifyContent: 'center',
alignItems: 'center', alignItems: 'center',
w: 'full',
maxW: 'full',
borderBottomRadius: 'base',
bg: isSelected ? 'accent.400' : 'base.500',
color: isSelected ? 'base.50' : 'base.100',
_dark: {
bg: isSelected ? 'accent.500' : 'base.600',
color: isSelected ? 'base.50' : 'base.100',
},
lineHeight: 'short',
fontSize: 'xs',
}} }}
> >
<Editable <Editable
defaultValue={board_name} value={localBoardName}
isDisabled={isUpdateBoardLoading}
submitOnBlur={true} submitOnBlur={true}
onSubmit={(nextValue) => { onChange={handleChange}
handleUpdateBoardName(nextValue); onSubmit={handleSubmit}
sx={{
w: 'full',
}} }}
sx={{ maxW: 'full' }}
> >
<EditablePreview <EditablePreview
sx={{ sx={{
color: isSelected
? mode('base.900', 'base.50')(colorMode)
: mode('base.700', 'base.200')(colorMode),
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
textAlign: 'center',
p: 0, p: 0,
fontWeight: isSelected ? 700 : 500,
textAlign: 'center',
overflow: 'hidden', overflow: 'hidden',
textOverflow: 'ellipsis', textOverflow: 'ellipsis',
}} }}
@ -198,18 +272,26 @@ const GalleryBoard = memo(
/> />
<EditableInput <EditableInput
sx={{ sx={{
color: mode('base.900', 'base.50')(colorMode),
fontSize: 'xs',
borderColor: mode('base.500', 'base.500')(colorMode),
p: 0, p: 0,
outline: 0, _focusVisible: {
p: 0,
textAlign: 'center',
// get rid of the edit border
boxShadow: 'none',
},
}} }}
/> />
</Editable> </Editable>
</Flex> </Flex>
<IAIDroppable
data={droppableData}
dropLabel={<Text fontSize="md">Move</Text>}
/>
</Flex> </Flex>
)} )}
</BoardContextMenu> </BoardContextMenu>
</Flex>
</Box> </Box>
); );
} }

View File

@ -17,7 +17,7 @@ type GenericBoardProps = {
badgeCount?: number; badgeCount?: number;
}; };
const formatBadgeCount = (count: number) => export const formatBadgeCount = (count: number) =>
Intl.NumberFormat('en-US', { Intl.NumberFormat('en-US', {
notation: 'compact', notation: 'compact',
maximumFractionDigits: 1, maximumFractionDigits: 1,
@ -92,7 +92,7 @@ const GenericBoard = (props: GenericBoardProps) => {
h: 'full', h: 'full',
alignItems: 'center', alignItems: 'center',
fontWeight: isSelected ? 600 : undefined, fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs', fontSize: 'sm',
color: isSelected ? 'base.900' : 'base.700', color: isSelected ? 'base.900' : 'base.700',
_dark: { color: isSelected ? 'base.50' : 'base.200' }, _dark: { color: isSelected ? 'base.50' : 'base.200' },
}} }}

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
type Props = {
board_id: 'images' | 'assets' | 'no_board';
};
const SystemBoardButton = ({ board_id }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery }) => {
const { selectedBoardId } = gallery;
return { isSelected: selectedBoardId === board_id };
},
defaultSelectorOptions
),
[board_id]
);
const { isSelected } = useAppSelector(selector);
const boardName = useBoardName(board_id);
const handleClick = useCallback(() => {
dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]);
return (
<IAIButton
onClick={handleClick}
size="sm"
isChecked={isSelected}
sx={{
flexGrow: 1,
borderRadius: 'base',
}}
>
{boardName}
</IAIButton>
);
};
export default memo(SystemBoardButton);

View File

@ -4,8 +4,9 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react'; import { memo, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName'; import { useBoardName } from 'services/api/hooks/useBoardName';
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -26,6 +27,17 @@ const GalleryBoardName = (props: Props) => {
const { isOpen, onToggle } = props; const { isOpen, onToggle } = props;
const { selectedBoardId } = useAppSelector(selector); const { selectedBoardId } = useAppSelector(selector);
const boardName = useBoardName(selectedBoardId); const boardName = useBoardName(selectedBoardId);
const numOfBoardImages = useBoardTotal(selectedBoardId);
const formattedBoardName = useMemo(() => {
if (!boardName || !numOfBoardImages) {
return '';
}
if (boardName.length > 20) {
return `${boardName.substring(0, 20)}... (${numOfBoardImages})`;
}
return `${boardName} (${numOfBoardImages})`;
}, [boardName, numOfBoardImages]);
return ( return (
<Flex <Flex
@ -58,7 +70,7 @@ const GalleryBoardName = (props: Props) => {
}, },
}} }}
> >
{boardName} {formattedBoardName}
</Text> </Text>
</Box> </Box>
<Spacer /> <Spacer />

View File

@ -109,7 +109,7 @@ const GalleryDrawer = () => {
isResizable={true} isResizable={true}
isOpen={shouldShowGallery} isOpen={shouldShowGallery}
onClose={handleCloseGallery} onClose={handleCloseGallery}
minWidth={337} minWidth={400}
> >
<ImageGalleryContent /> <ImageGalleryContent />
</ResizableDrawer> </ResizableDrawer>

View File

@ -1,4 +1,4 @@
import { Box } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
@ -86,15 +86,10 @@ const GalleryImage = (props: HoverableImageProps) => {
return ( return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}> <Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
<ImageContextMenu imageDTO={imageDTO}> <Flex
{(ref) => (
<Box
position="relative"
key={imageName}
userSelect="none" userSelect="none"
ref={ref}
sx={{ sx={{
display: 'flex', position: 'relative',
justifyContent: 'center', justifyContent: 'center',
alignItems: 'center', alignItems: 'center',
aspectRatio: '1/1', aspectRatio: '1/1',
@ -115,9 +110,7 @@ const GalleryImage = (props: HoverableImageProps) => {
// resetTooltip="Delete image" // resetTooltip="Delete image"
// withResetIcon // removed bc it's too easy to accidentally delete images // withResetIcon // removed bc it's too easy to accidentally delete images
/> />
</Box> </Flex>
)}
</ImageContextMenu>
</Box> </Box>
); );
}; };

View File

@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = (
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlMode, controlMode,
resizeMode,
model, model,
processorType, processorType,
weight, weight,
@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = (
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
control_mode: controlMode, control_mode: controlMode,
resize_mode: resizeMode,
control_model: model as ControlNetInvocation['control_model'], control_model: model as ControlNetInvocation['control_model'],
control_weight: weight, control_weight: weight,
}; };

View File

@ -105,7 +105,7 @@ const enabledTabsSelector = createSelector(
} }
); );
const MIN_GALLERY_WIDTH = 300; const MIN_GALLERY_WIDTH = 350;
const DEFAULT_GALLERY_PCT = 20; const DEFAULT_GALLERY_PCT = 20;
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager']; export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];

View File

@ -3,7 +3,7 @@ import { memo } from 'react';
import { PanelResizeHandle } from 'react-resizable-panels'; import { PanelResizeHandle } from 'react-resizable-panels';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
type ResizeHandleProps = FlexProps & { type ResizeHandleProps = Omit<FlexProps, 'direction'> & {
direction?: 'horizontal' | 'vertical'; direction?: 'horizontal' | 'vertical';
}; };

View File

@ -169,10 +169,11 @@ export const imagesApi = api.injectEndpoints({
], ],
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) { async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
/** /**
* Cache changes for deleteImage: * Cache changes for `deleteImage`:
* - Remove from "All Images" * - *remove* from "All Images" / "All Assets"
* - Remove from image's `board_id` if it has one, or "No Board" if not * - IF it has a board:
* - Remove from "Batch" * - THEN *remove* from it's own board
* - ELSE *remove* from "No Board"
*/ */
const { image_name, board_id, image_category } = imageDTO; const { image_name, board_id, image_category } = imageDTO;
@ -181,22 +182,23 @@ export const imagesApi = api.injectEndpoints({
// That means constructing the possible query args that are serialized into the cache key... // That means constructing the possible query args that are serialized into the cache key...
const removeFromCacheKeys: ListImagesArgs[] = []; const removeFromCacheKeys: ListImagesArgs[] = [];
// determine `categories`, i.e. do we update "All Images" or "All Assets"
const categories = IMAGE_CATEGORIES.includes(image_category) const categories = IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES ? IMAGE_CATEGORIES
: ASSETS_CATEGORIES; : ASSETS_CATEGORIES;
// All Images board (e.g. no board) // remove from "All Images"
removeFromCacheKeys.push({ categories }); removeFromCacheKeys.push({ categories });
// Board specific
if (board_id) { if (board_id) {
// remove from it's own board
removeFromCacheKeys.push({ board_id }); removeFromCacheKeys.push({ board_id });
} else { } else {
// TODO: No Board // remove from "No Board"
removeFromCacheKeys.push({ board_id: 'none' });
} }
// TODO: Batch
const patches: PatchCollection[] = []; const patches: PatchCollection[] = [];
removeFromCacheKeys.forEach((cacheKey) => { removeFromCacheKeys.forEach((cacheKey) => {
patches.push( patches.push(
@ -240,32 +242,37 @@ export const imagesApi = api.injectEndpoints({
{ imageDTO: oldImageDTO, changes: _changes }, { imageDTO: oldImageDTO, changes: _changes },
{ dispatch, queryFulfilled, getState } { dispatch, queryFulfilled, getState }
) { ) {
// TODO: Should we handle changes to boards via this mutation? Seems reasonable...
// let's be extra-sure we do not accidentally change categories // let's be extra-sure we do not accidentally change categories
const changes = omit(_changes, 'image_category'); const changes = omit(_changes, 'image_category');
/** /**
* Cache changes for `updateImage`: * Cache changes for "updateImage":
* - Update the ImageDTO * - *update* "getImageDTO" cache
* - Update the image in "All Images" board: * - for "All Images" || "All Assets":
* - IF it is in the date range represented by the cache: * - IF it is not already in the cache
* - add the image IF it is not already in the cache & update the total * - THEN *add* it to "All Images" / "All Assets" and update the total
* - ELSE update the image IF it is already in the cache * - ELSE *update* it
* - IF the image has a board: * - IF the image has a board:
* - Update the image in it's own board * - THEN *update* it's own board
* - ELSE Update the image in the "No Board" board (TODO) * - ELSE *update* the "No Board" board
*/ */
const patches: PatchCollection[] = []; const patches: PatchCollection[] = [];
const { image_name, board_id, image_category } = oldImageDTO; const { image_name, board_id, image_category, is_intermediate } =
oldImageDTO;
const isChangingFromIntermediate = changes.is_intermediate === false;
// do not add intermediates to gallery cache
if (is_intermediate && !isChangingFromIntermediate) {
return;
}
// determine `categories`, i.e. do we update "All Images" or "All Assets"
const categories = IMAGE_CATEGORIES.includes(image_category) const categories = IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES ? IMAGE_CATEGORIES
: ASSETS_CATEGORIES; : ASSETS_CATEGORIES;
// TODO: No Board // update `getImageDTO` cache
// Update `getImageDTO` cache
patches.push( patches.push(
dispatch( dispatch(
imagesApi.util.updateQueryData( imagesApi.util.updateQueryData(
@ -281,9 +288,13 @@ export const imagesApi = api.injectEndpoints({
// Update the "All Image" or "All Assets" board // Update the "All Image" or "All Assets" board
const queryArgsToUpdate: ListImagesArgs[] = [{ categories }]; const queryArgsToUpdate: ListImagesArgs[] = [{ categories }];
// IF the image has a board:
if (board_id) { if (board_id) {
// We also need to update the user board // THEN update it's own board
queryArgsToUpdate.push({ board_id }); queryArgsToUpdate.push({ board_id });
} else {
// ELSE update the "No Board" board
queryArgsToUpdate.push({ board_id: 'none' });
} }
queryArgsToUpdate.forEach((queryArg) => { queryArgsToUpdate.forEach((queryArg) => {
@ -371,12 +382,12 @@ export const imagesApi = api.injectEndpoints({
return; return;
} }
// Add the image to the "All Images" / "All Assets" board // determine `categories`, i.e. do we update "All Images" or "All Assets"
const queryArg = { const categories = IMAGE_CATEGORIES.includes(image_category)
categories: IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES ? IMAGE_CATEGORIES
: ASSETS_CATEGORIES, : ASSETS_CATEGORIES;
};
const queryArg = { categories };
dispatch( dispatch(
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => { imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
@ -410,16 +421,14 @@ export const imagesApi = api.injectEndpoints({
{ dispatch, queryFulfilled, getState } { dispatch, queryFulfilled, getState }
) { ) {
/** /**
* Cache changes for addImageToBoard: * Cache changes for `addImageToBoard`:
* - Remove from "No Board" * - *update* the `getImageDTO` cache
* - Remove from `old_board_id` if it has one * - *remove* from "No Board"
* - Add to new `board_id` * - IF the image has an old `board_id`:
* - THEN *remove* from it's old `board_id`
* - IF the image's `created_at` is within the range of the board's cached images * - IF the image's `created_at` is within the range of the board's cached images
* - OR the board cache has length of 0 or 1 * - OR the board cache has length of 0 or 1
* - Update the `total` for each board whose cache is updated * - THEN *add* it to new `board_id`
* - Update the ImageDTO
*
* TODO: maybe total should just be updated in the boards endpoints?
*/ */
const { image_name, board_id: old_board_id } = oldImageDTO; const { image_name, board_id: old_board_id } = oldImageDTO;
@ -427,13 +436,10 @@ export const imagesApi = api.injectEndpoints({
// Figure out the `listImages` caches that we need to update // Figure out the `listImages` caches that we need to update
const removeFromQueryArgs: ListImagesArgs[] = []; const removeFromQueryArgs: ListImagesArgs[] = [];
// TODO: No Board // remove from "No Board"
// TODO: Batch
// Remove from No Board
removeFromQueryArgs.push({ board_id: 'none' }); removeFromQueryArgs.push({ board_id: 'none' });
// Remove from old board // remove from old board
if (old_board_id) { if (old_board_id) {
removeFromQueryArgs.push({ board_id: old_board_id }); removeFromQueryArgs.push({ board_id: old_board_id });
} }
@ -534,17 +540,15 @@ export const imagesApi = api.injectEndpoints({
{ dispatch, queryFulfilled, getState } { dispatch, queryFulfilled, getState }
) { ) {
/** /**
* Cache changes for removeImageFromBoard: * Cache changes for `removeImageFromBoard`:
* - Add to "No Board" * - *update* `getImageDTO`
* - IF the image's `created_at` is within the range of the board's cached images * - IF the image's `created_at` is within the range of the board's cached images
* - Remove from `old_board_id` * - THEN *add* to "No Board"
* - Update the ImageDTO * - *remove* from `old_board_id`
*/ */
const { image_name, board_id: old_board_id } = imageDTO; const { image_name, board_id: old_board_id } = imageDTO;
// TODO: Batch
const patches: PatchCollection[] = []; const patches: PatchCollection[] = [];
// Updated imageDTO with new board_id // Updated imageDTO with new board_id

View File

@ -6,9 +6,9 @@ export const useBoardName = (board_id: BoardId | null | undefined) => {
selectFromResult: ({ data }) => { selectFromResult: ({ data }) => {
let boardName = ''; let boardName = '';
if (board_id === 'images') { if (board_id === 'images') {
boardName = 'All Images'; boardName = 'Images';
} else if (board_id === 'assets') { } else if (board_id === 'assets') {
boardName = 'All Assets'; boardName = 'Assets';
} else if (board_id === 'no_board') { } else if (board_id === 'no_board') {
boardName = 'No Board'; boardName = 'No Board';
} else if (board_id === 'batch') { } else if (board_id === 'batch') {

View File

@ -0,0 +1,53 @@
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
ASSETS_CATEGORIES,
BoardId,
IMAGE_CATEGORIES,
INITIAL_IMAGE_LIMIT,
} from 'features/gallery/store/gallerySlice';
import { useMemo } from 'react';
import { ListImagesArgs, useListImagesQuery } from '../endpoints/images';
const baseQueryArg: ListImagesArgs = {
offset: 0,
limit: INITIAL_IMAGE_LIMIT,
is_intermediate: false,
};
const imagesQueryArg: ListImagesArgs = {
categories: IMAGE_CATEGORIES,
...baseQueryArg,
};
const assetsQueryArg: ListImagesArgs = {
categories: ASSETS_CATEGORIES,
...baseQueryArg,
};
const noBoardQueryArg: ListImagesArgs = {
board_id: 'none',
...baseQueryArg,
};
export const useBoardTotal = (board_id: BoardId | null | undefined) => {
const queryArg = useMemo(() => {
if (!board_id) {
return;
}
if (board_id === 'images') {
return imagesQueryArg;
} else if (board_id === 'assets') {
return assetsQueryArg;
} else if (board_id === 'no_board') {
return noBoardQueryArg;
} else {
return { board_id, ...baseQueryArg };
}
}, [board_id]);
const { total } = useListImagesQuery(queryArg ?? skipToken, {
selectFromResult: ({ currentData }) => ({ total: currentData?.total }),
});
return total;
};

View File

@ -167,7 +167,7 @@ export type paths = {
"/api/v1/images/clear-intermediates": { "/api/v1/images/clear-intermediates": {
/** /**
* Clear Intermediates * Clear Intermediates
* @description Clears first 100 intermediates * @description Clears all intermediates
*/ */
post: operations["clear_intermediates"]; post: operations["clear_intermediates"];
}; };
@ -800,6 +800,13 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced"; control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
/**
* Resize Mode
* @description The resize mode to use
* @default just_resize
* @enum {string}
*/
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
}; };
/** /**
* ControlNetInvocation * ControlNetInvocation
@ -859,6 +866,13 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced"; control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
/**
* Resize Mode
* @description The resize mode used
* @default just_resize
* @enum {string}
*/
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
}; };
/** ControlNetModelConfig */ /** ControlNetModelConfig */
ControlNetModelConfig: { ControlNetModelConfig: {
@ -5324,11 +5338,11 @@ export type components = {
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/** /**
* StableDiffusion2ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusionXLModelFormat * StableDiffusionXLModelFormat
* @description An enumeration. * @description An enumeration.
@ -5336,11 +5350,11 @@ export type components = {
*/ */
StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -6125,7 +6139,7 @@ export type operations = {
}; };
/** /**
* Clear Intermediates * Clear Intermediates
* @description Clears first 100 intermediates * @description Clears all intermediates
*/ */
clear_intermediates: { clear_intermediates: {
responses: { responses: {

View File

@ -2,11 +2,16 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
import { generateColorPalette } from 'theme/util/generateColorPalette'; import { generateColorPalette } from 'theme/util/generateColorPalette';
const BASE = { H: 220, S: 16 }; const BASE = { H: 220, S: 16 };
const ACCENT = { H: 250, S: 52 }; const ACCENT = { H: 250, S: 42 };
const WORKING = { H: 47, S: 50 }; // const ACCENT = { H: 250, S: 52 };
const WARNING = { H: 28, S: 50 }; const WORKING = { H: 47, S: 42 };
const OK = { H: 113, S: 50 }; // const WORKING = { H: 47, S: 50 };
const ERROR = { H: 0, S: 50 }; const WARNING = { H: 28, S: 42 };
// const WARNING = { H: 28, S: 50 };
const OK = { H: 113, S: 42 };
// const OK = { H: 113, S: 50 };
const ERROR = { H: 0, S: 42 };
// const ERROR = { H: 0, S: 50 };
export const InvokeAIColors: InvokeAIThemeColors = { export const InvokeAIColors: InvokeAIThemeColors = {
base: generateColorPalette(BASE.H, BASE.S), base: generateColorPalette(BASE.H, BASE.S),

View File

@ -0,0 +1,56 @@
import { editableAnatomy as parts } from '@chakra-ui/anatomy';
import {
createMultiStyleConfigHelpers,
defineStyle,
} from '@chakra-ui/styled-system';
import { mode } from '@chakra-ui/theme-tools';
const { definePartsStyle, defineMultiStyleConfig } =
createMultiStyleConfigHelpers(parts.keys);
const baseStylePreview = defineStyle({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
});
const baseStyleInput = defineStyle((props) => ({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
width: 'full',
_focusVisible: { boxShadow: 'outline' },
_placeholder: { opacity: 0.6 },
'::selection': {
color: mode('accent.900', 'accent.50')(props),
bg: mode('accent.200', 'accent.400')(props),
},
}));
const baseStyleTextarea = defineStyle({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
width: 'full',
_focusVisible: { boxShadow: 'outline' },
_placeholder: { opacity: 0.6 },
});
const invokeAI = definePartsStyle((props) => ({
preview: baseStylePreview,
input: baseStyleInput(props),
textarea: baseStyleTextarea,
}));
export const editableTheme = defineMultiStyleConfig({
variants: {
invokeAI,
},
defaultProps: {
size: 'sm',
variant: 'invokeAI',
},
});

View File

@ -4,6 +4,7 @@ import { InvokeAIColors } from './colors/colors';
import { accordionTheme } from './components/accordion'; import { accordionTheme } from './components/accordion';
import { buttonTheme } from './components/button'; import { buttonTheme } from './components/button';
import { checkboxTheme } from './components/checkbox'; import { checkboxTheme } from './components/checkbox';
import { editableTheme } from './components/editable';
import { formLabelTheme } from './components/formLabel'; import { formLabelTheme } from './components/formLabel';
import { inputTheme } from './components/input'; import { inputTheme } from './components/input';
import { menuTheme } from './components/menu'; import { menuTheme } from './components/menu';
@ -72,7 +73,17 @@ export const theme: ThemeOverride = {
selected: { selected: {
light: light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', '0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
},
hoverSelected: {
light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-300)',
},
hoverUnselected: {
light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-200)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-600)',
}, },
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`, nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`,
}, },
@ -80,6 +91,7 @@ export const theme: ThemeOverride = {
components: { components: {
Button: buttonTheme, // Button and IconButton Button: buttonTheme, // Button and IconButton
Input: inputTheme, Input: inputTheme,
Editable: editableTheme,
Textarea: textareaTheme, Textarea: textareaTheme,
Tabs: tabsTheme, Tabs: tabsTheme,
Progress: progressTheme, Progress: progressTheme,

View File

@ -37,4 +37,7 @@ export const getInputOutlineStyles = (props: StyleFunctionProps) => ({
_placeholder: { _placeholder: {
color: mode('base.700', 'base.400')(props), color: mode('base.700', 'base.400')(props),
}, },
'::selection': {
bg: mode('accent.200', 'accent.400')(props),
},
}); });