Merge branch 'main' into nodepromptsize

This commit is contained in:
mickr777 2023-07-21 08:07:55 +10:00 committed by GitHub
commit 98b2734240
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1274 additions and 362 deletions

View File

@ -1,9 +1,22 @@
from enum import Enum
from fastapi import Body
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.version import __version__ from invokeai.version import __version__
from ..dependencies import ApiDependencies
from invokeai.backend.util.logging import logging
class LogLevel(int, Enum):
NotSet = logging.NOTSET
Debug = logging.DEBUG
Info = logging.INFO
Warning = logging.WARNING
Error = logging.ERROR
Critical = logging.CRITICAL
app_router = APIRouter(prefix="/v1/app", tags=["app"]) app_router = APIRouter(prefix="/v1/app", tags=["app"])
@ -34,3 +47,27 @@ async def get_config() -> AppConfig:
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch') infill_methods.append('patchmatch')
return AppConfig(infill_methods=infill_methods) return AppConfig(infill_methods=infill_methods)
@app_router.get(
"/logging",
operation_id="get_log_level",
responses={200: {"description" : "The operation was successful"}},
response_model = LogLevel,
)
async def get_log_level(
) -> LogLevel:
"""Returns the log level"""
return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.post(
"/logging",
operation_id="set_log_level",
responses={200: {"description" : "The operation was successful"}},
response_model = LogLevel,
)
async def set_log_level(
level: LogLevel = Body(description="New log verbosity level"),
) -> LogLevel:
"""Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level)
return LogLevel(ApiDependencies.invoker.services.logger.level)

View File

@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple( CONTROLNET_MODE_VALUES = Literal[tuple(
["balanced", "more_prompt", "more_control", "unbalanced"])] ["balanced", "more_prompt", "more_control", "unbalanced"])]
# crop and fill options not ready yet CONTROLNET_RESIZE_VALUES = Literal[tuple(
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] ["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
class ControlNetModelField(BaseModel): class ControlNetModelField(BaseModel):
@ -111,7 +111,8 @@ class ControlField(BaseModel):
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field( control_mode: CONTROLNET_MODE_VALUES = Field(
default="balanced", description="The control mode to use") default="balanced", description="The control mode to use")
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(
default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def validate_control_weight(cls, v): def validate_control_weight(cls, v):
@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation):
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
control_mode=self.control_mode, control_mode=self.control_mode,
resize_mode=self.resize_mode,
), ),
) )

View File

@ -30,6 +30,7 @@ from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .image import ImageOutput from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation):
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance? # and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image( control_image = prepare_control_image(
image=input_image, image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize, width=control_width_resize,
@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation):
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
) )
control_item = ControlNetData( control_item = ControlNetData(
model=control_model, image_tensor=control_image, model=control_model,
image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
) )
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
@ -601,7 +607,7 @@ class ResizeLatentsInvocation(BaseInvocation):
antialias: bool = Field( antialias: bool = Field(
default=False, default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)") description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
@ -647,7 +653,7 @@ class ScaleLatentsInvocation(BaseInvocation):
antialias: bool = Field( antialias: bool = Field(
default=False, default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)") description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
@ -758,7 +764,7 @@ class ImageToLatentsInvocation(BaseInvocation):
dtype=vae.dtype dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible! ) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"

View File

@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
from pydantic import Field, validator from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
@ -243,10 +244,31 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
}, },
} }
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on # based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.noise.latents_name) latents = context.services.latents.get(self.noise.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@ -341,6 +363,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
else: else:
@ -409,6 +432,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
@ -473,10 +497,31 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
}, },
} }
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on # based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@ -579,6 +624,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
else: else:
@ -647,6 +693,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update() progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)

View File

@ -0,0 +1,342 @@
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers.utils import PIL_INTERPOLATION
from einops import rearrange
from controlnet_aux.util import HWC3, resize_image
###################################################################
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
###################################################################
# High Quality Edge Thinning using Pure Python
# Written by Lvmin Zhangu
# 2023 April
# Stanford University
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
lvmin_kernels_raw = [
np.array([
[-1, -1, -1],
[0, 1, 0],
[1, 1, 1]
], dtype=np.int32),
np.array([
[0, -1, -1],
[1, 1, -1],
[0, 1, 0]
], dtype=np.int32)
]
lvmin_kernels = []
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_prunings_raw = [
np.array([
[-1, -1, -1],
[-1, 1, -1],
[0, 0, -1]
], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[-1, 0, 0]
], dtype=np.int32)
]
lvmin_prunings = []
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
def remove_pattern(x, kernel):
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
objects = np.where(objects > 127)
x[objects] = 0
return x, objects[0].shape[0] > 0
def thin_one_time(x, kernels):
y = x
is_done = True
for k in kernels:
y, has_update = remove_pattern(y, k)
if has_update:
is_done = False
return y, is_done
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break
if prunings:
y, _ = thin_one_time(y, lvmin_prunings)
return y
def nake_nms(x):
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
return y
################################################################################
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
################################################################################
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
def pixel_perfect_resolution(
image: np.ndarray,
target_H: int,
target_W: int,
resize_mode: str,
) -> int:
"""
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
The function first calculates scaling factors for height and width of the image based on the target
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
scaling factor to estimate the new resolution.
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
fits within the target dimensions, potentially leaving some empty space.
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
dimensions are fully filled, potentially cropping the image.
After calculating the estimated resolution, the function prints some debugging information.
Args:
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
target_H (int): The target height for the image.
target_W (int): The target width for the image.
resize_mode (ResizeMode): The mode for resizing.
Returns:
int: The estimated resolution after resizing.
"""
raw_H, raw_W, _ = image.shape
k0 = float(target_H) / float(raw_H)
k1 = float(target_W) / float(raw_W)
if resize_mode == "fill_resize":
estimation = min(k0, k1) * float(min(raw_H, raw_W))
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
estimation = max(k0, k1) * float(min(raw_H, raw_W))
# print(f"Pixel Perfect Computation:")
# print(f"resize_mode = {resize_mode}")
# print(f"raw_H = {raw_H}")
# print(f"raw_W = {raw_W}")
# print(f"target_H = {target_H}")
# print(f"target_W = {target_W}")
# print(f"estimation = {estimation}")
return int(np.round(estimation))
###########################################################################
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
# modified for InvokeAI
###########################################################################
# def detectmap_proc(detected_map, module, resize_mode, h, w):
def np_img_resize(
np_img: np.ndarray,
resize_mode: str,
h: int,
w: int,
device: torch.device = torch.device('cpu')
):
# if 'inpaint' in module:
# np_img = np_img.astype(np.float32)
# else:
# np_img = HWC3(np_img)
np_img = HWC3(np_img)
def safe_numpy(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = y.copy()
y = np.ascontiguousarray(y)
y = y.copy()
return y
def get_pytorch_control(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = torch.from_numpy(y)
y = y.float() / 255.0
y = rearrange(y, 'h w c -> 1 c h w')
y = y.clone()
# y = y.to(devices.get_device_for("controlnet"))
y = y.to(device)
y = y.clone()
return y
def high_quality_resize(x: np.ndarray,
size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
is_binary = np.min(x) < 16 and np.max(x) > 240
if is_binary:
xc = x
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
all_edge_count = np.where(x > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
y = nake_nms(y)
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = lvmin_thin(y, prunings=new_size_is_bigger)
else:
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = np.stack([y] * 3, axis=2)
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y
# if resize_mode == external_code.ResizeMode.RESIZE:
if resize_mode == "just_resize": # RESIZE
np_img = high_quality_resize(np_img, (w, h))
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
old_h, old_w, _ = np_img.shape
old_w = float(old_w)
old_h = float(old_h)
k0 = float(h) / old_h
k1 = float(w) / old_w
safeint = lambda x: int(np.round(x))
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == "fill_resize": # OUTER_FIT
k = min(k0, k1)
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
if len(high_quality_border_color) == 4:
# Inpaint hijack
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
np_img = high_quality_background
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1)
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
image: Image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
# batch_size=1, # currently no batching
# num_images_per_prompt=1, # currently only single image
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced",
resize_mode="just_resize_simple",
):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
if (resize_mode == "just_resize_simple" or
resize_mode == "crop_resize_simple" or
resize_mode == "fill_resize_simple"):
image = image.convert("RGB")
if (resize_mode == "just_resize_simple"):
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif (resize_mode == "crop_resize_simple"): # not yet implemented
pass
elif (resize_mode == "fill_resize_simple"): # not yet implemented
pass
nimage = np.array(image)
nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0)
# normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
nimage = np.array(nimage).astype(np.float32) / 255.0
nimage = nimage.transpose(0, 3, 1, 2)
timage = torch.from_numpy(nimage)
# use fancy lvmin controlnet resizing
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
nimage = np.array(image)
timage, nimage = np_img_resize(
np_img=nimage,
resize_mode=resize_mode,
h=height,
w=width,
# device=torch.device('cpu')
device=device,
)
else:
pass
print("ERROR: invalid resize_mode ==> ", resize_mode)
exit(1)
timage = timage.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
timage = torch.cat([timage] * 2)
return timage

View File

@ -1,9 +1,30 @@
import torch
from PIL import Image
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()
return Image.fromarray(latents_ubyte.numpy())
def stable_diffusion_step_callback( def stable_diffusion_step_callback(
@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
# step = intermediate_state.step # step = intermediate_state.step
# TODO: only output a preview image when requested # TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size (width, height) = image.size
width *= 8 width *= 8
@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
step=intermediate_state.step, step=intermediate_state.step,
total_steps=node["steps"], total_steps=node["steps"],
) )
def stable_diffusion_xl_step_callback(
context: InvocationContext,
node: dict,
source_node_id: str,
sample,
step,
total_steps,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[ 0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[ 0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
#[ 0.0478, 0.1285, 0.0478],
#[ 0.1285, 0.2948, 0.1285],
#[ 0.0478, 0.1285, 0.0478],
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=step,
total_steps=total_steps,
)

View File

@ -219,6 +219,7 @@ class ControlNetData:
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced") control_mode: str = Field(default="balanced")
resize_mode: str = Field(default="just_resize")
@dataclass @dataclass
@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection: if cfg_injection:
# Inferred ControlNet only for the conditional batch. # Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
@staticmethod
def prepare_control_image(
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
return image

View File

@ -17,13 +17,13 @@ import {
} from 'common/components/IAIImageFallback'; } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react'; import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa'; import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO, PostUploadAction } from 'services/api/types'; import { ImageDTO, PostUploadAction } from 'services/api/types';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable'; import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable'; import IAIDroppable from './IAIDroppable';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
type IAIDndImageProps = { type IAIDndImageProps = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
@ -148,7 +148,9 @@ const IAIDndImage = (props: IAIDndImageProps) => {
maxH: 'full', maxH: 'full',
borderRadius: 'base', borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined, shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined }, _dark: {
shadow: isSelected ? 'selected.dark' : undefined,
},
...imageSx, ...imageSx,
}} }}
/> />
@ -183,13 +185,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} {!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && ( {imageDTO && !isDragDisabled && (
<IAIDraggable <IAIDraggable
data={draggableData} data={draggableData}
@ -197,6 +192,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick} onClick={onClick}
/> />
)} )}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{onClickReset && withResetIcon && imageDTO && ( {onClickReset && withResetIcon && imageDTO && (
<IAIIconButton <IAIIconButton
onClick={onClickReset} onClick={onClickReset}

View File

@ -13,10 +13,11 @@ type IAIDroppableProps = {
dropLabel?: ReactNode; dropLabel?: ReactNode;
disabled?: boolean; disabled?: boolean;
data?: TypesafeDroppableData; data?: TypesafeDroppableData;
hoverRef?: React.Ref<HTMLDivElement>;
}; };
const IAIDroppable = (props: IAIDroppableProps) => { const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props; const { dropLabel, data, disabled, hoverRef } = props;
const dndId = useRef(uuidv4()); const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppable({ const { isOver, setNodeRef, active } = useDroppable({

View File

@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
type ControlNetProps = { type ControlNetProps = {
controlNetId: string; controlNetId: string;
@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => {
<Flex <Flex
sx={{ sx={{
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 3,
p: 3, p: 3,
borderRadius: 'base', borderRadius: 'base',
position: 'relative', position: 'relative',
@ -117,7 +118,12 @@ const ControlNet = (props: ControlNetProps) => {
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'} tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'} aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded} onClick={toggleIsExpanded}
variant="link" variant="ghost"
sx={{
_hover: {
bg: 'none',
},
}}
icon={ icon={
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
@ -151,7 +157,7 @@ const ControlNet = (props: ControlNetProps) => {
/> />
)} )}
</Flex> </Flex>
<Flex sx={{ w: 'full', flexDirection: 'column' }}> <Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}> <Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex <Flex
sx={{ sx={{
@ -176,16 +182,16 @@ const ControlNet = (props: ControlNetProps) => {
h: 28, h: 28,
w: 28, w: 28,
aspectRatio: '1/1', aspectRatio: '1/1',
mt: 3,
}} }}
> >
<ControlNetImagePreview controlNetId={controlNetId} height={28} /> <ControlNetImagePreview controlNetId={controlNetId} height={28} />
</Flex> </Flex>
)} )}
</Flex> </Flex>
<Box mt={2}> <Flex sx={{ gap: 2 }}>
<ParamControlNetControlMode controlNetId={controlNetId} /> <ParamControlNetControlMode controlNetId={controlNetId} />
</Box> <ParamControlNetResizeMode controlNetId={controlNetId} />
</Flex>
<ParamControlNetProcessorSelect controlNetId={controlNetId} /> <ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex> </Flex>

View File

@ -0,0 +1,62 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ResizeModes,
controlNetResizeModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetResizeModeProps = {
controlNetId: string;
};
const RESIZE_MODE_DATA = [
{ label: 'Resize', value: 'just_resize' },
{ label: 'Crop', value: 'crop_resize' },
{ label: 'Fill', value: 'fill_resize' },
];
export default function ParamControlNetResizeMode(
props: ParamControlNetResizeModeProps
) {
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { resizeMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { resizeMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { resizeMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation();
const handleResizeModeChange = useCallback(
(resizeMode: ResizeModes) => {
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
},
[controlNetId, dispatch]
);
return (
<IAIMantineSelect
disabled={!isEnabled}
label="Resize Mode"
data={RESIZE_MODE_DATA}
value={String(resizeMode)}
onChange={handleResizeModeChange}
/>
);
}

View File

@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es'; import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session'; import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions'; import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions'; import { controlNetImageProcessed } from './actions';
@ -16,11 +17,13 @@ import {
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
export type ControlModes = export type ControlModes = NonNullable<
| 'balanced' components['schemas']['ControlNetInvocation']['control_mode']
| 'more_prompt' >;
| 'more_control'
| 'unbalanced'; export type ResizeModes = NonNullable<
components['schemas']['ControlNetInvocation']['resize_mode']
>;
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
@ -29,6 +32,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
controlMode: 'balanced', controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null, controlImage: null,
processedControlImage: null, processedControlImage: null,
processorType: 'canny_image_processor', processorType: 'canny_image_processor',
@ -45,6 +49,7 @@ export type ControlNetConfig = {
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
controlMode: ControlModes; controlMode: ControlModes;
resizeMode: ResizeModes;
controlImage: string | null; controlImage: string | null;
processedControlImage: string | null; processedControlImage: string | null;
processorType: ControlNetProcessorType; processorType: ControlNetProcessorType;
@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({
const { controlNetId, controlMode } = action.payload; const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode; state.controlNets[controlNetId].controlMode = controlMode;
}, },
controlNetResizeModeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
resizeMode: ResizeModes;
}>
) => {
const { controlNetId, resizeMode } = action.payload;
state.controlNets[controlNetId].resizeMode = resizeMode;
},
controlNetProcessorParamsChanged: ( controlNetProcessorParamsChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -342,6 +357,7 @@ export const {
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
controlNetControlModeChanged, controlNetControlModeChanged,
controlNetResizeModeChanged,
controlNetProcessorParamsChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged, controlNetProcessorTypeChanged,
controlNetReset, controlNetReset,

View File

@ -23,34 +23,32 @@ const BoardContextMenu = memo(
dispatch(boardIdSelected(board?.board_id ?? board_id)); dispatch(boardIdSelected(board?.board_id ?? board_id));
}, [board?.board_id, board_id, dispatch]); }, [board?.board_id, board_id, dispatch]);
return ( return (
<Box sx={{ touchAction: 'none', height: 'full' }}> <ContextMenu<HTMLDivElement>
<ContextMenu<HTMLDivElement> menuProps={{ size: 'sm', isLazy: true }}
menuProps={{ size: 'sm', isLazy: true }} menuButtonProps={{
menuButtonProps={{ bg: 'transparent',
bg: 'transparent', _hover: { bg: 'transparent' },
_hover: { bg: 'transparent' }, }}
}} renderMenu={() => (
renderMenu={() => ( <MenuList
<MenuList sx={{ visibility: 'visible !important' }}
sx={{ visibility: 'visible !important' }} motionProps={menuListMotionProps}
motionProps={menuListMotionProps} >
> <MenuItem icon={<FaFolder />} onClickCapture={handleSelectBoard}>
<MenuItem icon={<FaFolder />} onClickCapture={handleSelectBoard}> Select Board
Select Board </MenuItem>
</MenuItem> {!board && <SystemBoardContextMenuItems board_id={board_id} />}
{!board && <SystemBoardContextMenuItems board_id={board_id} />} {board && (
{board && ( <GalleryBoardContextMenuItems
<GalleryBoardContextMenuItems board={board}
board={board} setBoardToDelete={setBoardToDelete}
setBoardToDelete={setBoardToDelete} />
/> )}
)} </MenuList>
</MenuList> )}
)} >
> {children}
{children} </ContextMenu>
</ContextMenu>
</Box>
); );
} }
); );

View File

@ -1,27 +1,21 @@
import { import { ButtonGroup, Collapse, Flex, Grid, GridItem } from '@chakra-ui/react';
Collapse,
Flex,
Grid,
GridItem,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import { AnimatePresence, motion } from 'framer-motion';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useState } from 'react'; import { memo, useCallback, useState } from 'react';
import { FaSearch } from 'react-icons/fa';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { BoardDTO } from 'services/api/types';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
import DeleteBoardModal from '../DeleteBoardModal';
import AddBoardButton from './AddBoardButton'; import AddBoardButton from './AddBoardButton';
import AllAssetsBoard from './AllAssetsBoard';
import AllImagesBoard from './AllImagesBoard';
import BatchBoard from './BatchBoard';
import BoardsSearch from './BoardsSearch'; import BoardsSearch from './BoardsSearch';
import GalleryBoard from './GalleryBoard'; import GalleryBoard from './GalleryBoard';
import NoBoardBoard from './NoBoardBoard'; import SystemBoardButton from './SystemBoardButton';
import DeleteBoardModal from '../DeleteBoardModal';
import { BoardDTO } from 'services/api/types';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -48,7 +42,10 @@ const BoardsList = (props: Props) => {
) )
: boards; : boards;
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>(); const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
const [searchMode, setSearchMode] = useState(false); const [isSearching, setIsSearching] = useState(false);
const handleClickSearchIcon = useCallback(() => {
setIsSearching((v) => !v);
}, []);
return ( return (
<> <>
@ -64,7 +61,54 @@ const BoardsList = (props: Props) => {
}} }}
> >
<Flex sx={{ gap: 2, alignItems: 'center' }}> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<BoardsSearch setSearchMode={setSearchMode} /> <AnimatePresence mode="popLayout">
{isSearching ? (
<motion.div
key="boards-search"
initial={{
opacity: 0,
}}
exit={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
style={{ width: '100%' }}
>
<BoardsSearch setIsSearching={setIsSearching} />
</motion.div>
) : (
<motion.div
key="system-boards-select"
initial={{
opacity: 0,
}}
exit={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
style={{ width: '100%' }}
>
<ButtonGroup sx={{ w: 'full', ps: 1.5 }} isAttached>
<SystemBoardButton board_id="images" />
<SystemBoardButton board_id="assets" />
<SystemBoardButton board_id="no_board" />
</ButtonGroup>
</motion.div>
)}
</AnimatePresence>
<IAIIconButton
aria-label="Search Boards"
size="sm"
isChecked={isSearching}
onClick={handleClickSearchIcon}
icon={<FaSearch />}
/>
<AddBoardButton /> <AddBoardButton />
</Flex> </Flex>
<OverlayScrollbarsComponent <OverlayScrollbarsComponent
@ -82,29 +126,10 @@ const BoardsList = (props: Props) => {
<Grid <Grid
className="list-container" className="list-container"
sx={{ sx={{
gridTemplateRows: '6.5rem 6.5rem', gridTemplateColumns: `repeat(auto-fill, minmax(96px, 1fr));`,
gridAutoFlow: 'column dense', maxH: 346,
gridAutoColumns: '5rem',
}} }}
> >
{!searchMode && (
<>
<GridItem sx={{ p: 1.5 }}>
<AllImagesBoard isSelected={selectedBoardId === 'images'} />
</GridItem>
<GridItem sx={{ p: 1.5 }}>
<AllAssetsBoard isSelected={selectedBoardId === 'assets'} />
</GridItem>
<GridItem sx={{ p: 1.5 }}>
<NoBoardBoard isSelected={selectedBoardId === 'no_board'} />
</GridItem>
{isBatchEnabled && (
<GridItem sx={{ p: 1.5 }}>
<BatchBoard isSelected={selectedBoardId === 'batch'} />
</GridItem>
)}
</>
)}
{filteredBoards && {filteredBoards &&
filteredBoards.map((board) => ( filteredBoards.map((board) => (
<GridItem key={board.board_id} sx={{ p: 1.5 }}> <GridItem key={board.board_id} sx={{ p: 1.5 }}>

View File

@ -10,7 +10,14 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { setBoardSearchText } from 'features/gallery/store/boardSlice'; import { setBoardSearchText } from 'features/gallery/store/boardSlice';
import { memo } from 'react'; import {
ChangeEvent,
KeyboardEvent,
memo,
useCallback,
useEffect,
useRef,
} from 'react';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -22,31 +29,60 @@ const selector = createSelector(
); );
type Props = { type Props = {
setSearchMode: (searchMode: boolean) => void; setIsSearching: (isSearching: boolean) => void;
}; };
const BoardsSearch = (props: Props) => { const BoardsSearch = (props: Props) => {
const { setSearchMode } = props; const { setIsSearching } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { searchText } = useAppSelector(selector); const { searchText } = useAppSelector(selector);
const inputRef = useRef<HTMLInputElement>(null);
const handleBoardSearch = (searchTerm: string) => { const handleBoardSearch = useCallback(
setSearchMode(searchTerm.length > 0); (searchTerm: string) => {
dispatch(setBoardSearchText(searchTerm)); dispatch(setBoardSearchText(searchTerm));
}; },
const clearBoardSearch = () => { [dispatch]
setSearchMode(false); );
const clearBoardSearch = useCallback(() => {
dispatch(setBoardSearchText('')); dispatch(setBoardSearchText(''));
}; setIsSearching(false);
}, [dispatch, setIsSearching]);
const handleKeydown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
// exit search mode on escape
if (e.key === 'Escape') {
clearBoardSearch();
}
},
[clearBoardSearch]
);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
handleBoardSearch(e.target.value);
},
[handleBoardSearch]
);
useEffect(() => {
// focus the search box on mount
if (!inputRef.current) {
return;
}
inputRef.current.focus();
}, []);
return ( return (
<InputGroup> <InputGroup>
<Input <Input
ref={inputRef}
placeholder="Search Boards..." placeholder="Search Boards..."
value={searchText} value={searchText}
onChange={(e) => { onKeyDown={handleKeydown}
handleBoardSearch(e.target.value); onChange={handleChange}
}}
/> />
{searchText && searchText.length && ( {searchText && searchText.length && (
<InputRightElement> <InputRightElement>
@ -55,7 +91,8 @@ const BoardsSearch = (props: Props) => {
size="xs" size="xs"
variant="ghost" variant="ghost"
aria-label="Clear Search" aria-label="Clear Search"
icon={<CloseIcon boxSize={3} />} opacity={0.5}
icon={<CloseIcon boxSize={2} />}
/> />
</InputRightElement> </InputRightElement>
)} )}

View File

@ -6,9 +6,9 @@ import {
EditableInput, EditableInput,
EditablePreview, EditablePreview,
Flex, Flex,
Icon,
Image, Image,
Text, Text,
useColorMode,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
@ -17,14 +17,12 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable'; import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/gallerySlice'; import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { FaUser } from 'react-icons/fa'; import { FaFolder } from 'react-icons/fa';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types'; import { BoardDTO } from 'services/api/types';
import { mode } from 'theme/util/mode';
import BoardContextMenu from '../BoardContextMenu'; import BoardContextMenu from '../BoardContextMenu';
const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = { const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = {
@ -66,8 +64,9 @@ const GalleryBoard = memo(
board.cover_image_name ?? skipToken board.cover_image_name ?? skipToken
); );
const { colorMode } = useColorMode();
const { board_name, board_id } = board; const { board_name, board_id } = board;
const [localBoardName, setLocalBoardName] = useState(board_name);
const handleSelectBoard = useCallback(() => { const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(board_id)); dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]); }, [board_id, dispatch]);
@ -75,10 +74,6 @@ const GalleryBoard = memo(
const [updateBoard, { isLoading: isUpdateBoardLoading }] = const [updateBoard, { isLoading: isUpdateBoardLoading }] =
useUpdateBoardMutation(); useUpdateBoardMutation();
const handleUpdateBoardName = (newBoardName: string) => {
updateBoard({ board_id, changes: { board_name: newBoardName } });
};
const droppableData: MoveBoardDropData = useMemo( const droppableData: MoveBoardDropData = useMemo(
() => ({ () => ({
id: board_id, id: board_id,
@ -88,59 +83,116 @@ const GalleryBoard = memo(
[board_id] [board_id]
); );
const handleSubmit = useCallback(
(newBoardName: string) => {
if (!newBoardName) {
// empty strings are not allowed
setLocalBoardName(board_name);
return;
}
if (newBoardName === board_name) {
// don't updated the board name if it hasn't changed
return;
}
updateBoard({ board_id, changes: { board_name: newBoardName } })
.unwrap()
.then((response) => {
// update local state
setLocalBoardName(response.board_name);
})
.catch(() => {
// revert on error
setLocalBoardName(board_name);
});
},
[board_id, board_name, updateBoard]
);
const handleChange = useCallback((newBoardName: string) => {
setLocalBoardName(newBoardName);
}, []);
return ( return (
<Box sx={{ touchAction: 'none', height: 'full' }}> <Box
<BoardContextMenu sx={{ w: 'full', h: 'full', touchAction: 'none', userSelect: 'none' }}
board={board} >
board_id={board_id} <Flex
setBoardToDelete={setBoardToDelete} sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
w: 'full',
h: 'full',
}}
> >
{(ref) => ( <BoardContextMenu
<Flex board={board}
key={board_id} board_id={board_id}
userSelect="none" setBoardToDelete={setBoardToDelete}
ref={ref} >
sx={{ {(ref) => (
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
}}
>
<Flex <Flex
ref={ref}
onClick={handleSelectBoard} onClick={handleSelectBoard}
sx={{ sx={{
w: 'full',
h: 'full',
position: 'relative', position: 'relative',
justifyContent: 'center', justifyContent: 'center',
alignItems: 'center', alignItems: 'center',
borderRadius: 'base', borderRadius: 'base',
w: 'full', cursor: 'pointer',
aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}} }}
> >
{board.cover_image_name && coverImage?.thumbnail_url && ( <Flex
<Image src={coverImage?.thumbnail_url} draggable={false} /> sx={{
)} w: 'full',
{!(board.cover_image_name && coverImage?.thumbnail_url) && ( h: 'full',
<IAINoContentFallback justifyContent: 'center',
boxSize={8} alignItems: 'center',
icon={FaUser} borderRadius: 'base',
sx={{ bg: 'base.200',
borderWidth: '2px', _dark: {
borderStyle: 'solid', bg: 'base.800',
borderColor: 'base.200', },
_dark: { }}
borderColor: 'base.800', >
}, {coverImage?.thumbnail_url ? (
}} <Image
/> src={coverImage?.thumbnail_url}
)} draggable={false}
sx={{
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
borderBottomRadius: 'lg',
}}
/>
) : (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
}}
>
<Icon
boxSize={12}
as={FaFolder}
sx={{
mt: -3,
opacity: 0.7,
color: 'base.500',
_dark: {
color: 'base.500',
},
}}
/>
</Flex>
)}
</Flex>
<Flex <Flex
sx={{ sx={{
position: 'absolute', position: 'absolute',
@ -160,56 +212,86 @@ const GalleryBoard = memo(
{board.image_count} {board.image_count}
</Badge> </Badge>
</Flex> </Flex>
<Box
className="selection-box"
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: 'common',
shadow: isSelected ? 'selected.light' : undefined,
_dark: {
shadow: isSelected ? 'selected.dark' : undefined,
},
}}
/>
<Flex
sx={{
position: 'absolute',
bottom: 0,
left: 0,
p: 1,
justifyContent: 'center',
alignItems: 'center',
w: 'full',
maxW: 'full',
borderBottomRadius: 'base',
bg: isSelected ? 'accent.400' : 'base.500',
color: isSelected ? 'base.50' : 'base.100',
_dark: {
bg: isSelected ? 'accent.500' : 'base.600',
color: isSelected ? 'base.50' : 'base.100',
},
lineHeight: 'short',
fontSize: 'xs',
}}
>
<Editable
value={localBoardName}
isDisabled={isUpdateBoardLoading}
submitOnBlur={true}
onChange={handleChange}
onSubmit={handleSubmit}
sx={{
w: 'full',
}}
>
<EditablePreview
sx={{
p: 0,
fontWeight: isSelected ? 700 : 500,
textAlign: 'center',
overflow: 'hidden',
textOverflow: 'ellipsis',
}}
noOfLines={1}
/>
<EditableInput
sx={{
p: 0,
_focusVisible: {
p: 0,
textAlign: 'center',
// get rid of the edit border
boxShadow: 'none',
},
}}
/>
</Editable>
</Flex>
<IAIDroppable <IAIDroppable
data={droppableData} data={droppableData}
dropLabel={<Text fontSize="md">Move</Text>} dropLabel={<Text fontSize="md">Move</Text>}
/> />
</Flex> </Flex>
)}
<Flex </BoardContextMenu>
sx={{ </Flex>
width: 'full',
height: 'full',
justifyContent: 'center',
alignItems: 'center',
}}
>
<Editable
defaultValue={board_name}
submitOnBlur={true}
onSubmit={(nextValue) => {
handleUpdateBoardName(nextValue);
}}
sx={{ maxW: 'full' }}
>
<EditablePreview
sx={{
color: isSelected
? mode('base.900', 'base.50')(colorMode)
: mode('base.700', 'base.200')(colorMode),
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
textAlign: 'center',
p: 0,
overflow: 'hidden',
textOverflow: 'ellipsis',
}}
noOfLines={1}
/>
<EditableInput
sx={{
color: mode('base.900', 'base.50')(colorMode),
fontSize: 'xs',
borderColor: mode('base.500', 'base.500')(colorMode),
p: 0,
outline: 0,
}}
/>
</Editable>
</Flex>
</Flex>
)}
</BoardContextMenu>
</Box> </Box>
); );
} }

View File

@ -17,7 +17,7 @@ type GenericBoardProps = {
badgeCount?: number; badgeCount?: number;
}; };
const formatBadgeCount = (count: number) => export const formatBadgeCount = (count: number) =>
Intl.NumberFormat('en-US', { Intl.NumberFormat('en-US', {
notation: 'compact', notation: 'compact',
maximumFractionDigits: 1, maximumFractionDigits: 1,
@ -92,7 +92,7 @@ const GenericBoard = (props: GenericBoardProps) => {
h: 'full', h: 'full',
alignItems: 'center', alignItems: 'center',
fontWeight: isSelected ? 600 : undefined, fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs', fontSize: 'sm',
color: isSelected ? 'base.900' : 'base.700', color: isSelected ? 'base.900' : 'base.700',
_dark: { color: isSelected ? 'base.50' : 'base.200' }, _dark: { color: isSelected ? 'base.50' : 'base.200' },
}} }}

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
type Props = {
board_id: 'images' | 'assets' | 'no_board';
};
const SystemBoardButton = ({ board_id }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery }) => {
const { selectedBoardId } = gallery;
return { isSelected: selectedBoardId === board_id };
},
defaultSelectorOptions
),
[board_id]
);
const { isSelected } = useAppSelector(selector);
const boardName = useBoardName(board_id);
const handleClick = useCallback(() => {
dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]);
return (
<IAIButton
onClick={handleClick}
size="sm"
isChecked={isSelected}
sx={{
flexGrow: 1,
borderRadius: 'base',
}}
>
{boardName}
</IAIButton>
);
};
export default memo(SystemBoardButton);

View File

@ -4,8 +4,9 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react'; import { memo, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName'; import { useBoardName } from 'services/api/hooks/useBoardName';
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -26,6 +27,16 @@ const GalleryBoardName = (props: Props) => {
const { isOpen, onToggle } = props; const { isOpen, onToggle } = props;
const { selectedBoardId } = useAppSelector(selector); const { selectedBoardId } = useAppSelector(selector);
const boardName = useBoardName(selectedBoardId); const boardName = useBoardName(selectedBoardId);
const numOfBoardImages = useBoardTotal(selectedBoardId);
const formattedBoardName = useMemo(() => {
if (!boardName) return '';
if (boardName && !numOfBoardImages) return boardName;
if (boardName.length > 20) {
return `${boardName.substring(0, 20)}... (${numOfBoardImages})`;
}
return `${boardName} (${numOfBoardImages})`;
}, [boardName, numOfBoardImages]);
return ( return (
<Flex <Flex
@ -58,7 +69,7 @@ const GalleryBoardName = (props: Props) => {
}, },
}} }}
> >
{boardName} {formattedBoardName}
</Text> </Text>
</Box> </Box>
<Spacer /> <Spacer />

View File

@ -109,7 +109,7 @@ const GalleryDrawer = () => {
isResizable={true} isResizable={true}
isOpen={shouldShowGallery} isOpen={shouldShowGallery}
onClose={handleCloseGallery} onClose={handleCloseGallery}
minWidth={337} minWidth={400}
> >
<ImageGalleryContent /> <ImageGalleryContent />
</ResizableDrawer> </ResizableDrawer>

View File

@ -1,4 +1,4 @@
import { Box } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
@ -86,38 +86,31 @@ const GalleryImage = (props: HoverableImageProps) => {
return ( return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}> <Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
<ImageContextMenu imageDTO={imageDTO}> <Flex
{(ref) => ( userSelect="none"
<Box sx={{
position="relative" position: 'relative',
key={imageName} justifyContent: 'center',
userSelect="none" alignItems: 'center',
ref={ref} aspectRatio: '1/1',
sx={{ }}
display: 'flex', >
justifyContent: 'center', <IAIDndImage
alignItems: 'center', onClick={handleClick}
aspectRatio: '1/1', imageDTO={imageDTO}
}} draggableData={draggableData}
> isSelected={isSelected}
<IAIDndImage minSize={0}
onClick={handleClick} onClickReset={handleDelete}
imageDTO={imageDTO} imageSx={{ w: 'full', h: 'full' }}
draggableData={draggableData} isDropDisabled={true}
isSelected={isSelected} isUploadDisabled={true}
minSize={0} thumbnail={true}
onClickReset={handleDelete} // resetIcon={<FaTrash />}
imageSx={{ w: 'full', h: 'full' }} // resetTooltip="Delete image"
isDropDisabled={true} // withResetIcon // removed bc it's too easy to accidentally delete images
isUploadDisabled={true} />
thumbnail={true} </Flex>
// resetIcon={<FaTrash />}
// resetTooltip="Delete image"
// withResetIcon // removed bc it's too easy to accidentally delete images
/>
</Box>
)}
</ImageContextMenu>
</Box> </Box>
); );
}; };

View File

@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = (
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlMode, controlMode,
resizeMode,
model, model,
processorType, processorType,
weight, weight,
@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = (
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
control_mode: controlMode, control_mode: controlMode,
resize_mode: resizeMode,
control_model: model as ControlNetInvocation['control_model'], control_model: model as ControlNetInvocation['control_model'],
control_weight: weight, control_weight: weight,
}; };

View File

@ -105,7 +105,7 @@ const enabledTabsSelector = createSelector(
} }
); );
const MIN_GALLERY_WIDTH = 300; const MIN_GALLERY_WIDTH = 350;
const DEFAULT_GALLERY_PCT = 20; const DEFAULT_GALLERY_PCT = 20;
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager']; export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];

View File

@ -3,7 +3,7 @@ import { memo } from 'react';
import { PanelResizeHandle } from 'react-resizable-panels'; import { PanelResizeHandle } from 'react-resizable-panels';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
type ResizeHandleProps = FlexProps & { type ResizeHandleProps = Omit<FlexProps, 'direction'> & {
direction?: 'horizontal' | 'vertical'; direction?: 'horizontal' | 'vertical';
}; };

View File

@ -169,10 +169,11 @@ export const imagesApi = api.injectEndpoints({
], ],
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) { async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
/** /**
* Cache changes for deleteImage: * Cache changes for `deleteImage`:
* - Remove from "All Images" * - *remove* from "All Images" / "All Assets"
* - Remove from image's `board_id` if it has one, or "No Board" if not * - IF it has a board:
* - Remove from "Batch" * - THEN *remove* from it's own board
* - ELSE *remove* from "No Board"
*/ */
const { image_name, board_id, image_category } = imageDTO; const { image_name, board_id, image_category } = imageDTO;
@ -181,22 +182,23 @@ export const imagesApi = api.injectEndpoints({
// That means constructing the possible query args that are serialized into the cache key... // That means constructing the possible query args that are serialized into the cache key...
const removeFromCacheKeys: ListImagesArgs[] = []; const removeFromCacheKeys: ListImagesArgs[] = [];
// determine `categories`, i.e. do we update "All Images" or "All Assets"
const categories = IMAGE_CATEGORIES.includes(image_category) const categories = IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES ? IMAGE_CATEGORIES
: ASSETS_CATEGORIES; : ASSETS_CATEGORIES;
// All Images board (e.g. no board) // remove from "All Images"
removeFromCacheKeys.push({ categories }); removeFromCacheKeys.push({ categories });
// Board specific
if (board_id) { if (board_id) {
// remove from it's own board
removeFromCacheKeys.push({ board_id }); removeFromCacheKeys.push({ board_id });
} else { } else {
// TODO: No Board // remove from "No Board"
removeFromCacheKeys.push({ board_id: 'none' });
} }
// TODO: Batch
const patches: PatchCollection[] = []; const patches: PatchCollection[] = [];
removeFromCacheKeys.forEach((cacheKey) => { removeFromCacheKeys.forEach((cacheKey) => {
patches.push( patches.push(
@ -240,32 +242,37 @@ export const imagesApi = api.injectEndpoints({
{ imageDTO: oldImageDTO, changes: _changes }, { imageDTO: oldImageDTO, changes: _changes },
{ dispatch, queryFulfilled, getState } { dispatch, queryFulfilled, getState }
) { ) {
// TODO: Should we handle changes to boards via this mutation? Seems reasonable...
// let's be extra-sure we do not accidentally change categories // let's be extra-sure we do not accidentally change categories
const changes = omit(_changes, 'image_category'); const changes = omit(_changes, 'image_category');
/** /**
* Cache changes for `updateImage`: * Cache changes for "updateImage":
* - Update the ImageDTO * - *update* "getImageDTO" cache
* - Update the image in "All Images" board: * - for "All Images" || "All Assets":
* - IF it is in the date range represented by the cache: * - IF it is not already in the cache
* - add the image IF it is not already in the cache & update the total * - THEN *add* it to "All Images" / "All Assets" and update the total
* - ELSE update the image IF it is already in the cache * - ELSE *update* it
* - IF the image has a board: * - IF the image has a board:
* - Update the image in it's own board * - THEN *update* it's own board
* - ELSE Update the image in the "No Board" board (TODO) * - ELSE *update* the "No Board" board
*/ */
const patches: PatchCollection[] = []; const patches: PatchCollection[] = [];
const { image_name, board_id, image_category } = oldImageDTO; const { image_name, board_id, image_category, is_intermediate } =
oldImageDTO;
const isChangingFromIntermediate = changes.is_intermediate === false;
// do not add intermediates to gallery cache
if (is_intermediate && !isChangingFromIntermediate) {
return;
}
// determine `categories`, i.e. do we update "All Images" or "All Assets"
const categories = IMAGE_CATEGORIES.includes(image_category) const categories = IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES ? IMAGE_CATEGORIES
: ASSETS_CATEGORIES; : ASSETS_CATEGORIES;
// TODO: No Board // update `getImageDTO` cache
// Update `getImageDTO` cache
patches.push( patches.push(
dispatch( dispatch(
imagesApi.util.updateQueryData( imagesApi.util.updateQueryData(
@ -281,9 +288,13 @@ export const imagesApi = api.injectEndpoints({
// Update the "All Image" or "All Assets" board // Update the "All Image" or "All Assets" board
const queryArgsToUpdate: ListImagesArgs[] = [{ categories }]; const queryArgsToUpdate: ListImagesArgs[] = [{ categories }];
// IF the image has a board:
if (board_id) { if (board_id) {
// We also need to update the user board // THEN update it's own board
queryArgsToUpdate.push({ board_id }); queryArgsToUpdate.push({ board_id });
} else {
// ELSE update the "No Board" board
queryArgsToUpdate.push({ board_id: 'none' });
} }
queryArgsToUpdate.forEach((queryArg) => { queryArgsToUpdate.forEach((queryArg) => {
@ -371,12 +382,12 @@ export const imagesApi = api.injectEndpoints({
return; return;
} }
// Add the image to the "All Images" / "All Assets" board // determine `categories`, i.e. do we update "All Images" or "All Assets"
const queryArg = { const categories = IMAGE_CATEGORIES.includes(image_category)
categories: IMAGE_CATEGORIES.includes(image_category) ? IMAGE_CATEGORIES
? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
: ASSETS_CATEGORIES,
}; const queryArg = { categories };
dispatch( dispatch(
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => { imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
@ -410,16 +421,14 @@ export const imagesApi = api.injectEndpoints({
{ dispatch, queryFulfilled, getState } { dispatch, queryFulfilled, getState }
) { ) {
/** /**
* Cache changes for addImageToBoard: * Cache changes for `addImageToBoard`:
* - Remove from "No Board" * - *update* the `getImageDTO` cache
* - Remove from `old_board_id` if it has one * - *remove* from "No Board"
* - Add to new `board_id` * - IF the image has an old `board_id`:
* - IF the image's `created_at` is within the range of the board's cached images * - THEN *remove* from it's old `board_id`
* - IF the image's `created_at` is within the range of the board's cached images
* - OR the board cache has length of 0 or 1 * - OR the board cache has length of 0 or 1
* - Update the `total` for each board whose cache is updated * - THEN *add* it to new `board_id`
* - Update the ImageDTO
*
* TODO: maybe total should just be updated in the boards endpoints?
*/ */
const { image_name, board_id: old_board_id } = oldImageDTO; const { image_name, board_id: old_board_id } = oldImageDTO;
@ -427,13 +436,10 @@ export const imagesApi = api.injectEndpoints({
// Figure out the `listImages` caches that we need to update // Figure out the `listImages` caches that we need to update
const removeFromQueryArgs: ListImagesArgs[] = []; const removeFromQueryArgs: ListImagesArgs[] = [];
// TODO: No Board // remove from "No Board"
// TODO: Batch
// Remove from No Board
removeFromQueryArgs.push({ board_id: 'none' }); removeFromQueryArgs.push({ board_id: 'none' });
// Remove from old board // remove from old board
if (old_board_id) { if (old_board_id) {
removeFromQueryArgs.push({ board_id: old_board_id }); removeFromQueryArgs.push({ board_id: old_board_id });
} }
@ -534,17 +540,15 @@ export const imagesApi = api.injectEndpoints({
{ dispatch, queryFulfilled, getState } { dispatch, queryFulfilled, getState }
) { ) {
/** /**
* Cache changes for removeImageFromBoard: * Cache changes for `removeImageFromBoard`:
* - Add to "No Board" * - *update* `getImageDTO`
* - IF the image's `created_at` is within the range of the board's cached images * - IF the image's `created_at` is within the range of the board's cached images
* - Remove from `old_board_id` * - THEN *add* to "No Board"
* - Update the ImageDTO * - *remove* from `old_board_id`
*/ */
const { image_name, board_id: old_board_id } = imageDTO; const { image_name, board_id: old_board_id } = imageDTO;
// TODO: Batch
const patches: PatchCollection[] = []; const patches: PatchCollection[] = [];
// Updated imageDTO with new board_id // Updated imageDTO with new board_id

View File

@ -6,9 +6,9 @@ export const useBoardName = (board_id: BoardId | null | undefined) => {
selectFromResult: ({ data }) => { selectFromResult: ({ data }) => {
let boardName = ''; let boardName = '';
if (board_id === 'images') { if (board_id === 'images') {
boardName = 'All Images'; boardName = 'Images';
} else if (board_id === 'assets') { } else if (board_id === 'assets') {
boardName = 'All Assets'; boardName = 'Assets';
} else if (board_id === 'no_board') { } else if (board_id === 'no_board') {
boardName = 'No Board'; boardName = 'No Board';
} else if (board_id === 'batch') { } else if (board_id === 'batch') {

View File

@ -0,0 +1,53 @@
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
ASSETS_CATEGORIES,
BoardId,
IMAGE_CATEGORIES,
INITIAL_IMAGE_LIMIT,
} from 'features/gallery/store/gallerySlice';
import { useMemo } from 'react';
import { ListImagesArgs, useListImagesQuery } from '../endpoints/images';
const baseQueryArg: ListImagesArgs = {
offset: 0,
limit: INITIAL_IMAGE_LIMIT,
is_intermediate: false,
};
const imagesQueryArg: ListImagesArgs = {
categories: IMAGE_CATEGORIES,
...baseQueryArg,
};
const assetsQueryArg: ListImagesArgs = {
categories: ASSETS_CATEGORIES,
...baseQueryArg,
};
const noBoardQueryArg: ListImagesArgs = {
board_id: 'none',
...baseQueryArg,
};
export const useBoardTotal = (board_id: BoardId | null | undefined) => {
const queryArg = useMemo(() => {
if (!board_id) {
return;
}
if (board_id === 'images') {
return imagesQueryArg;
} else if (board_id === 'assets') {
return assetsQueryArg;
} else if (board_id === 'no_board') {
return noBoardQueryArg;
} else {
return { board_id, ...baseQueryArg };
}
}, [board_id]);
const { total } = useListImagesQuery(queryArg ?? skipToken, {
selectFromResult: ({ currentData }) => ({ total: currentData?.total }),
});
return total;
};

View File

@ -167,7 +167,7 @@ export type paths = {
"/api/v1/images/clear-intermediates": { "/api/v1/images/clear-intermediates": {
/** /**
* Clear Intermediates * Clear Intermediates
* @description Clears first 100 intermediates * @description Clears all intermediates
*/ */
post: operations["clear_intermediates"]; post: operations["clear_intermediates"];
}; };
@ -800,6 +800,13 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced"; control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
/**
* Resize Mode
* @description The resize mode to use
* @default just_resize
* @enum {string}
*/
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
}; };
/** /**
* ControlNetInvocation * ControlNetInvocation
@ -859,6 +866,13 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced"; control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
/**
* Resize Mode
* @description The resize mode used
* @default just_resize
* @enum {string}
*/
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
}; };
/** ControlNetModelConfig */ /** ControlNetModelConfig */
ControlNetModelConfig: { ControlNetModelConfig: {
@ -5324,11 +5338,11 @@ export type components = {
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/** /**
* StableDiffusion2ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusionXLModelFormat * StableDiffusionXLModelFormat
* @description An enumeration. * @description An enumeration.
@ -5336,11 +5350,11 @@ export type components = {
*/ */
StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -6125,7 +6139,7 @@ export type operations = {
}; };
/** /**
* Clear Intermediates * Clear Intermediates
* @description Clears first 100 intermediates * @description Clears all intermediates
*/ */
clear_intermediates: { clear_intermediates: {
responses: { responses: {

View File

@ -2,11 +2,16 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
import { generateColorPalette } from 'theme/util/generateColorPalette'; import { generateColorPalette } from 'theme/util/generateColorPalette';
const BASE = { H: 220, S: 16 }; const BASE = { H: 220, S: 16 };
const ACCENT = { H: 250, S: 52 }; const ACCENT = { H: 250, S: 42 };
const WORKING = { H: 47, S: 50 }; // const ACCENT = { H: 250, S: 52 };
const WARNING = { H: 28, S: 50 }; const WORKING = { H: 47, S: 42 };
const OK = { H: 113, S: 50 }; // const WORKING = { H: 47, S: 50 };
const ERROR = { H: 0, S: 50 }; const WARNING = { H: 28, S: 42 };
// const WARNING = { H: 28, S: 50 };
const OK = { H: 113, S: 42 };
// const OK = { H: 113, S: 50 };
const ERROR = { H: 0, S: 42 };
// const ERROR = { H: 0, S: 50 };
export const InvokeAIColors: InvokeAIThemeColors = { export const InvokeAIColors: InvokeAIThemeColors = {
base: generateColorPalette(BASE.H, BASE.S), base: generateColorPalette(BASE.H, BASE.S),

View File

@ -0,0 +1,56 @@
import { editableAnatomy as parts } from '@chakra-ui/anatomy';
import {
createMultiStyleConfigHelpers,
defineStyle,
} from '@chakra-ui/styled-system';
import { mode } from '@chakra-ui/theme-tools';
const { definePartsStyle, defineMultiStyleConfig } =
createMultiStyleConfigHelpers(parts.keys);
const baseStylePreview = defineStyle({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
});
const baseStyleInput = defineStyle((props) => ({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
width: 'full',
_focusVisible: { boxShadow: 'outline' },
_placeholder: { opacity: 0.6 },
'::selection': {
color: mode('accent.900', 'accent.50')(props),
bg: mode('accent.200', 'accent.400')(props),
},
}));
const baseStyleTextarea = defineStyle({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
width: 'full',
_focusVisible: { boxShadow: 'outline' },
_placeholder: { opacity: 0.6 },
});
const invokeAI = definePartsStyle((props) => ({
preview: baseStylePreview,
input: baseStyleInput(props),
textarea: baseStyleTextarea,
}));
export const editableTheme = defineMultiStyleConfig({
variants: {
invokeAI,
},
defaultProps: {
size: 'sm',
variant: 'invokeAI',
},
});

View File

@ -4,6 +4,7 @@ import { InvokeAIColors } from './colors/colors';
import { accordionTheme } from './components/accordion'; import { accordionTheme } from './components/accordion';
import { buttonTheme } from './components/button'; import { buttonTheme } from './components/button';
import { checkboxTheme } from './components/checkbox'; import { checkboxTheme } from './components/checkbox';
import { editableTheme } from './components/editable';
import { formLabelTheme } from './components/formLabel'; import { formLabelTheme } from './components/formLabel';
import { inputTheme } from './components/input'; import { inputTheme } from './components/input';
import { menuTheme } from './components/menu'; import { menuTheme } from './components/menu';
@ -72,7 +73,17 @@ export const theme: ThemeOverride = {
selected: { selected: {
light: light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', '0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
},
hoverSelected: {
light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-300)',
},
hoverUnselected: {
light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-200)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-600)',
}, },
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`, nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`,
}, },
@ -80,6 +91,7 @@ export const theme: ThemeOverride = {
components: { components: {
Button: buttonTheme, // Button and IconButton Button: buttonTheme, // Button and IconButton
Input: inputTheme, Input: inputTheme,
Editable: editableTheme,
Textarea: textareaTheme, Textarea: textareaTheme,
Tabs: tabsTheme, Tabs: tabsTheme,
Progress: progressTheme, Progress: progressTheme,

View File

@ -37,4 +37,7 @@ export const getInputOutlineStyles = (props: StyleFunctionProps) => ({
_placeholder: { _placeholder: {
color: mode('base.700', 'base.400')(props), color: mode('base.700', 'base.400')(props),
}, },
'::selection': {
bg: mode('accent.200', 'accent.400')(props),
},
}); });