diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 43cad3dcaf..911fede8fb 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_MODE_VALUES = Literal[tuple( ["balanced", "more_prompt", "more_control", "unbalanced"])] -# crop and fill options not ready yet -# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] +CONTROLNET_RESIZE_VALUES = Literal[tuple( + ["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])] class ControlNetModelField(BaseModel): @@ -111,7 +111,8 @@ class ControlField(BaseModel): description="When the ControlNet is last applied (% of total steps)") control_mode: CONTROLNET_MODE_VALUES = Field( default="balanced", description="The control mode to use") - # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") + resize_mode: CONTROLNET_RESIZE_VALUES = Field( + default="just_resize", description="The resize mode to use") @validator("control_weight") def validate_control_weight(cls, v): @@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation): end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used") + resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used") # fmt: on class Config(InvocationConfig): @@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation): begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, control_mode=self.control_mode, + resize_mode=self.resize_mode, ), ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index cd15fe156b..b4c3454c88 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -30,6 +30,7 @@ from .compel import ConditioningField from .controlnet_image_processors import ControlField from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField +from invokeai.app.util.controlnet_utils import prepare_control_image from diffusers.models.attention_processor import ( AttnProcessor2_0, @@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation): # and add in batch_size, num_images_per_prompt? # and do real check for classifier_free_guidance? # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) - control_image = model.prepare_control_image( + control_image = prepare_control_image( image=input_image, do_classifier_free_guidance=do_classifier_free_guidance, width=control_width_resize, @@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation): device=control_model.device, dtype=control_model.dtype, control_mode=control_info.control_mode, + resize_mode=control_info.resize_mode, ) control_item = ControlNetData( - model=control_model, image_tensor=control_image, + model=control_model, + image_tensor=control_image, weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, control_mode=control_info.control_mode, + # any resizing needed should currently be happening in prepare_control_image(), + # but adding resize_mode to ControlNetData in case needed in the future + resize_mode=control_info.resize_mode, ) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] @@ -601,7 +607,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = Field( default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") - + class Config(InvocationConfig): schema_extra = { "ui": { @@ -647,7 +653,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = Field( default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") - + class Config(InvocationConfig): schema_extra = { "ui": { diff --git a/invokeai/app/util/controlnet_utils.py b/invokeai/app/util/controlnet_utils.py new file mode 100644 index 0000000000..342fa147c5 --- /dev/null +++ b/invokeai/app/util/controlnet_utils.py @@ -0,0 +1,342 @@ +import torch +import numpy as np +import cv2 +from PIL import Image +from diffusers.utils import PIL_INTERPOLATION + +from einops import rearrange +from controlnet_aux.util import HWC3, resize_image + +################################################################### +# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet +################################################################### +# High Quality Edge Thinning using Pure Python +# Written by Lvmin Zhangu +# 2023 April +# Stanford University +# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet. + +lvmin_kernels_raw = [ + np.array([ + [-1, -1, -1], + [0, 1, 0], + [1, 1, 1] + ], dtype=np.int32), + np.array([ + [0, -1, -1], + [1, 1, -1], + [0, 1, 0] + ], dtype=np.int32) +] + +lvmin_kernels = [] +lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw] +lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw] +lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw] +lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw] + +lvmin_prunings_raw = [ + np.array([ + [-1, -1, -1], + [-1, 1, -1], + [0, 0, -1] + ], dtype=np.int32), + np.array([ + [-1, -1, -1], + [-1, 1, -1], + [-1, 0, 0] + ], dtype=np.int32) +] + +lvmin_prunings = [] +lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw] +lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw] +lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw] +lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw] + + +def remove_pattern(x, kernel): + objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel) + objects = np.where(objects > 127) + x[objects] = 0 + return x, objects[0].shape[0] > 0 + + +def thin_one_time(x, kernels): + y = x + is_done = True + for k in kernels: + y, has_update = remove_pattern(y, k) + if has_update: + is_done = False + return y, is_done + + +def lvmin_thin(x, prunings=True): + y = x + for i in range(32): + y, is_done = thin_one_time(y, lvmin_kernels) + if is_done: + break + if prunings: + y, _ = thin_one_time(y, lvmin_prunings) + return y + + +def nake_nms(x): + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + y = np.zeros_like(x) + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + return y + + +################################################################################ +# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI +################################################################################ +# FIXME: not using yet, if used in the future will most likely require modification of preprocessors +def pixel_perfect_resolution( + image: np.ndarray, + target_H: int, + target_W: int, + resize_mode: str, +) -> int: + """ + Calculate the estimated resolution for resizing an image while preserving aspect ratio. + + The function first calculates scaling factors for height and width of the image based on the target + height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger + scaling factor to estimate the new resolution. + + If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image + fits within the target dimensions, potentially leaving some empty space. + + If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target + dimensions are fully filled, potentially cropping the image. + + After calculating the estimated resolution, the function prints some debugging information. + + Args: + image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels]. + target_H (int): The target height for the image. + target_W (int): The target width for the image. + resize_mode (ResizeMode): The mode for resizing. + + Returns: + int: The estimated resolution after resizing. + """ + raw_H, raw_W, _ = image.shape + + k0 = float(target_H) / float(raw_H) + k1 = float(target_W) / float(raw_W) + + if resize_mode == "fill_resize": + estimation = min(k0, k1) * float(min(raw_H, raw_W)) + else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?) + estimation = max(k0, k1) * float(min(raw_H, raw_W)) + + # print(f"Pixel Perfect Computation:") + # print(f"resize_mode = {resize_mode}") + # print(f"raw_H = {raw_H}") + # print(f"raw_W = {raw_W}") + # print(f"target_H = {target_H}") + # print(f"target_W = {target_W}") + # print(f"estimation = {estimation}") + + return int(np.round(estimation)) + + +########################################################################### +# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet +# modified for InvokeAI +########################################################################### +# def detectmap_proc(detected_map, module, resize_mode, h, w): +def np_img_resize( + np_img: np.ndarray, + resize_mode: str, + h: int, + w: int, + device: torch.device = torch.device('cpu') +): + # if 'inpaint' in module: + # np_img = np_img.astype(np.float32) + # else: + # np_img = HWC3(np_img) + np_img = HWC3(np_img) + + def safe_numpy(x): + # A very safe method to make sure that Apple/Mac works + y = x + + # below is very boring but do not change these. If you change these Apple or Mac may fail. + y = y.copy() + y = np.ascontiguousarray(y) + y = y.copy() + return y + + def get_pytorch_control(x): + # A very safe method to make sure that Apple/Mac works + y = x + + # below is very boring but do not change these. If you change these Apple or Mac may fail. + y = torch.from_numpy(y) + y = y.float() / 255.0 + y = rearrange(y, 'h w c -> 1 c h w') + y = y.clone() + # y = y.to(devices.get_device_for("controlnet")) + y = y.to(device) + y = y.clone() + return y + + def high_quality_resize(x: np.ndarray, + size): + # Written by lvmin + # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges + inpaint_mask = None + if x.ndim == 3 and x.shape[2] == 4: + inpaint_mask = x[:, :, 3] + x = x[:, :, 0:3] + + new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1]) + new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1]) + unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0] + is_one_pixel_edge = False + is_binary = False + if unique_color_count == 2: + is_binary = np.min(x) < 16 and np.max(x) > 240 + if is_binary: + xc = x + xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1) + xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1) + one_pixel_edge_count = np.where(xc < x)[0].shape[0] + all_edge_count = np.where(x > 127)[0].shape[0] + is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count + + if 2 < unique_color_count < 200: + interpolation = cv2.INTER_NEAREST + elif new_size_is_smaller: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS + + y = cv2.resize(x, size, interpolation=interpolation) + if inpaint_mask is not None: + inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation) + + if is_binary: + y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8) + if is_one_pixel_edge: + y = nake_nms(y) + _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + y = lvmin_thin(y, prunings=new_size_is_bigger) + else: + _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + y = np.stack([y] * 3, axis=2) + + if inpaint_mask is not None: + inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0 + inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8) + y = np.concatenate([y, inpaint_mask], axis=2) + + return y + + # if resize_mode == external_code.ResizeMode.RESIZE: + if resize_mode == "just_resize": # RESIZE + np_img = high_quality_resize(np_img, (w, h)) + np_img = safe_numpy(np_img) + return get_pytorch_control(np_img), np_img + + old_h, old_w, _ = np_img.shape + old_w = float(old_w) + old_h = float(old_h) + k0 = float(h) / old_h + k1 = float(w) / old_w + + safeint = lambda x: int(np.round(x)) + + # if resize_mode == external_code.ResizeMode.OUTER_FIT: + if resize_mode == "fill_resize": # OUTER_FIT + k = min(k0, k1) + borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0) + high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype) + if len(high_quality_border_color) == 4: + # Inpaint hijack + high_quality_border_color[3] = 255 + high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1]) + np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k))) + new_h, new_w, _ = np_img.shape + pad_h = max(0, (h - new_h) // 2) + pad_w = max(0, (w - new_w) // 2) + high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img + np_img = high_quality_background + np_img = safe_numpy(np_img) + return get_pytorch_control(np_img), np_img + else: # resize_mode == "crop_resize" (INNER_FIT) + k = max(k0, k1) + np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k))) + new_h, new_w, _ = np_img.shape + pad_h = max(0, (new_h - h) // 2) + pad_w = max(0, (new_w - w) // 2) + np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w] + np_img = safe_numpy(np_img) + return get_pytorch_control(np_img), np_img + +def prepare_control_image( + # image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]] + # but now should be able to assume that image is a single PIL.Image, which simplifies things + image: Image, + # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions? + # latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width) + width=512, # should be 8 * latent.shape[3] + height=512, # should be 8 * latent height[2] + # batch_size=1, # currently no batching + # num_images_per_prompt=1, # currently only single image + device="cuda", + dtype=torch.float16, + do_classifier_free_guidance=True, + control_mode="balanced", + resize_mode="just_resize_simple", +): + # FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out + if (resize_mode == "just_resize_simple" or + resize_mode == "crop_resize_simple" or + resize_mode == "fill_resize_simple"): + image = image.convert("RGB") + if (resize_mode == "just_resize_simple"): + image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + elif (resize_mode == "crop_resize_simple"): # not yet implemented + pass + elif (resize_mode == "fill_resize_simple"): # not yet implemented + pass + nimage = np.array(image) + nimage = nimage[None, :] + nimage = np.concatenate([nimage], axis=0) + # normalizing RGB values to [0,1] range (in PIL.Image they are [0-255]) + nimage = np.array(nimage).astype(np.float32) / 255.0 + nimage = nimage.transpose(0, 3, 1, 2) + timage = torch.from_numpy(nimage) + + # use fancy lvmin controlnet resizing + elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"): + nimage = np.array(image) + timage, nimage = np_img_resize( + np_img=nimage, + resize_mode=resize_mode, + h=height, + w=width, + # device=torch.device('cpu') + device=device, + ) + else: + pass + print("ERROR: invalid resize_mode ==> ", resize_mode) + exit(1) + + timage = timage.to(device=device, dtype=dtype) + cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced") + if do_classifier_free_guidance and not cfg_injection: + timage = torch.cat([timage] * 2) + return timage diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 228fbd0585..8acfb100a6 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -219,6 +219,7 @@ class ControlNetData: begin_step_percent: float = Field(default=0.0) end_step_percent: float = Field(default=1.0) control_mode: str = Field(default="balanced") + resize_mode: str = Field(default="just_resize") @dataclass @@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if cfg_injection: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. + # prepend zeros for unconditional batch down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) @@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): debug_image( img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True ) - - # Copied from diffusers pipeline_stable_diffusion_controlnet.py - # Returns torch.Tensor of shape (batch_size, 3, height, width) - @staticmethod - def prepare_control_image( - image, - # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions? - # latents, - width=512, # should be 8 * latent.shape[3] - height=512, # should be 8 * latent height[2] - batch_size=1, - num_images_per_prompt=1, - device="cuda", - dtype=torch.float16, - do_classifier_free_guidance=True, - control_mode="balanced" - ): - - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - image = images - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced") - if do_classifier_free_guidance and not cfg_injection: - image = torch.cat([image] * 2) - return image diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 6082843c55..57d54e155e 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -17,13 +17,13 @@ import { } from 'common/components/IAIImageFallback'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; +import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu'; import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react'; import { FaImage, FaUndo, FaUpload } from 'react-icons/fa'; import { ImageDTO, PostUploadAction } from 'services/api/types'; import { mode } from 'theme/util/mode'; import IAIDraggable from './IAIDraggable'; import IAIDroppable from './IAIDroppable'; -import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu'; type IAIDndImageProps = { imageDTO: ImageDTO | undefined; @@ -148,7 +148,9 @@ const IAIDndImage = (props: IAIDndImageProps) => { maxH: 'full', borderRadius: 'base', shadow: isSelected ? 'selected.light' : undefined, - _dark: { shadow: isSelected ? 'selected.dark' : undefined }, + _dark: { + shadow: isSelected ? 'selected.dark' : undefined, + }, ...imageSx, }} /> @@ -183,13 +185,6 @@ const IAIDndImage = (props: IAIDndImageProps) => { )} {!imageDTO && isUploadDisabled && noContentFallback} - {!isDropDisabled && ( - - )} {imageDTO && !isDragDisabled && ( { onClick={onClick} /> )} + {!isDropDisabled && ( + + )} {onClickReset && withResetIcon && imageDTO && ( ; }; const IAIDroppable = (props: IAIDroppableProps) => { - const { dropLabel, data, disabled } = props; + const { dropLabel, data, disabled, hoverRef } = props; const dndId = useRef(uuidv4()); const { isOver, setNodeRef, active } = useDroppable({ diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index 368b9f727c..d858e46fdb 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; +import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode'; type ControlNetProps = { controlNetId: string; @@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => { { tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'} aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'} onClick={toggleIsExpanded} - variant="link" + variant="ghost" + sx={{ + _hover: { + bg: 'none', + }, + }} icon={ { /> )} - + { h: 28, w: 28, aspectRatio: '1/1', - mt: 3, }} > )} - + - + + diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx new file mode 100644 index 0000000000..4b31ebfc64 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx @@ -0,0 +1,62 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { + ResizeModes, + controlNetResizeModeChanged, +} from 'features/controlNet/store/controlNetSlice'; +import { useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +type ParamControlNetResizeModeProps = { + controlNetId: string; +}; + +const RESIZE_MODE_DATA = [ + { label: 'Resize', value: 'just_resize' }, + { label: 'Crop', value: 'crop_resize' }, + { label: 'Fill', value: 'fill_resize' }, +]; + +export default function ParamControlNetResizeMode( + props: ParamControlNetResizeModeProps +) { + const { controlNetId } = props; + const dispatch = useAppDispatch(); + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => { + const { resizeMode, isEnabled } = + controlNet.controlNets[controlNetId]; + return { resizeMode, isEnabled }; + }, + defaultSelectorOptions + ), + [controlNetId] + ); + + const { resizeMode, isEnabled } = useAppSelector(selector); + + const { t } = useTranslation(); + + const handleResizeModeChange = useCallback( + (resizeMode: ResizeModes) => { + dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode })); + }, + [controlNetId, dispatch] + ); + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index 663edfd65f..2f8668115a 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -3,6 +3,7 @@ import { RootState } from 'app/store/store'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas'; import { cloneDeep, forEach } from 'lodash-es'; import { imagesApi } from 'services/api/endpoints/images'; +import { components } from 'services/api/schema'; import { isAnySessionRejected } from 'services/api/thunks/session'; import { appSocketInvocationError } from 'services/events/actions'; import { controlNetImageProcessed } from './actions'; @@ -16,11 +17,13 @@ import { RequiredControlNetProcessorNode, } from './types'; -export type ControlModes = - | 'balanced' - | 'more_prompt' - | 'more_control' - | 'unbalanced'; +export type ControlModes = NonNullable< + components['schemas']['ControlNetInvocation']['control_mode'] +>; + +export type ResizeModes = NonNullable< + components['schemas']['ControlNetInvocation']['resize_mode'] +>; export const initialControlNet: Omit = { isEnabled: true, @@ -29,6 +32,7 @@ export const initialControlNet: Omit = { beginStepPct: 0, endStepPct: 1, controlMode: 'balanced', + resizeMode: 'just_resize', controlImage: null, processedControlImage: null, processorType: 'canny_image_processor', @@ -45,6 +49,7 @@ export type ControlNetConfig = { beginStepPct: number; endStepPct: number; controlMode: ControlModes; + resizeMode: ResizeModes; controlImage: string | null; processedControlImage: string | null; processorType: ControlNetProcessorType; @@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({ const { controlNetId, controlMode } = action.payload; state.controlNets[controlNetId].controlMode = controlMode; }, + controlNetResizeModeChanged: ( + state, + action: PayloadAction<{ + controlNetId: string; + resizeMode: ResizeModes; + }> + ) => { + const { controlNetId, resizeMode } = action.payload; + state.controlNets[controlNetId].resizeMode = resizeMode; + }, controlNetProcessorParamsChanged: ( state, action: PayloadAction<{ @@ -342,6 +357,7 @@ export const { controlNetBeginStepPctChanged, controlNetEndStepPctChanged, controlNetControlModeChanged, + controlNetResizeModeChanged, controlNetProcessorParamsChanged, controlNetProcessorTypeChanged, controlNetReset, diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index fa3a6b03be..3b3303f0c8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -23,34 +23,32 @@ const BoardContextMenu = memo( dispatch(boardIdSelected(board?.board_id ?? board_id)); }, [board?.board_id, board_id, dispatch]); return ( - - - menuProps={{ size: 'sm', isLazy: true }} - menuButtonProps={{ - bg: 'transparent', - _hover: { bg: 'transparent' }, - }} - renderMenu={() => ( - - } onClickCapture={handleSelectBoard}> - Select Board - - {!board && } - {board && ( - - )} - - )} - > - {children} - - + + menuProps={{ size: 'sm', isLazy: true }} + menuButtonProps={{ + bg: 'transparent', + _hover: { bg: 'transparent' }, + }} + renderMenu={() => ( + + } onClickCapture={handleSelectBoard}> + Select Board + + {!board && } + {board && ( + + )} + + )} + > + {children} + ); } ); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx index 61b8856ff9..60be0c4ab3 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx @@ -1,27 +1,21 @@ -import { - Collapse, - Flex, - Grid, - GridItem, - useDisclosure, -} from '@chakra-ui/react'; +import { ButtonGroup, Collapse, Flex, Grid, GridItem } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { AnimatePresence, motion } from 'framer-motion'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; -import { memo, useState } from 'react'; +import { memo, useCallback, useState } from 'react'; +import { FaSearch } from 'react-icons/fa'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; +import { BoardDTO } from 'services/api/types'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; +import DeleteBoardModal from '../DeleteBoardModal'; import AddBoardButton from './AddBoardButton'; -import AllAssetsBoard from './AllAssetsBoard'; -import AllImagesBoard from './AllImagesBoard'; -import BatchBoard from './BatchBoard'; import BoardsSearch from './BoardsSearch'; import GalleryBoard from './GalleryBoard'; -import NoBoardBoard from './NoBoardBoard'; -import DeleteBoardModal from '../DeleteBoardModal'; -import { BoardDTO } from 'services/api/types'; +import SystemBoardButton from './SystemBoardButton'; const selector = createSelector( [stateSelector], @@ -48,7 +42,10 @@ const BoardsList = (props: Props) => { ) : boards; const [boardToDelete, setBoardToDelete] = useState(); - const [searchMode, setSearchMode] = useState(false); + const [isSearching, setIsSearching] = useState(false); + const handleClickSearchIcon = useCallback(() => { + setIsSearching((v) => !v); + }, []); return ( <> @@ -64,7 +61,54 @@ const BoardsList = (props: Props) => { }} > - + + {isSearching ? ( + + + + ) : ( + + + + + + + + )} + + } + /> { - {!searchMode && ( - <> - - - - - - - - - - {isBatchEnabled && ( - - - - )} - - )} {filteredBoards && filteredBoards.map((board) => ( diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx index fffe50f6a7..f556b83d24 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx @@ -10,7 +10,14 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { setBoardSearchText } from 'features/gallery/store/boardSlice'; -import { memo } from 'react'; +import { + ChangeEvent, + KeyboardEvent, + memo, + useCallback, + useEffect, + useRef, +} from 'react'; const selector = createSelector( [stateSelector], @@ -22,31 +29,60 @@ const selector = createSelector( ); type Props = { - setSearchMode: (searchMode: boolean) => void; + setIsSearching: (isSearching: boolean) => void; }; const BoardsSearch = (props: Props) => { - const { setSearchMode } = props; + const { setIsSearching } = props; const dispatch = useAppDispatch(); const { searchText } = useAppSelector(selector); + const inputRef = useRef(null); - const handleBoardSearch = (searchTerm: string) => { - setSearchMode(searchTerm.length > 0); - dispatch(setBoardSearchText(searchTerm)); - }; - const clearBoardSearch = () => { - setSearchMode(false); + const handleBoardSearch = useCallback( + (searchTerm: string) => { + dispatch(setBoardSearchText(searchTerm)); + }, + [dispatch] + ); + + const clearBoardSearch = useCallback(() => { dispatch(setBoardSearchText('')); - }; + setIsSearching(false); + }, [dispatch, setIsSearching]); + + const handleKeydown = useCallback( + (e: KeyboardEvent) => { + // exit search mode on escape + if (e.key === 'Escape') { + clearBoardSearch(); + } + }, + [clearBoardSearch] + ); + + const handleChange = useCallback( + (e: ChangeEvent) => { + handleBoardSearch(e.target.value); + }, + [handleBoardSearch] + ); + + useEffect(() => { + // focus the search box on mount + if (!inputRef.current) { + return; + } + inputRef.current.focus(); + }, []); return ( { - handleBoardSearch(e.target.value); - }} + onKeyDown={handleKeydown} + onChange={handleChange} /> {searchText && searchText.length && ( @@ -55,7 +91,8 @@ const BoardsSearch = (props: Props) => { size="xs" variant="ghost" aria-label="Clear Search" - icon={} + opacity={0.5} + icon={} /> )} diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 46e7cbcca8..5d76ad743c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -6,9 +6,9 @@ import { EditableInput, EditablePreview, Flex, + Icon, Image, Text, - useColorMode, } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/dist/query'; @@ -17,14 +17,12 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDroppable from 'common/components/IAIDroppable'; -import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { boardIdSelected } from 'features/gallery/store/gallerySlice'; -import { memo, useCallback, useMemo } from 'react'; -import { FaUser } from 'react-icons/fa'; +import { memo, useCallback, useMemo, useState } from 'react'; +import { FaFolder } from 'react-icons/fa'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { BoardDTO } from 'services/api/types'; -import { mode } from 'theme/util/mode'; import BoardContextMenu from '../BoardContextMenu'; const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = { @@ -66,8 +64,9 @@ const GalleryBoard = memo( board.cover_image_name ?? skipToken ); - const { colorMode } = useColorMode(); const { board_name, board_id } = board; + const [localBoardName, setLocalBoardName] = useState(board_name); + const handleSelectBoard = useCallback(() => { dispatch(boardIdSelected(board_id)); }, [board_id, dispatch]); @@ -75,10 +74,6 @@ const GalleryBoard = memo( const [updateBoard, { isLoading: isUpdateBoardLoading }] = useUpdateBoardMutation(); - const handleUpdateBoardName = (newBoardName: string) => { - updateBoard({ board_id, changes: { board_name: newBoardName } }); - }; - const droppableData: MoveBoardDropData = useMemo( () => ({ id: board_id, @@ -88,59 +83,116 @@ const GalleryBoard = memo( [board_id] ); + const handleSubmit = useCallback( + (newBoardName: string) => { + if (!newBoardName) { + // empty strings are not allowed + setLocalBoardName(board_name); + return; + } + if (newBoardName === board_name) { + // don't updated the board name if it hasn't changed + return; + } + updateBoard({ board_id, changes: { board_name: newBoardName } }) + .unwrap() + .then((response) => { + // update local state + setLocalBoardName(response.board_name); + }) + .catch(() => { + // revert on error + setLocalBoardName(board_name); + }); + }, + [board_id, board_name, updateBoard] + ); + + const handleChange = useCallback((newBoardName: string) => { + setLocalBoardName(newBoardName); + }, []); + return ( - - + - {(ref) => ( - + + {(ref) => ( - {board.cover_image_name && coverImage?.thumbnail_url && ( - - )} - {!(board.cover_image_name && coverImage?.thumbnail_url) && ( - - )} + + {coverImage?.thumbnail_url ? ( + + ) : ( + + + + )} + + + + + + + + + Move} /> - - - { - handleUpdateBoardName(nextValue); - }} - sx={{ maxW: 'full' }} - > - - - - - - )} - + )} + + ); } diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx index 226100c490..fa7f944a24 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx @@ -17,7 +17,7 @@ type GenericBoardProps = { badgeCount?: number; }; -const formatBadgeCount = (count: number) => +export const formatBadgeCount = (count: number) => Intl.NumberFormat('en-US', { notation: 'compact', maximumFractionDigits: 1, @@ -92,7 +92,7 @@ const GenericBoard = (props: GenericBoardProps) => { h: 'full', alignItems: 'center', fontWeight: isSelected ? 600 : undefined, - fontSize: 'xs', + fontSize: 'sm', color: isSelected ? 'base.900' : 'base.700', _dark: { color: isSelected ? 'base.50' : 'base.200' }, }} diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/SystemBoardButton.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/SystemBoardButton.tsx new file mode 100644 index 0000000000..b538eee9d1 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/SystemBoardButton.tsx @@ -0,0 +1,53 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import { boardIdSelected } from 'features/gallery/store/gallerySlice'; +import { memo, useCallback, useMemo } from 'react'; +import { useBoardName } from 'services/api/hooks/useBoardName'; + +type Props = { + board_id: 'images' | 'assets' | 'no_board'; +}; + +const SystemBoardButton = ({ board_id }: Props) => { + const dispatch = useAppDispatch(); + + const selector = useMemo( + () => + createSelector( + [stateSelector], + ({ gallery }) => { + const { selectedBoardId } = gallery; + return { isSelected: selectedBoardId === board_id }; + }, + defaultSelectorOptions + ), + [board_id] + ); + + const { isSelected } = useAppSelector(selector); + + const boardName = useBoardName(board_id); + + const handleClick = useCallback(() => { + dispatch(boardIdSelected(board_id)); + }, [board_id, dispatch]); + + return ( + + {boardName} + + ); +}; + +export default memo(SystemBoardButton); diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryBoardName.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryBoardName.tsx index 12454dd15b..27565a52aa 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryBoardName.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryBoardName.tsx @@ -4,8 +4,9 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { memo } from 'react'; +import { memo, useMemo } from 'react'; import { useBoardName } from 'services/api/hooks/useBoardName'; +import { useBoardTotal } from 'services/api/hooks/useBoardTotal'; const selector = createSelector( [stateSelector], @@ -26,6 +27,17 @@ const GalleryBoardName = (props: Props) => { const { isOpen, onToggle } = props; const { selectedBoardId } = useAppSelector(selector); const boardName = useBoardName(selectedBoardId); + const numOfBoardImages = useBoardTotal(selectedBoardId); + + const formattedBoardName = useMemo(() => { + if (!boardName || !numOfBoardImages) { + return ''; + } + if (boardName.length > 20) { + return `${boardName.substring(0, 20)}... (${numOfBoardImages})`; + } + return `${boardName} (${numOfBoardImages})`; + }, [boardName, numOfBoardImages]); return ( { }, }} > - {boardName} + {formattedBoardName} diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryPanel.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryPanel.tsx index 2aa44e50a1..1bbec03f3e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryPanel.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryPanel.tsx @@ -109,7 +109,7 @@ const GalleryDrawer = () => { isResizable={true} isOpen={shouldShowGallery} onClose={handleCloseGallery} - minWidth={337} + minWidth={400} > diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index dcce3a1b18..bf627b9591 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -1,4 +1,4 @@ -import { Box } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; @@ -86,38 +86,31 @@ const GalleryImage = (props: HoverableImageProps) => { return ( - - {(ref) => ( - - } - // resetTooltip="Delete image" - // withResetIcon // removed bc it's too easy to accidentally delete images - /> - - )} - + + } + // resetTooltip="Delete image" + // withResetIcon // removed bc it's too easy to accidentally delete images + /> + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts index 0f882f248d..578c4371f2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts @@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = ( beginStepPct, endStepPct, controlMode, + resizeMode, model, processorType, weight, @@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = ( begin_step_percent: beginStepPct, end_step_percent: endStepPct, control_mode: controlMode, + resize_mode: resizeMode, control_model: model as ControlNetInvocation['control_model'], control_weight: weight, }; diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index 94195a27c1..6c683470e7 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -105,7 +105,7 @@ const enabledTabsSelector = createSelector( } ); -const MIN_GALLERY_WIDTH = 300; +const MIN_GALLERY_WIDTH = 350; const DEFAULT_GALLERY_PCT = 20; export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager']; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx index 7ef0b48784..57f2e89ef0 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx @@ -3,7 +3,7 @@ import { memo } from 'react'; import { PanelResizeHandle } from 'react-resizable-panels'; import { mode } from 'theme/util/mode'; -type ResizeHandleProps = FlexProps & { +type ResizeHandleProps = Omit & { direction?: 'horizontal' | 'vertical'; }; diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 52f410e315..5eeb86d9c5 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -169,10 +169,11 @@ export const imagesApi = api.injectEndpoints({ ], async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) { /** - * Cache changes for deleteImage: - * - Remove from "All Images" - * - Remove from image's `board_id` if it has one, or "No Board" if not - * - Remove from "Batch" + * Cache changes for `deleteImage`: + * - *remove* from "All Images" / "All Assets" + * - IF it has a board: + * - THEN *remove* from it's own board + * - ELSE *remove* from "No Board" */ const { image_name, board_id, image_category } = imageDTO; @@ -181,22 +182,23 @@ export const imagesApi = api.injectEndpoints({ // That means constructing the possible query args that are serialized into the cache key... const removeFromCacheKeys: ListImagesArgs[] = []; + + // determine `categories`, i.e. do we update "All Images" or "All Assets" const categories = IMAGE_CATEGORIES.includes(image_category) ? IMAGE_CATEGORIES : ASSETS_CATEGORIES; - // All Images board (e.g. no board) + // remove from "All Images" removeFromCacheKeys.push({ categories }); - // Board specific if (board_id) { + // remove from it's own board removeFromCacheKeys.push({ board_id }); } else { - // TODO: No Board + // remove from "No Board" + removeFromCacheKeys.push({ board_id: 'none' }); } - // TODO: Batch - const patches: PatchCollection[] = []; removeFromCacheKeys.forEach((cacheKey) => { patches.push( @@ -240,32 +242,37 @@ export const imagesApi = api.injectEndpoints({ { imageDTO: oldImageDTO, changes: _changes }, { dispatch, queryFulfilled, getState } ) { - // TODO: Should we handle changes to boards via this mutation? Seems reasonable... - // let's be extra-sure we do not accidentally change categories const changes = omit(_changes, 'image_category'); /** - * Cache changes for `updateImage`: - * - Update the ImageDTO - * - Update the image in "All Images" board: - * - IF it is in the date range represented by the cache: - * - add the image IF it is not already in the cache & update the total - * - ELSE update the image IF it is already in the cache + * Cache changes for "updateImage": + * - *update* "getImageDTO" cache + * - for "All Images" || "All Assets": + * - IF it is not already in the cache + * - THEN *add* it to "All Images" / "All Assets" and update the total + * - ELSE *update* it * - IF the image has a board: - * - Update the image in it's own board - * - ELSE Update the image in the "No Board" board (TODO) + * - THEN *update* it's own board + * - ELSE *update* the "No Board" board */ const patches: PatchCollection[] = []; - const { image_name, board_id, image_category } = oldImageDTO; + const { image_name, board_id, image_category, is_intermediate } = + oldImageDTO; + + const isChangingFromIntermediate = changes.is_intermediate === false; + // do not add intermediates to gallery cache + if (is_intermediate && !isChangingFromIntermediate) { + return; + } + + // determine `categories`, i.e. do we update "All Images" or "All Assets" const categories = IMAGE_CATEGORIES.includes(image_category) ? IMAGE_CATEGORIES : ASSETS_CATEGORIES; - // TODO: No Board - - // Update `getImageDTO` cache + // update `getImageDTO` cache patches.push( dispatch( imagesApi.util.updateQueryData( @@ -281,9 +288,13 @@ export const imagesApi = api.injectEndpoints({ // Update the "All Image" or "All Assets" board const queryArgsToUpdate: ListImagesArgs[] = [{ categories }]; + // IF the image has a board: if (board_id) { - // We also need to update the user board + // THEN update it's own board queryArgsToUpdate.push({ board_id }); + } else { + // ELSE update the "No Board" board + queryArgsToUpdate.push({ board_id: 'none' }); } queryArgsToUpdate.forEach((queryArg) => { @@ -371,12 +382,12 @@ export const imagesApi = api.injectEndpoints({ return; } - // Add the image to the "All Images" / "All Assets" board - const queryArg = { - categories: IMAGE_CATEGORIES.includes(image_category) - ? IMAGE_CATEGORIES - : ASSETS_CATEGORIES, - }; + // determine `categories`, i.e. do we update "All Images" or "All Assets" + const categories = IMAGE_CATEGORIES.includes(image_category) + ? IMAGE_CATEGORIES + : ASSETS_CATEGORIES; + + const queryArg = { categories }; dispatch( imagesApi.util.updateQueryData('listImages', queryArg, (draft) => { @@ -410,16 +421,14 @@ export const imagesApi = api.injectEndpoints({ { dispatch, queryFulfilled, getState } ) { /** - * Cache changes for addImageToBoard: - * - Remove from "No Board" - * - Remove from `old_board_id` if it has one - * - Add to new `board_id` - * - IF the image's `created_at` is within the range of the board's cached images + * Cache changes for `addImageToBoard`: + * - *update* the `getImageDTO` cache + * - *remove* from "No Board" + * - IF the image has an old `board_id`: + * - THEN *remove* from it's old `board_id` + * - IF the image's `created_at` is within the range of the board's cached images * - OR the board cache has length of 0 or 1 - * - Update the `total` for each board whose cache is updated - * - Update the ImageDTO - * - * TODO: maybe total should just be updated in the boards endpoints? + * - THEN *add* it to new `board_id` */ const { image_name, board_id: old_board_id } = oldImageDTO; @@ -427,13 +436,10 @@ export const imagesApi = api.injectEndpoints({ // Figure out the `listImages` caches that we need to update const removeFromQueryArgs: ListImagesArgs[] = []; - // TODO: No Board - // TODO: Batch - - // Remove from No Board + // remove from "No Board" removeFromQueryArgs.push({ board_id: 'none' }); - // Remove from old board + // remove from old board if (old_board_id) { removeFromQueryArgs.push({ board_id: old_board_id }); } @@ -534,17 +540,15 @@ export const imagesApi = api.injectEndpoints({ { dispatch, queryFulfilled, getState } ) { /** - * Cache changes for removeImageFromBoard: - * - Add to "No Board" - * - IF the image's `created_at` is within the range of the board's cached images - * - Remove from `old_board_id` - * - Update the ImageDTO + * Cache changes for `removeImageFromBoard`: + * - *update* `getImageDTO` + * - IF the image's `created_at` is within the range of the board's cached images + * - THEN *add* to "No Board" + * - *remove* from `old_board_id` */ const { image_name, board_id: old_board_id } = imageDTO; - // TODO: Batch - const patches: PatchCollection[] = []; // Updated imageDTO with new board_id diff --git a/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts b/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts index d63b6e0425..cbe0ec1808 100644 --- a/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts +++ b/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts @@ -6,9 +6,9 @@ export const useBoardName = (board_id: BoardId | null | undefined) => { selectFromResult: ({ data }) => { let boardName = ''; if (board_id === 'images') { - boardName = 'All Images'; + boardName = 'Images'; } else if (board_id === 'assets') { - boardName = 'All Assets'; + boardName = 'Assets'; } else if (board_id === 'no_board') { boardName = 'No Board'; } else if (board_id === 'batch') { diff --git a/invokeai/frontend/web/src/services/api/hooks/useBoardTotal.ts b/invokeai/frontend/web/src/services/api/hooks/useBoardTotal.ts new file mode 100644 index 0000000000..8deccd8947 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/useBoardTotal.ts @@ -0,0 +1,53 @@ +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { + ASSETS_CATEGORIES, + BoardId, + IMAGE_CATEGORIES, + INITIAL_IMAGE_LIMIT, +} from 'features/gallery/store/gallerySlice'; +import { useMemo } from 'react'; +import { ListImagesArgs, useListImagesQuery } from '../endpoints/images'; + +const baseQueryArg: ListImagesArgs = { + offset: 0, + limit: INITIAL_IMAGE_LIMIT, + is_intermediate: false, +}; + +const imagesQueryArg: ListImagesArgs = { + categories: IMAGE_CATEGORIES, + ...baseQueryArg, +}; + +const assetsQueryArg: ListImagesArgs = { + categories: ASSETS_CATEGORIES, + ...baseQueryArg, +}; + +const noBoardQueryArg: ListImagesArgs = { + board_id: 'none', + ...baseQueryArg, +}; + +export const useBoardTotal = (board_id: BoardId | null | undefined) => { + const queryArg = useMemo(() => { + if (!board_id) { + return; + } + if (board_id === 'images') { + return imagesQueryArg; + } else if (board_id === 'assets') { + return assetsQueryArg; + } else if (board_id === 'no_board') { + return noBoardQueryArg; + } else { + return { board_id, ...baseQueryArg }; + } + }, [board_id]); + + const { total } = useListImagesQuery(queryArg ?? skipToken, { + selectFromResult: ({ currentData }) => ({ total: currentData?.total }), + }); + + return total; +}; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 6a2e176ffd..3ecef092af 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -167,7 +167,7 @@ export type paths = { "/api/v1/images/clear-intermediates": { /** * Clear Intermediates - * @description Clears first 100 intermediates + * @description Clears all intermediates */ post: operations["clear_intermediates"]; }; @@ -800,6 +800,13 @@ export type components = { * @enum {string} */ control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced"; + /** + * Resize Mode + * @description The resize mode to use + * @default just_resize + * @enum {string} + */ + resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple"; }; /** * ControlNetInvocation @@ -859,6 +866,13 @@ export type components = { * @enum {string} */ control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced"; + /** + * Resize Mode + * @description The resize mode used + * @default just_resize + * @enum {string} + */ + resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple"; }; /** ControlNetModelConfig */ ControlNetModelConfig: { @@ -5324,11 +5338,11 @@ export type components = { image?: components["schemas"]["ImageField"]; }; /** - * StableDiffusion2ModelFormat + * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -5336,11 +5350,11 @@ export type components = { */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion1ModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -6125,7 +6139,7 @@ export type operations = { }; /** * Clear Intermediates - * @description Clears first 100 intermediates + * @description Clears all intermediates */ clear_intermediates: { responses: { diff --git a/invokeai/frontend/web/src/theme/colors/colors.ts b/invokeai/frontend/web/src/theme/colors/colors.ts index bcb2e43c0b..99260ee071 100644 --- a/invokeai/frontend/web/src/theme/colors/colors.ts +++ b/invokeai/frontend/web/src/theme/colors/colors.ts @@ -2,11 +2,16 @@ import { InvokeAIThemeColors } from 'theme/themeTypes'; import { generateColorPalette } from 'theme/util/generateColorPalette'; const BASE = { H: 220, S: 16 }; -const ACCENT = { H: 250, S: 52 }; -const WORKING = { H: 47, S: 50 }; -const WARNING = { H: 28, S: 50 }; -const OK = { H: 113, S: 50 }; -const ERROR = { H: 0, S: 50 }; +const ACCENT = { H: 250, S: 42 }; +// const ACCENT = { H: 250, S: 52 }; +const WORKING = { H: 47, S: 42 }; +// const WORKING = { H: 47, S: 50 }; +const WARNING = { H: 28, S: 42 }; +// const WARNING = { H: 28, S: 50 }; +const OK = { H: 113, S: 42 }; +// const OK = { H: 113, S: 50 }; +const ERROR = { H: 0, S: 42 }; +// const ERROR = { H: 0, S: 50 }; export const InvokeAIColors: InvokeAIThemeColors = { base: generateColorPalette(BASE.H, BASE.S), diff --git a/invokeai/frontend/web/src/theme/components/editable.ts b/invokeai/frontend/web/src/theme/components/editable.ts new file mode 100644 index 0000000000..19321e5968 --- /dev/null +++ b/invokeai/frontend/web/src/theme/components/editable.ts @@ -0,0 +1,56 @@ +import { editableAnatomy as parts } from '@chakra-ui/anatomy'; +import { + createMultiStyleConfigHelpers, + defineStyle, +} from '@chakra-ui/styled-system'; +import { mode } from '@chakra-ui/theme-tools'; + +const { definePartsStyle, defineMultiStyleConfig } = + createMultiStyleConfigHelpers(parts.keys); + +const baseStylePreview = defineStyle({ + borderRadius: 'md', + py: '1', + transitionProperty: 'common', + transitionDuration: 'normal', +}); + +const baseStyleInput = defineStyle((props) => ({ + borderRadius: 'md', + py: '1', + transitionProperty: 'common', + transitionDuration: 'normal', + width: 'full', + _focusVisible: { boxShadow: 'outline' }, + _placeholder: { opacity: 0.6 }, + '::selection': { + color: mode('accent.900', 'accent.50')(props), + bg: mode('accent.200', 'accent.400')(props), + }, +})); + +const baseStyleTextarea = defineStyle({ + borderRadius: 'md', + py: '1', + transitionProperty: 'common', + transitionDuration: 'normal', + width: 'full', + _focusVisible: { boxShadow: 'outline' }, + _placeholder: { opacity: 0.6 }, +}); + +const invokeAI = definePartsStyle((props) => ({ + preview: baseStylePreview, + input: baseStyleInput(props), + textarea: baseStyleTextarea, +})); + +export const editableTheme = defineMultiStyleConfig({ + variants: { + invokeAI, + }, + defaultProps: { + size: 'sm', + variant: 'invokeAI', + }, +}); diff --git a/invokeai/frontend/web/src/theme/theme.ts b/invokeai/frontend/web/src/theme/theme.ts index 42a5a12c3f..6f7a719e85 100644 --- a/invokeai/frontend/web/src/theme/theme.ts +++ b/invokeai/frontend/web/src/theme/theme.ts @@ -4,6 +4,7 @@ import { InvokeAIColors } from './colors/colors'; import { accordionTheme } from './components/accordion'; import { buttonTheme } from './components/button'; import { checkboxTheme } from './components/checkbox'; +import { editableTheme } from './components/editable'; import { formLabelTheme } from './components/formLabel'; import { inputTheme } from './components/input'; import { menuTheme } from './components/menu'; @@ -72,7 +73,17 @@ export const theme: ThemeOverride = { selected: { light: '0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', - dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', + dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-500)', + }, + hoverSelected: { + light: + '0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-500)', + dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-300)', + }, + hoverUnselected: { + light: + '0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-200)', + dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-600)', }, nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`, }, @@ -80,6 +91,7 @@ export const theme: ThemeOverride = { components: { Button: buttonTheme, // Button and IconButton Input: inputTheme, + Editable: editableTheme, Textarea: textareaTheme, Tabs: tabsTheme, Progress: progressTheme, diff --git a/invokeai/frontend/web/src/theme/util/getInputOutlineStyles.ts b/invokeai/frontend/web/src/theme/util/getInputOutlineStyles.ts index 8cf64cbd94..ba5fc9e4c1 100644 --- a/invokeai/frontend/web/src/theme/util/getInputOutlineStyles.ts +++ b/invokeai/frontend/web/src/theme/util/getInputOutlineStyles.ts @@ -37,4 +37,7 @@ export const getInputOutlineStyles = (props: StyleFunctionProps) => ({ _placeholder: { color: mode('base.700', 'base.400')(props), }, + '::selection': { + bg: mode('accent.200', 'accent.400')(props), + }, });