Merge branch 'main' into nodepromptsize

This commit is contained in:
mickr777 2023-07-21 08:07:55 +10:00 committed by GitHub
commit 98b2734240
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1274 additions and 362 deletions

View File

@ -1,9 +1,22 @@
from enum import Enum
from fastapi import Body
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.version import __version__
from ..dependencies import ApiDependencies
from invokeai.backend.util.logging import logging
class LogLevel(int, Enum):
NotSet = logging.NOTSET
Debug = logging.DEBUG
Info = logging.INFO
Warning = logging.WARNING
Error = logging.ERROR
Critical = logging.CRITICAL
app_router = APIRouter(prefix="/v1/app", tags=["app"])
@ -34,3 +47,27 @@ async def get_config() -> AppConfig:
if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch')
return AppConfig(infill_methods=infill_methods)
@app_router.get(
"/logging",
operation_id="get_log_level",
responses={200: {"description" : "The operation was successful"}},
response_model = LogLevel,
)
async def get_log_level(
) -> LogLevel:
"""Returns the log level"""
return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.post(
"/logging",
operation_id="set_log_level",
responses={200: {"description" : "The operation was successful"}},
response_model = LogLevel,
)
async def set_log_level(
level: LogLevel = Body(description="New log verbosity level"),
) -> LogLevel:
"""Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level)
return LogLevel(ApiDependencies.invoker.services.logger.level)

View File

@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(
["balanced", "more_prompt", "more_control", "unbalanced"])]
# crop and fill options not ready yet
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
CONTROLNET_RESIZE_VALUES = Literal[tuple(
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
class ControlNetModelField(BaseModel):
@ -111,7 +111,8 @@ class ControlField(BaseModel):
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(
default="balanced", description="The control mode to use")
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
default="just_resize", description="The resize mode to use")
@validator("control_weight")
def validate_control_weight(cls, v):
@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation):
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
# fmt: on
class Config(InvocationConfig):
@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)

View File

@ -30,6 +30,7 @@ from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import (
AttnProcessor2_0,
@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation):
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation):
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, image_tensor=control_image,
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
@ -758,7 +764,7 @@ class ImageToLatentsInvocation(BaseInvocation):
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}"

View File

@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
@ -243,10 +244,31 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
},
}
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.noise.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@ -341,6 +363,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
@ -409,6 +432,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
@ -473,10 +497,31 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
},
}
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.latents.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
@ -579,6 +624,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
@ -647,6 +693,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)

View File

@ -0,0 +1,342 @@
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers.utils import PIL_INTERPOLATION
from einops import rearrange
from controlnet_aux.util import HWC3, resize_image
###################################################################
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
###################################################################
# High Quality Edge Thinning using Pure Python
# Written by Lvmin Zhangu
# 2023 April
# Stanford University
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
lvmin_kernels_raw = [
np.array([
[-1, -1, -1],
[0, 1, 0],
[1, 1, 1]
], dtype=np.int32),
np.array([
[0, -1, -1],
[1, 1, -1],
[0, 1, 0]
], dtype=np.int32)
]
lvmin_kernels = []
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_prunings_raw = [
np.array([
[-1, -1, -1],
[-1, 1, -1],
[0, 0, -1]
], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[-1, 0, 0]
], dtype=np.int32)
]
lvmin_prunings = []
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
def remove_pattern(x, kernel):
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
objects = np.where(objects > 127)
x[objects] = 0
return x, objects[0].shape[0] > 0
def thin_one_time(x, kernels):
y = x
is_done = True
for k in kernels:
y, has_update = remove_pattern(y, k)
if has_update:
is_done = False
return y, is_done
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break
if prunings:
y, _ = thin_one_time(y, lvmin_prunings)
return y
def nake_nms(x):
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
return y
################################################################################
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
################################################################################
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
def pixel_perfect_resolution(
image: np.ndarray,
target_H: int,
target_W: int,
resize_mode: str,
) -> int:
"""
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
The function first calculates scaling factors for height and width of the image based on the target
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
scaling factor to estimate the new resolution.
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
fits within the target dimensions, potentially leaving some empty space.
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
dimensions are fully filled, potentially cropping the image.
After calculating the estimated resolution, the function prints some debugging information.
Args:
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
target_H (int): The target height for the image.
target_W (int): The target width for the image.
resize_mode (ResizeMode): The mode for resizing.
Returns:
int: The estimated resolution after resizing.
"""
raw_H, raw_W, _ = image.shape
k0 = float(target_H) / float(raw_H)
k1 = float(target_W) / float(raw_W)
if resize_mode == "fill_resize":
estimation = min(k0, k1) * float(min(raw_H, raw_W))
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
estimation = max(k0, k1) * float(min(raw_H, raw_W))
# print(f"Pixel Perfect Computation:")
# print(f"resize_mode = {resize_mode}")
# print(f"raw_H = {raw_H}")
# print(f"raw_W = {raw_W}")
# print(f"target_H = {target_H}")
# print(f"target_W = {target_W}")
# print(f"estimation = {estimation}")
return int(np.round(estimation))
###########################################################################
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
# modified for InvokeAI
###########################################################################
# def detectmap_proc(detected_map, module, resize_mode, h, w):
def np_img_resize(
np_img: np.ndarray,
resize_mode: str,
h: int,
w: int,
device: torch.device = torch.device('cpu')
):
# if 'inpaint' in module:
# np_img = np_img.astype(np.float32)
# else:
# np_img = HWC3(np_img)
np_img = HWC3(np_img)
def safe_numpy(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = y.copy()
y = np.ascontiguousarray(y)
y = y.copy()
return y
def get_pytorch_control(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = torch.from_numpy(y)
y = y.float() / 255.0
y = rearrange(y, 'h w c -> 1 c h w')
y = y.clone()
# y = y.to(devices.get_device_for("controlnet"))
y = y.to(device)
y = y.clone()
return y
def high_quality_resize(x: np.ndarray,
size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
is_binary = np.min(x) < 16 and np.max(x) > 240
if is_binary:
xc = x
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
all_edge_count = np.where(x > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
y = nake_nms(y)
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = lvmin_thin(y, prunings=new_size_is_bigger)
else:
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = np.stack([y] * 3, axis=2)
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y
# if resize_mode == external_code.ResizeMode.RESIZE:
if resize_mode == "just_resize": # RESIZE
np_img = high_quality_resize(np_img, (w, h))
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
old_h, old_w, _ = np_img.shape
old_w = float(old_w)
old_h = float(old_h)
k0 = float(h) / old_h
k1 = float(w) / old_w
safeint = lambda x: int(np.round(x))
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == "fill_resize": # OUTER_FIT
k = min(k0, k1)
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
if len(high_quality_border_color) == 4:
# Inpaint hijack
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
np_img = high_quality_background
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1)
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
image: Image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
# batch_size=1, # currently no batching
# num_images_per_prompt=1, # currently only single image
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced",
resize_mode="just_resize_simple",
):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
if (resize_mode == "just_resize_simple" or
resize_mode == "crop_resize_simple" or
resize_mode == "fill_resize_simple"):
image = image.convert("RGB")
if (resize_mode == "just_resize_simple"):
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif (resize_mode == "crop_resize_simple"): # not yet implemented
pass
elif (resize_mode == "fill_resize_simple"): # not yet implemented
pass
nimage = np.array(image)
nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0)
# normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
nimage = np.array(nimage).astype(np.float32) / 255.0
nimage = nimage.transpose(0, 3, 1, 2)
timage = torch.from_numpy(nimage)
# use fancy lvmin controlnet resizing
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
nimage = np.array(image)
timage, nimage = np_img_resize(
np_img=nimage,
resize_mode=resize_mode,
h=height,
w=width,
# device=torch.device('cpu')
device=device,
)
else:
pass
print("ERROR: invalid resize_mode ==> ", resize_mode)
exit(1)
timage = timage.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
timage = torch.cat([timage] * 2)
return timage

View File

@ -1,9 +1,30 @@
import torch
from PIL import Image
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()
return Image.fromarray(latents_ubyte.numpy())
def stable_diffusion_step_callback(
@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
# step = intermediate_state.step
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size
width *= 8
@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
step=intermediate_state.step,
total_steps=node["steps"],
)
def stable_diffusion_xl_step_callback(
context: InvocationContext,
node: dict,
source_node_id: str,
sample,
step,
total_steps,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[ 0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[ 0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
#[ 0.0478, 0.1285, 0.0478],
#[ 0.1285, 0.2948, 0.1285],
#[ 0.0478, 0.1285, 0.0478],
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=step,
total_steps=total_steps,
)

View File

@ -219,6 +219,7 @@ class ControlNetData:
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
resize_mode: str = Field(default="just_resize")
@dataclass
@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
# prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
)
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
@staticmethod
def prepare_control_image(
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
return image

View File

@ -17,13 +17,13 @@ import {
} from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO, PostUploadAction } from 'services/api/types';
import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
type IAIDndImageProps = {
imageDTO: ImageDTO | undefined;
@ -148,7 +148,9 @@ const IAIDndImage = (props: IAIDndImageProps) => {
maxH: 'full',
borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
_dark: {
shadow: isSelected ? 'selected.dark' : undefined,
},
...imageSx,
}}
/>
@ -183,13 +185,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
@ -197,6 +192,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick}
/>
)}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}

View File

@ -13,10 +13,11 @@ type IAIDroppableProps = {
dropLabel?: ReactNode;
disabled?: boolean;
data?: TypesafeDroppableData;
hoverRef?: React.Ref<HTMLDivElement>;
};
const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const { dropLabel, data, disabled, hoverRef } = props;
const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppable({

View File

@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
type ControlNetProps = {
controlNetId: string;
@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => {
<Flex
sx={{
flexDir: 'column',
gap: 2,
gap: 3,
p: 3,
borderRadius: 'base',
position: 'relative',
@ -117,7 +118,12 @@ const ControlNet = (props: ControlNetProps) => {
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded}
variant="link"
variant="ghost"
sx={{
_hover: {
bg: 'none',
},
}}
icon={
<ChevronUpIcon
sx={{
@ -151,7 +157,7 @@ const ControlNet = (props: ControlNetProps) => {
/>
)}
</Flex>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex
sx={{
@ -176,16 +182,16 @@ const ControlNet = (props: ControlNetProps) => {
h: 28,
w: 28,
aspectRatio: '1/1',
mt: 3,
}}
>
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
</Flex>
)}
</Flex>
<Box mt={2}>
<Flex sx={{ gap: 2 }}>
<ParamControlNetControlMode controlNetId={controlNetId} />
</Box>
<ParamControlNetResizeMode controlNetId={controlNetId} />
</Flex>
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex>

View File

@ -0,0 +1,62 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ResizeModes,
controlNetResizeModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetResizeModeProps = {
controlNetId: string;
};
const RESIZE_MODE_DATA = [
{ label: 'Resize', value: 'just_resize' },
{ label: 'Crop', value: 'crop_resize' },
{ label: 'Fill', value: 'fill_resize' },
];
export default function ParamControlNetResizeMode(
props: ParamControlNetResizeModeProps
) {
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { resizeMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { resizeMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { resizeMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation();
const handleResizeModeChange = useCallback(
(resizeMode: ResizeModes) => {
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
},
[controlNetId, dispatch]
);
return (
<IAIMantineSelect
disabled={!isEnabled}
label="Resize Mode"
data={RESIZE_MODE_DATA}
value={String(resizeMode)}
onChange={handleResizeModeChange}
/>
);
}

View File

@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
@ -16,11 +17,13 @@ import {
RequiredControlNetProcessorNode,
} from './types';
export type ControlModes =
| 'balanced'
| 'more_prompt'
| 'more_control'
| 'unbalanced';
export type ControlModes = NonNullable<
components['schemas']['ControlNetInvocation']['control_mode']
>;
export type ResizeModes = NonNullable<
components['schemas']['ControlNetInvocation']['resize_mode']
>;
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
@ -29,6 +32,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
@ -45,6 +49,7 @@ export type ControlNetConfig = {
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
resizeMode: ResizeModes;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({
const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode;
},
controlNetResizeModeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
resizeMode: ResizeModes;
}>
) => {
const { controlNetId, resizeMode } = action.payload;
state.controlNets[controlNetId].resizeMode = resizeMode;
},
controlNetProcessorParamsChanged: (
state,
action: PayloadAction<{
@ -342,6 +357,7 @@ export const {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetControlModeChanged,
controlNetResizeModeChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
controlNetReset,

View File

@ -23,7 +23,6 @@ const BoardContextMenu = memo(
dispatch(boardIdSelected(board?.board_id ?? board_id));
}, [board?.board_id, board_id, dispatch]);
return (
<Box sx={{ touchAction: 'none', height: 'full' }}>
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
menuButtonProps={{
@ -50,7 +49,6 @@ const BoardContextMenu = memo(
>
{children}
</ContextMenu>
</Box>
);
}
);

View File

@ -1,27 +1,21 @@
import {
Collapse,
Flex,
Grid,
GridItem,
useDisclosure,
} from '@chakra-ui/react';
import { ButtonGroup, Collapse, Flex, Grid, GridItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import { AnimatePresence, motion } from 'framer-motion';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useState } from 'react';
import { memo, useCallback, useState } from 'react';
import { FaSearch } from 'react-icons/fa';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { BoardDTO } from 'services/api/types';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
import DeleteBoardModal from '../DeleteBoardModal';
import AddBoardButton from './AddBoardButton';
import AllAssetsBoard from './AllAssetsBoard';
import AllImagesBoard from './AllImagesBoard';
import BatchBoard from './BatchBoard';
import BoardsSearch from './BoardsSearch';
import GalleryBoard from './GalleryBoard';
import NoBoardBoard from './NoBoardBoard';
import DeleteBoardModal from '../DeleteBoardModal';
import { BoardDTO } from 'services/api/types';
import SystemBoardButton from './SystemBoardButton';
const selector = createSelector(
[stateSelector],
@ -48,7 +42,10 @@ const BoardsList = (props: Props) => {
)
: boards;
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
const [searchMode, setSearchMode] = useState(false);
const [isSearching, setIsSearching] = useState(false);
const handleClickSearchIcon = useCallback(() => {
setIsSearching((v) => !v);
}, []);
return (
<>
@ -64,7 +61,54 @@ const BoardsList = (props: Props) => {
}}
>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<BoardsSearch setSearchMode={setSearchMode} />
<AnimatePresence mode="popLayout">
{isSearching ? (
<motion.div
key="boards-search"
initial={{
opacity: 0,
}}
exit={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
style={{ width: '100%' }}
>
<BoardsSearch setIsSearching={setIsSearching} />
</motion.div>
) : (
<motion.div
key="system-boards-select"
initial={{
opacity: 0,
}}
exit={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
style={{ width: '100%' }}
>
<ButtonGroup sx={{ w: 'full', ps: 1.5 }} isAttached>
<SystemBoardButton board_id="images" />
<SystemBoardButton board_id="assets" />
<SystemBoardButton board_id="no_board" />
</ButtonGroup>
</motion.div>
)}
</AnimatePresence>
<IAIIconButton
aria-label="Search Boards"
size="sm"
isChecked={isSearching}
onClick={handleClickSearchIcon}
icon={<FaSearch />}
/>
<AddBoardButton />
</Flex>
<OverlayScrollbarsComponent
@ -82,29 +126,10 @@ const BoardsList = (props: Props) => {
<Grid
className="list-container"
sx={{
gridTemplateRows: '6.5rem 6.5rem',
gridAutoFlow: 'column dense',
gridAutoColumns: '5rem',
gridTemplateColumns: `repeat(auto-fill, minmax(96px, 1fr));`,
maxH: 346,
}}
>
{!searchMode && (
<>
<GridItem sx={{ p: 1.5 }}>
<AllImagesBoard isSelected={selectedBoardId === 'images'} />
</GridItem>
<GridItem sx={{ p: 1.5 }}>
<AllAssetsBoard isSelected={selectedBoardId === 'assets'} />
</GridItem>
<GridItem sx={{ p: 1.5 }}>
<NoBoardBoard isSelected={selectedBoardId === 'no_board'} />
</GridItem>
{isBatchEnabled && (
<GridItem sx={{ p: 1.5 }}>
<BatchBoard isSelected={selectedBoardId === 'batch'} />
</GridItem>
)}
</>
)}
{filteredBoards &&
filteredBoards.map((board) => (
<GridItem key={board.board_id} sx={{ p: 1.5 }}>

View File

@ -10,7 +10,14 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { setBoardSearchText } from 'features/gallery/store/boardSlice';
import { memo } from 'react';
import {
ChangeEvent,
KeyboardEvent,
memo,
useCallback,
useEffect,
useRef,
} from 'react';
const selector = createSelector(
[stateSelector],
@ -22,31 +29,60 @@ const selector = createSelector(
);
type Props = {
setSearchMode: (searchMode: boolean) => void;
setIsSearching: (isSearching: boolean) => void;
};
const BoardsSearch = (props: Props) => {
const { setSearchMode } = props;
const { setIsSearching } = props;
const dispatch = useAppDispatch();
const { searchText } = useAppSelector(selector);
const inputRef = useRef<HTMLInputElement>(null);
const handleBoardSearch = (searchTerm: string) => {
setSearchMode(searchTerm.length > 0);
const handleBoardSearch = useCallback(
(searchTerm: string) => {
dispatch(setBoardSearchText(searchTerm));
};
const clearBoardSearch = () => {
setSearchMode(false);
},
[dispatch]
);
const clearBoardSearch = useCallback(() => {
dispatch(setBoardSearchText(''));
};
setIsSearching(false);
}, [dispatch, setIsSearching]);
const handleKeydown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
// exit search mode on escape
if (e.key === 'Escape') {
clearBoardSearch();
}
},
[clearBoardSearch]
);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
handleBoardSearch(e.target.value);
},
[handleBoardSearch]
);
useEffect(() => {
// focus the search box on mount
if (!inputRef.current) {
return;
}
inputRef.current.focus();
}, []);
return (
<InputGroup>
<Input
ref={inputRef}
placeholder="Search Boards..."
value={searchText}
onChange={(e) => {
handleBoardSearch(e.target.value);
}}
onKeyDown={handleKeydown}
onChange={handleChange}
/>
{searchText && searchText.length && (
<InputRightElement>
@ -55,7 +91,8 @@ const BoardsSearch = (props: Props) => {
size="xs"
variant="ghost"
aria-label="Clear Search"
icon={<CloseIcon boxSize={3} />}
opacity={0.5}
icon={<CloseIcon boxSize={2} />}
/>
</InputRightElement>
)}

View File

@ -6,9 +6,9 @@ import {
EditableInput,
EditablePreview,
Flex,
Icon,
Image,
Text,
useColorMode,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
@ -17,14 +17,12 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { FaUser } from 'react-icons/fa';
import { memo, useCallback, useMemo, useState } from 'react';
import { FaFolder } from 'react-icons/fa';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types';
import { mode } from 'theme/util/mode';
import BoardContextMenu from '../BoardContextMenu';
const AUTO_ADD_BADGE_STYLES: ChakraProps['sx'] = {
@ -66,8 +64,9 @@ const GalleryBoard = memo(
board.cover_image_name ?? skipToken
);
const { colorMode } = useColorMode();
const { board_name, board_id } = board;
const [localBoardName, setLocalBoardName] = useState(board_name);
const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]);
@ -75,10 +74,6 @@ const GalleryBoard = memo(
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
useUpdateBoardMutation();
const handleUpdateBoardName = (newBoardName: string) => {
updateBoard({ board_id, changes: { board_name: newBoardName } });
};
const droppableData: MoveBoardDropData = useMemo(
() => ({
id: board_id,
@ -88,8 +83,49 @@ const GalleryBoard = memo(
[board_id]
);
const handleSubmit = useCallback(
(newBoardName: string) => {
if (!newBoardName) {
// empty strings are not allowed
setLocalBoardName(board_name);
return;
}
if (newBoardName === board_name) {
// don't updated the board name if it hasn't changed
return;
}
updateBoard({ board_id, changes: { board_name: newBoardName } })
.unwrap()
.then((response) => {
// update local state
setLocalBoardName(response.board_name);
})
.catch(() => {
// revert on error
setLocalBoardName(board_name);
});
},
[board_id, board_name, updateBoard]
);
const handleChange = useCallback((newBoardName: string) => {
setLocalBoardName(newBoardName);
}, []);
return (
<Box sx={{ touchAction: 'none', height: 'full' }}>
<Box
sx={{ w: 'full', h: 'full', touchAction: 'none', userSelect: 'none' }}
>
<Flex
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
w: 'full',
h: 'full',
}}
>
<BoardContextMenu
board={board}
board_id={board_id}
@ -97,50 +133,66 @@ const GalleryBoard = memo(
>
{(ref) => (
<Flex
key={board_id}
userSelect="none"
ref={ref}
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
}}
>
<Flex
onClick={handleSelectBoard}
sx={{
w: 'full',
h: 'full',
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
cursor: 'pointer',
}}
>
{board.cover_image_name && coverImage?.thumbnail_url && (
<Image src={coverImage?.thumbnail_url} draggable={false} />
)}
{!(board.cover_image_name && coverImage?.thumbnail_url) && (
<IAINoContentFallback
boxSize={8}
icon={FaUser}
<Flex
sx={{
borderWidth: '2px',
borderStyle: 'solid',
borderColor: 'base.200',
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.200',
_dark: {
borderColor: 'base.800',
bg: 'base.800',
},
}}
>
{coverImage?.thumbnail_url ? (
<Image
src={coverImage?.thumbnail_url}
draggable={false}
sx={{
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
borderBottomRadius: 'lg',
}}
/>
) : (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
}}
>
<Icon
boxSize={12}
as={FaFolder}
sx={{
mt: -3,
opacity: 0.7,
color: 'base.500',
_dark: {
color: 'base.500',
},
}}
/>
</Flex>
)}
</Flex>
<Flex
sx={{
position: 'absolute',
@ -160,37 +212,59 @@ const GalleryBoard = memo(
{board.image_count}
</Badge>
</Flex>
<IAIDroppable
data={droppableData}
dropLabel={<Text fontSize="md">Move</Text>}
<Box
className="selection-box"
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: 'common',
shadow: isSelected ? 'selected.light' : undefined,
_dark: {
shadow: isSelected ? 'selected.dark' : undefined,
},
}}
/>
</Flex>
<Flex
sx={{
width: 'full',
height: 'full',
position: 'absolute',
bottom: 0,
left: 0,
p: 1,
justifyContent: 'center',
alignItems: 'center',
w: 'full',
maxW: 'full',
borderBottomRadius: 'base',
bg: isSelected ? 'accent.400' : 'base.500',
color: isSelected ? 'base.50' : 'base.100',
_dark: {
bg: isSelected ? 'accent.500' : 'base.600',
color: isSelected ? 'base.50' : 'base.100',
},
lineHeight: 'short',
fontSize: 'xs',
}}
>
<Editable
defaultValue={board_name}
value={localBoardName}
isDisabled={isUpdateBoardLoading}
submitOnBlur={true}
onSubmit={(nextValue) => {
handleUpdateBoardName(nextValue);
onChange={handleChange}
onSubmit={handleSubmit}
sx={{
w: 'full',
}}
sx={{ maxW: 'full' }}
>
<EditablePreview
sx={{
color: isSelected
? mode('base.900', 'base.50')(colorMode)
: mode('base.700', 'base.200')(colorMode),
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
textAlign: 'center',
p: 0,
fontWeight: isSelected ? 700 : 500,
textAlign: 'center',
overflow: 'hidden',
textOverflow: 'ellipsis',
}}
@ -198,18 +272,26 @@ const GalleryBoard = memo(
/>
<EditableInput
sx={{
color: mode('base.900', 'base.50')(colorMode),
fontSize: 'xs',
borderColor: mode('base.500', 'base.500')(colorMode),
p: 0,
outline: 0,
_focusVisible: {
p: 0,
textAlign: 'center',
// get rid of the edit border
boxShadow: 'none',
},
}}
/>
</Editable>
</Flex>
<IAIDroppable
data={droppableData}
dropLabel={<Text fontSize="md">Move</Text>}
/>
</Flex>
)}
</BoardContextMenu>
</Flex>
</Box>
);
}

View File

@ -17,7 +17,7 @@ type GenericBoardProps = {
badgeCount?: number;
};
const formatBadgeCount = (count: number) =>
export const formatBadgeCount = (count: number) =>
Intl.NumberFormat('en-US', {
notation: 'compact',
maximumFractionDigits: 1,
@ -92,7 +92,7 @@ const GenericBoard = (props: GenericBoardProps) => {
h: 'full',
alignItems: 'center',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
fontSize: 'sm',
color: isSelected ? 'base.900' : 'base.700',
_dark: { color: isSelected ? 'base.50' : 'base.200' },
}}

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
type Props = {
board_id: 'images' | 'assets' | 'no_board';
};
const SystemBoardButton = ({ board_id }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery }) => {
const { selectedBoardId } = gallery;
return { isSelected: selectedBoardId === board_id };
},
defaultSelectorOptions
),
[board_id]
);
const { isSelected } = useAppSelector(selector);
const boardName = useBoardName(board_id);
const handleClick = useCallback(() => {
dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]);
return (
<IAIButton
onClick={handleClick}
size="sm"
isChecked={isSelected}
sx={{
flexGrow: 1,
borderRadius: 'base',
}}
>
{boardName}
</IAIButton>
);
};
export default memo(SystemBoardButton);

View File

@ -4,8 +4,9 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
const selector = createSelector(
[stateSelector],
@ -26,6 +27,16 @@ const GalleryBoardName = (props: Props) => {
const { isOpen, onToggle } = props;
const { selectedBoardId } = useAppSelector(selector);
const boardName = useBoardName(selectedBoardId);
const numOfBoardImages = useBoardTotal(selectedBoardId);
const formattedBoardName = useMemo(() => {
if (!boardName) return '';
if (boardName && !numOfBoardImages) return boardName;
if (boardName.length > 20) {
return `${boardName.substring(0, 20)}... (${numOfBoardImages})`;
}
return `${boardName} (${numOfBoardImages})`;
}, [boardName, numOfBoardImages]);
return (
<Flex
@ -58,7 +69,7 @@ const GalleryBoardName = (props: Props) => {
},
}}
>
{boardName}
{formattedBoardName}
</Text>
</Box>
<Spacer />

View File

@ -109,7 +109,7 @@ const GalleryDrawer = () => {
isResizable={true}
isOpen={shouldShowGallery}
onClose={handleCloseGallery}
minWidth={337}
minWidth={400}
>
<ImageGalleryContent />
</ResizableDrawer>

View File

@ -1,4 +1,4 @@
import { Box } from '@chakra-ui/react';
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
@ -86,15 +86,10 @@ const GalleryImage = (props: HoverableImageProps) => {
return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
<ImageContextMenu imageDTO={imageDTO}>
{(ref) => (
<Box
position="relative"
key={imageName}
<Flex
userSelect="none"
ref={ref}
sx={{
display: 'flex',
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
@ -115,9 +110,7 @@ const GalleryImage = (props: HoverableImageProps) => {
// resetTooltip="Delete image"
// withResetIcon // removed bc it's too easy to accidentally delete images
/>
</Box>
)}
</ImageContextMenu>
</Flex>
</Box>
);
};

View File

@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = (
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = (
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
resize_mode: resizeMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};

View File

@ -105,7 +105,7 @@ const enabledTabsSelector = createSelector(
}
);
const MIN_GALLERY_WIDTH = 300;
const MIN_GALLERY_WIDTH = 350;
const DEFAULT_GALLERY_PCT = 20;
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];

View File

@ -3,7 +3,7 @@ import { memo } from 'react';
import { PanelResizeHandle } from 'react-resizable-panels';
import { mode } from 'theme/util/mode';
type ResizeHandleProps = FlexProps & {
type ResizeHandleProps = Omit<FlexProps, 'direction'> & {
direction?: 'horizontal' | 'vertical';
};

View File

@ -169,10 +169,11 @@ export const imagesApi = api.injectEndpoints({
],
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
/**
* Cache changes for deleteImage:
* - Remove from "All Images"
* - Remove from image's `board_id` if it has one, or "No Board" if not
* - Remove from "Batch"
* Cache changes for `deleteImage`:
* - *remove* from "All Images" / "All Assets"
* - IF it has a board:
* - THEN *remove* from it's own board
* - ELSE *remove* from "No Board"
*/
const { image_name, board_id, image_category } = imageDTO;
@ -181,22 +182,23 @@ export const imagesApi = api.injectEndpoints({
// That means constructing the possible query args that are serialized into the cache key...
const removeFromCacheKeys: ListImagesArgs[] = [];
// determine `categories`, i.e. do we update "All Images" or "All Assets"
const categories = IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES
: ASSETS_CATEGORIES;
// All Images board (e.g. no board)
// remove from "All Images"
removeFromCacheKeys.push({ categories });
// Board specific
if (board_id) {
// remove from it's own board
removeFromCacheKeys.push({ board_id });
} else {
// TODO: No Board
// remove from "No Board"
removeFromCacheKeys.push({ board_id: 'none' });
}
// TODO: Batch
const patches: PatchCollection[] = [];
removeFromCacheKeys.forEach((cacheKey) => {
patches.push(
@ -240,32 +242,37 @@ export const imagesApi = api.injectEndpoints({
{ imageDTO: oldImageDTO, changes: _changes },
{ dispatch, queryFulfilled, getState }
) {
// TODO: Should we handle changes to boards via this mutation? Seems reasonable...
// let's be extra-sure we do not accidentally change categories
const changes = omit(_changes, 'image_category');
/**
* Cache changes for `updateImage`:
* - Update the ImageDTO
* - Update the image in "All Images" board:
* - IF it is in the date range represented by the cache:
* - add the image IF it is not already in the cache & update the total
* - ELSE update the image IF it is already in the cache
* Cache changes for "updateImage":
* - *update* "getImageDTO" cache
* - for "All Images" || "All Assets":
* - IF it is not already in the cache
* - THEN *add* it to "All Images" / "All Assets" and update the total
* - ELSE *update* it
* - IF the image has a board:
* - Update the image in it's own board
* - ELSE Update the image in the "No Board" board (TODO)
* - THEN *update* it's own board
* - ELSE *update* the "No Board" board
*/
const patches: PatchCollection[] = [];
const { image_name, board_id, image_category } = oldImageDTO;
const { image_name, board_id, image_category, is_intermediate } =
oldImageDTO;
const isChangingFromIntermediate = changes.is_intermediate === false;
// do not add intermediates to gallery cache
if (is_intermediate && !isChangingFromIntermediate) {
return;
}
// determine `categories`, i.e. do we update "All Images" or "All Assets"
const categories = IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES
: ASSETS_CATEGORIES;
// TODO: No Board
// Update `getImageDTO` cache
// update `getImageDTO` cache
patches.push(
dispatch(
imagesApi.util.updateQueryData(
@ -281,9 +288,13 @@ export const imagesApi = api.injectEndpoints({
// Update the "All Image" or "All Assets" board
const queryArgsToUpdate: ListImagesArgs[] = [{ categories }];
// IF the image has a board:
if (board_id) {
// We also need to update the user board
// THEN update it's own board
queryArgsToUpdate.push({ board_id });
} else {
// ELSE update the "No Board" board
queryArgsToUpdate.push({ board_id: 'none' });
}
queryArgsToUpdate.forEach((queryArg) => {
@ -371,12 +382,12 @@ export const imagesApi = api.injectEndpoints({
return;
}
// Add the image to the "All Images" / "All Assets" board
const queryArg = {
categories: IMAGE_CATEGORIES.includes(image_category)
// determine `categories`, i.e. do we update "All Images" or "All Assets"
const categories = IMAGE_CATEGORIES.includes(image_category)
? IMAGE_CATEGORIES
: ASSETS_CATEGORIES,
};
: ASSETS_CATEGORIES;
const queryArg = { categories };
dispatch(
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
@ -410,16 +421,14 @@ export const imagesApi = api.injectEndpoints({
{ dispatch, queryFulfilled, getState }
) {
/**
* Cache changes for addImageToBoard:
* - Remove from "No Board"
* - Remove from `old_board_id` if it has one
* - Add to new `board_id`
* Cache changes for `addImageToBoard`:
* - *update* the `getImageDTO` cache
* - *remove* from "No Board"
* - IF the image has an old `board_id`:
* - THEN *remove* from it's old `board_id`
* - IF the image's `created_at` is within the range of the board's cached images
* - OR the board cache has length of 0 or 1
* - Update the `total` for each board whose cache is updated
* - Update the ImageDTO
*
* TODO: maybe total should just be updated in the boards endpoints?
* - THEN *add* it to new `board_id`
*/
const { image_name, board_id: old_board_id } = oldImageDTO;
@ -427,13 +436,10 @@ export const imagesApi = api.injectEndpoints({
// Figure out the `listImages` caches that we need to update
const removeFromQueryArgs: ListImagesArgs[] = [];
// TODO: No Board
// TODO: Batch
// Remove from No Board
// remove from "No Board"
removeFromQueryArgs.push({ board_id: 'none' });
// Remove from old board
// remove from old board
if (old_board_id) {
removeFromQueryArgs.push({ board_id: old_board_id });
}
@ -534,17 +540,15 @@ export const imagesApi = api.injectEndpoints({
{ dispatch, queryFulfilled, getState }
) {
/**
* Cache changes for removeImageFromBoard:
* - Add to "No Board"
* Cache changes for `removeImageFromBoard`:
* - *update* `getImageDTO`
* - IF the image's `created_at` is within the range of the board's cached images
* - Remove from `old_board_id`
* - Update the ImageDTO
* - THEN *add* to "No Board"
* - *remove* from `old_board_id`
*/
const { image_name, board_id: old_board_id } = imageDTO;
// TODO: Batch
const patches: PatchCollection[] = [];
// Updated imageDTO with new board_id

View File

@ -6,9 +6,9 @@ export const useBoardName = (board_id: BoardId | null | undefined) => {
selectFromResult: ({ data }) => {
let boardName = '';
if (board_id === 'images') {
boardName = 'All Images';
boardName = 'Images';
} else if (board_id === 'assets') {
boardName = 'All Assets';
boardName = 'Assets';
} else if (board_id === 'no_board') {
boardName = 'No Board';
} else if (board_id === 'batch') {

View File

@ -0,0 +1,53 @@
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
ASSETS_CATEGORIES,
BoardId,
IMAGE_CATEGORIES,
INITIAL_IMAGE_LIMIT,
} from 'features/gallery/store/gallerySlice';
import { useMemo } from 'react';
import { ListImagesArgs, useListImagesQuery } from '../endpoints/images';
const baseQueryArg: ListImagesArgs = {
offset: 0,
limit: INITIAL_IMAGE_LIMIT,
is_intermediate: false,
};
const imagesQueryArg: ListImagesArgs = {
categories: IMAGE_CATEGORIES,
...baseQueryArg,
};
const assetsQueryArg: ListImagesArgs = {
categories: ASSETS_CATEGORIES,
...baseQueryArg,
};
const noBoardQueryArg: ListImagesArgs = {
board_id: 'none',
...baseQueryArg,
};
export const useBoardTotal = (board_id: BoardId | null | undefined) => {
const queryArg = useMemo(() => {
if (!board_id) {
return;
}
if (board_id === 'images') {
return imagesQueryArg;
} else if (board_id === 'assets') {
return assetsQueryArg;
} else if (board_id === 'no_board') {
return noBoardQueryArg;
} else {
return { board_id, ...baseQueryArg };
}
}, [board_id]);
const { total } = useListImagesQuery(queryArg ?? skipToken, {
selectFromResult: ({ currentData }) => ({ total: currentData?.total }),
});
return total;
};

View File

@ -167,7 +167,7 @@ export type paths = {
"/api/v1/images/clear-intermediates": {
/**
* Clear Intermediates
* @description Clears first 100 intermediates
* @description Clears all intermediates
*/
post: operations["clear_intermediates"];
};
@ -800,6 +800,13 @@ export type components = {
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
/**
* Resize Mode
* @description The resize mode to use
* @default just_resize
* @enum {string}
*/
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
};
/**
* ControlNetInvocation
@ -859,6 +866,13 @@ export type components = {
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
/**
* Resize Mode
* @description The resize mode used
* @default just_resize
* @enum {string}
*/
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
};
/** ControlNetModelConfig */
ControlNetModelConfig: {
@ -5324,11 +5338,11 @@ export type components = {
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion2ModelFormat
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
@ -5336,11 +5350,11 @@ export type components = {
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@ -6125,7 +6139,7 @@ export type operations = {
};
/**
* Clear Intermediates
* @description Clears first 100 intermediates
* @description Clears all intermediates
*/
clear_intermediates: {
responses: {

View File

@ -2,11 +2,16 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
import { generateColorPalette } from 'theme/util/generateColorPalette';
const BASE = { H: 220, S: 16 };
const ACCENT = { H: 250, S: 52 };
const WORKING = { H: 47, S: 50 };
const WARNING = { H: 28, S: 50 };
const OK = { H: 113, S: 50 };
const ERROR = { H: 0, S: 50 };
const ACCENT = { H: 250, S: 42 };
// const ACCENT = { H: 250, S: 52 };
const WORKING = { H: 47, S: 42 };
// const WORKING = { H: 47, S: 50 };
const WARNING = { H: 28, S: 42 };
// const WARNING = { H: 28, S: 50 };
const OK = { H: 113, S: 42 };
// const OK = { H: 113, S: 50 };
const ERROR = { H: 0, S: 42 };
// const ERROR = { H: 0, S: 50 };
export const InvokeAIColors: InvokeAIThemeColors = {
base: generateColorPalette(BASE.H, BASE.S),

View File

@ -0,0 +1,56 @@
import { editableAnatomy as parts } from '@chakra-ui/anatomy';
import {
createMultiStyleConfigHelpers,
defineStyle,
} from '@chakra-ui/styled-system';
import { mode } from '@chakra-ui/theme-tools';
const { definePartsStyle, defineMultiStyleConfig } =
createMultiStyleConfigHelpers(parts.keys);
const baseStylePreview = defineStyle({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
});
const baseStyleInput = defineStyle((props) => ({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
width: 'full',
_focusVisible: { boxShadow: 'outline' },
_placeholder: { opacity: 0.6 },
'::selection': {
color: mode('accent.900', 'accent.50')(props),
bg: mode('accent.200', 'accent.400')(props),
},
}));
const baseStyleTextarea = defineStyle({
borderRadius: 'md',
py: '1',
transitionProperty: 'common',
transitionDuration: 'normal',
width: 'full',
_focusVisible: { boxShadow: 'outline' },
_placeholder: { opacity: 0.6 },
});
const invokeAI = definePartsStyle((props) => ({
preview: baseStylePreview,
input: baseStyleInput(props),
textarea: baseStyleTextarea,
}));
export const editableTheme = defineMultiStyleConfig({
variants: {
invokeAI,
},
defaultProps: {
size: 'sm',
variant: 'invokeAI',
},
});

View File

@ -4,6 +4,7 @@ import { InvokeAIColors } from './colors/colors';
import { accordionTheme } from './components/accordion';
import { buttonTheme } from './components/button';
import { checkboxTheme } from './components/checkbox';
import { editableTheme } from './components/editable';
import { formLabelTheme } from './components/formLabel';
import { inputTheme } from './components/input';
import { menuTheme } from './components/menu';
@ -72,7 +73,17 @@ export const theme: ThemeOverride = {
selected: {
light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-400)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
},
hoverSelected: {
light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-500)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-300)',
},
hoverUnselected: {
light:
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-200)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-600)',
},
nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`,
},
@ -80,6 +91,7 @@ export const theme: ThemeOverride = {
components: {
Button: buttonTheme, // Button and IconButton
Input: inputTheme,
Editable: editableTheme,
Textarea: textareaTheme,
Tabs: tabsTheme,
Progress: progressTheme,

View File

@ -37,4 +37,7 @@ export const getInputOutlineStyles = (props: StyleFunctionProps) => ({
_placeholder: {
color: mode('base.700', 'base.400')(props),
},
'::selection': {
bg: mode('accent.200', 'accent.400')(props),
},
});