mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin' into maryhipp/gallery-ui-updates
This commit is contained in:
commit
f42ef55b2f
@ -1,9 +1,22 @@
|
||||
from enum import Enum
|
||||
from fastapi import Body
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.version import __version__
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
class LogLevel(int, Enum):
|
||||
NotSet = logging.NOTSET
|
||||
Debug = logging.DEBUG
|
||||
Info = logging.INFO
|
||||
Warning = logging.WARNING
|
||||
Error = logging.ERROR
|
||||
Critical = logging.CRITICAL
|
||||
|
||||
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
||||
|
||||
|
||||
@ -34,3 +47,27 @@ async def get_config() -> AppConfig:
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append('patchmatch')
|
||||
return AppConfig(infill_methods=infill_methods)
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
)
|
||||
async def get_log_level(
|
||||
) -> LogLevel:
|
||||
"""Returns the log level"""
|
||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||
|
||||
@app_router.post(
|
||||
"/logging",
|
||||
operation_id="set_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
)
|
||||
async def set_log_level(
|
||||
level: LogLevel = Body(description="New log verbosity level"),
|
||||
) -> LogLevel:
|
||||
"""Sets the log verbosity level"""
|
||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||
|
@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(
|
||||
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
# crop and fill options not ready yet
|
||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[tuple(
|
||||
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
@ -111,7 +111,8 @@ class ControlField(BaseModel):
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(
|
||||
default="balanced", description="The control mode to use")
|
||||
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
|
||||
default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
control_image = prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, image_tensor=control_image,
|
||||
model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
# any resizing needed should currently be happening in prepare_control_image(),
|
||||
# but adding resize_mode to ControlNetData in case needed in the future
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
@ -601,7 +607,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
@ -647,7 +653,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
@ -758,7 +764,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
latents = 0.18215 * latents
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
|
@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import Field, validator
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
|
||||
@ -243,10 +244,31 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
@ -341,6 +363,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
@ -409,6 +432,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
@ -473,10 +497,31 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
@ -579,6 +624,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
@ -647,6 +693,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
342
invokeai/app/util/controlnet_utils.py
Normal file
342
invokeai/app/util/controlnet_utils.py
Normal file
@ -0,0 +1,342 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
|
||||
from einops import rearrange
|
||||
from controlnet_aux.util import HWC3, resize_image
|
||||
|
||||
###################################################################
|
||||
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
||||
###################################################################
|
||||
# High Quality Edge Thinning using Pure Python
|
||||
# Written by Lvmin Zhangu
|
||||
# 2023 April
|
||||
# Stanford University
|
||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||
|
||||
lvmin_kernels_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[0, -1, -1],
|
||||
[1, 1, -1],
|
||||
[0, 1, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_kernels = []
|
||||
lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
|
||||
lvmin_prunings_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[0, 0, -1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[-1, 0, 0]
|
||||
], dtype=np.int32)
|
||||
]
|
||||
|
||||
lvmin_prunings = []
|
||||
lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
|
||||
|
||||
|
||||
def remove_pattern(x, kernel):
|
||||
objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
|
||||
objects = np.where(objects > 127)
|
||||
x[objects] = 0
|
||||
return x, objects[0].shape[0] > 0
|
||||
|
||||
|
||||
def thin_one_time(x, kernels):
|
||||
y = x
|
||||
is_done = True
|
||||
for k in kernels:
|
||||
y, has_update = remove_pattern(y, k)
|
||||
if has_update:
|
||||
is_done = False
|
||||
return y, is_done
|
||||
|
||||
|
||||
def lvmin_thin(x, prunings=True):
|
||||
y = x
|
||||
for i in range(32):
|
||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||
if is_done:
|
||||
break
|
||||
if prunings:
|
||||
y, _ = thin_one_time(y, lvmin_prunings)
|
||||
return y
|
||||
|
||||
|
||||
def nake_nms(x):
|
||||
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
y = np.zeros_like(x)
|
||||
for f in [f1, f2, f3, f4]:
|
||||
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||
return y
|
||||
|
||||
|
||||
################################################################################
|
||||
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
|
||||
################################################################################
|
||||
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
||||
def pixel_perfect_resolution(
|
||||
image: np.ndarray,
|
||||
target_H: int,
|
||||
target_W: int,
|
||||
resize_mode: str,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||
|
||||
The function first calculates scaling factors for height and width of the image based on the target
|
||||
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
|
||||
scaling factor to estimate the new resolution.
|
||||
|
||||
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
|
||||
fits within the target dimensions, potentially leaving some empty space.
|
||||
|
||||
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
|
||||
dimensions are fully filled, potentially cropping the image.
|
||||
|
||||
After calculating the estimated resolution, the function prints some debugging information.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
|
||||
target_H (int): The target height for the image.
|
||||
target_W (int): The target width for the image.
|
||||
resize_mode (ResizeMode): The mode for resizing.
|
||||
|
||||
Returns:
|
||||
int: The estimated resolution after resizing.
|
||||
"""
|
||||
raw_H, raw_W, _ = image.shape
|
||||
|
||||
k0 = float(target_H) / float(raw_H)
|
||||
k1 = float(target_W) / float(raw_W)
|
||||
|
||||
if resize_mode == "fill_resize":
|
||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||
|
||||
# print(f"Pixel Perfect Computation:")
|
||||
# print(f"resize_mode = {resize_mode}")
|
||||
# print(f"raw_H = {raw_H}")
|
||||
# print(f"raw_W = {raw_W}")
|
||||
# print(f"target_H = {target_H}")
|
||||
# print(f"target_W = {target_W}")
|
||||
# print(f"estimation = {estimation}")
|
||||
|
||||
return int(np.round(estimation))
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
|
||||
# modified for InvokeAI
|
||||
###########################################################################
|
||||
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
||||
def np_img_resize(
|
||||
np_img: np.ndarray,
|
||||
resize_mode: str,
|
||||
h: int,
|
||||
w: int,
|
||||
device: torch.device = torch.device('cpu')
|
||||
):
|
||||
# if 'inpaint' in module:
|
||||
# np_img = np_img.astype(np.float32)
|
||||
# else:
|
||||
# np_img = HWC3(np_img)
|
||||
np_img = HWC3(np_img)
|
||||
|
||||
def safe_numpy(x):
|
||||
# A very safe method to make sure that Apple/Mac works
|
||||
y = x
|
||||
|
||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||
y = y.copy()
|
||||
y = np.ascontiguousarray(y)
|
||||
y = y.copy()
|
||||
return y
|
||||
|
||||
def get_pytorch_control(x):
|
||||
# A very safe method to make sure that Apple/Mac works
|
||||
y = x
|
||||
|
||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||
y = torch.from_numpy(y)
|
||||
y = y.float() / 255.0
|
||||
y = rearrange(y, 'h w c -> 1 c h w')
|
||||
y = y.clone()
|
||||
# y = y.to(devices.get_device_for("controlnet"))
|
||||
y = y.to(device)
|
||||
y = y.clone()
|
||||
return y
|
||||
|
||||
def high_quality_resize(x: np.ndarray,
|
||||
size):
|
||||
# Written by lvmin
|
||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||
inpaint_mask = None
|
||||
if x.ndim == 3 and x.shape[2] == 4:
|
||||
inpaint_mask = x[:, :, 3]
|
||||
x = x[:, :, 0:3]
|
||||
|
||||
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
||||
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
||||
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
|
||||
is_one_pixel_edge = False
|
||||
is_binary = False
|
||||
if unique_color_count == 2:
|
||||
is_binary = np.min(x) < 16 and np.max(x) > 240
|
||||
if is_binary:
|
||||
xc = x
|
||||
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
||||
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
|
||||
all_edge_count = np.where(x > 127)[0].shape[0]
|
||||
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
|
||||
|
||||
if 2 < unique_color_count < 200:
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
elif new_size_is_smaller:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
||||
|
||||
y = cv2.resize(x, size, interpolation=interpolation)
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
|
||||
|
||||
if is_binary:
|
||||
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
||||
if is_one_pixel_edge:
|
||||
y = nake_nms(y)
|
||||
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
y = lvmin_thin(y, prunings=new_size_is_bigger)
|
||||
else:
|
||||
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
y = np.stack([y] * 3, axis=2)
|
||||
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
|
||||
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
|
||||
y = np.concatenate([y, inpaint_mask], axis=2)
|
||||
|
||||
return y
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||
if resize_mode == "just_resize": # RESIZE
|
||||
np_img = high_quality_resize(np_img, (w, h))
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
|
||||
old_h, old_w, _ = np_img.shape
|
||||
old_w = float(old_w)
|
||||
old_h = float(old_h)
|
||||
k0 = float(h) / old_h
|
||||
k1 = float(w) / old_w
|
||||
|
||||
safeint = lambda x: int(np.round(x))
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
if resize_mode == "fill_resize": # OUTER_FIT
|
||||
k = min(k0, k1)
|
||||
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
|
||||
high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
|
||||
if len(high_quality_border_color) == 4:
|
||||
# Inpaint hijack
|
||||
high_quality_border_color[3] = 255
|
||||
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (h - new_h) // 2)
|
||||
pad_w = max(0, (w - new_w) // 2)
|
||||
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
|
||||
np_img = high_quality_background
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||
k = max(k0, k1)
|
||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (new_h - h) // 2)
|
||||
pad_w = max(0, (new_w - w) // 2)
|
||||
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
|
||||
def prepare_control_image(
|
||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||
image: Image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
|
||||
width=512, # should be 8 * latent.shape[3]
|
||||
height=512, # should be 8 * latent height[2]
|
||||
# batch_size=1, # currently no batching
|
||||
# num_images_per_prompt=1, # currently only single image
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
control_mode="balanced",
|
||||
resize_mode="just_resize_simple",
|
||||
):
|
||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||
if (resize_mode == "just_resize_simple" or
|
||||
resize_mode == "crop_resize_simple" or
|
||||
resize_mode == "fill_resize_simple"):
|
||||
image = image.convert("RGB")
|
||||
if (resize_mode == "just_resize_simple"):
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
||||
pass
|
||||
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
||||
pass
|
||||
nimage = np.array(image)
|
||||
nimage = nimage[None, :]
|
||||
nimage = np.concatenate([nimage], axis=0)
|
||||
# normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
|
||||
nimage = np.array(nimage).astype(np.float32) / 255.0
|
||||
nimage = nimage.transpose(0, 3, 1, 2)
|
||||
timage = torch.from_numpy(nimage)
|
||||
|
||||
# use fancy lvmin controlnet resizing
|
||||
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
||||
nimage = np.array(image)
|
||||
timage, nimage = np_img_resize(
|
||||
np_img=nimage,
|
||||
resize_mode=resize_mode,
|
||||
h=height,
|
||||
w=width,
|
||||
# device=torch.device('cpu')
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
pass
|
||||
print("ERROR: invalid resize_mode ==> ", resize_mode)
|
||||
exit(1)
|
||||
|
||||
timage = timage.to(device=device, dtype=dtype)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
timage = torch.cat([timage] * 2)
|
||||
return timage
|
@ -1,9 +1,30 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
|
||||
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
||||
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
@ -37,7 +58,24 @@ def stable_diffusion_step_callback(
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
@ -53,3 +91,56 @@ def stable_diffusion_step_callback(
|
||||
step=intermediate_state.step,
|
||||
total_steps=node["steps"],
|
||||
)
|
||||
|
||||
def stable_diffusion_xl_step_callback(
|
||||
context: InvocationContext,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[ 0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[ 0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
#[ 0.0478, 0.1285, 0.0478],
|
||||
#[ 0.1285, 0.2948, 0.1285],
|
||||
#[ 0.0478, 0.1285, 0.0478],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
@ -219,6 +219,7 @@ class ControlNetData:
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
control_mode: str = Field(default="balanced")
|
||||
resize_mode: str = Field(default="just_resize")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if cfg_injection:
|
||||
# Inferred ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
# prepend zeros for unconditional batch
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
debug_image(
|
||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||
)
|
||||
|
||||
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||
@staticmethod
|
||||
def prepare_control_image(
|
||||
image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents,
|
||||
width=512, # should be 8 * latent.shape[3]
|
||||
height=512, # should be 8 * latent height[2]
|
||||
batch_size=1,
|
||||
num_images_per_prompt=1,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
control_mode="balanced"
|
||||
):
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
image = images
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
image = torch.cat([image] * 2)
|
||||
return image
|
||||
|
@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
|
||||
|
||||
type ControlNetProps = {
|
||||
controlNetId: string;
|
||||
@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
gap: 3,
|
||||
p: 3,
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
@ -117,7 +118,12 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
onClick={toggleIsExpanded}
|
||||
variant="link"
|
||||
variant="ghost"
|
||||
sx={{
|
||||
_hover: {
|
||||
bg: 'none',
|
||||
},
|
||||
}}
|
||||
icon={
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
@ -151,7 +157,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
|
||||
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
@ -176,16 +182,16 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
h: 28,
|
||||
w: 28,
|
||||
aspectRatio: '1/1',
|
||||
mt: 3,
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<Box mt={2}>
|
||||
<Flex sx={{ gap: 2 }}>
|
||||
<ParamControlNetControlMode controlNetId={controlNetId} />
|
||||
</Box>
|
||||
<ParamControlNetResizeMode controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
|
||||
|
@ -0,0 +1,62 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
ResizeModes,
|
||||
controlNetResizeModeChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ParamControlNetResizeModeProps = {
|
||||
controlNetId: string;
|
||||
};
|
||||
|
||||
const RESIZE_MODE_DATA = [
|
||||
{ label: 'Resize', value: 'just_resize' },
|
||||
{ label: 'Crop', value: 'crop_resize' },
|
||||
{ label: 'Fill', value: 'fill_resize' },
|
||||
];
|
||||
|
||||
export default function ParamControlNetResizeMode(
|
||||
props: ParamControlNetResizeModeProps
|
||||
) {
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { resizeMode, isEnabled } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { resizeMode, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { resizeMode, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleResizeModeChange = useCallback(
|
||||
(resizeMode: ResizeModes) => {
|
||||
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
disabled={!isEnabled}
|
||||
label="Resize Mode"
|
||||
data={RESIZE_MODE_DATA}
|
||||
value={String(resizeMode)}
|
||||
onChange={handleResizeModeChange}
|
||||
/>
|
||||
);
|
||||
}
|
@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
|
||||
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { components } from 'services/api/schema';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
@ -16,11 +17,13 @@ import {
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
|
||||
export type ControlModes =
|
||||
| 'balanced'
|
||||
| 'more_prompt'
|
||||
| 'more_control'
|
||||
| 'unbalanced';
|
||||
export type ControlModes = NonNullable<
|
||||
components['schemas']['ControlNetInvocation']['control_mode']
|
||||
>;
|
||||
|
||||
export type ResizeModes = NonNullable<
|
||||
components['schemas']['ControlNetInvocation']['resize_mode']
|
||||
>;
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
@ -29,6 +32,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
controlMode: 'balanced',
|
||||
resizeMode: 'just_resize',
|
||||
controlImage: null,
|
||||
processedControlImage: null,
|
||||
processorType: 'canny_image_processor',
|
||||
@ -45,6 +49,7 @@ export type ControlNetConfig = {
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
controlMode: ControlModes;
|
||||
resizeMode: ResizeModes;
|
||||
controlImage: string | null;
|
||||
processedControlImage: string | null;
|
||||
processorType: ControlNetProcessorType;
|
||||
@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({
|
||||
const { controlNetId, controlMode } = action.payload;
|
||||
state.controlNets[controlNetId].controlMode = controlMode;
|
||||
},
|
||||
controlNetResizeModeChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
resizeMode: ResizeModes;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, resizeMode } = action.payload;
|
||||
state.controlNets[controlNetId].resizeMode = resizeMode;
|
||||
},
|
||||
controlNetProcessorParamsChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -342,6 +357,7 @@ export const {
|
||||
controlNetBeginStepPctChanged,
|
||||
controlNetEndStepPctChanged,
|
||||
controlNetControlModeChanged,
|
||||
controlNetResizeModeChanged,
|
||||
controlNetProcessorParamsChanged,
|
||||
controlNetProcessorTypeChanged,
|
||||
controlNetReset,
|
||||
|
@ -30,9 +30,8 @@ const GalleryBoardName = (props: Props) => {
|
||||
const numOfBoardImages = useBoardTotal(selectedBoardId);
|
||||
|
||||
const formattedBoardName = useMemo(() => {
|
||||
if (!boardName || !numOfBoardImages) {
|
||||
return '';
|
||||
}
|
||||
if (!boardName) return '';
|
||||
if (boardName && !numOfBoardImages) return boardName;
|
||||
if (boardName.length > 20) {
|
||||
return `${boardName.substring(0, 20)}... (${numOfBoardImages})`;
|
||||
}
|
||||
|
@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = (
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
resizeMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = (
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_mode: controlMode,
|
||||
resize_mode: resizeMode,
|
||||
control_model: model as ControlNetInvocation['control_model'],
|
||||
control_weight: weight,
|
||||
};
|
||||
|
@ -169,10 +169,11 @@ export const imagesApi = api.injectEndpoints({
|
||||
],
|
||||
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
|
||||
/**
|
||||
* Cache changes for deleteImage:
|
||||
* - Remove from "All Images"
|
||||
* - Remove from image's `board_id` if it has one, or "No Board" if not
|
||||
* - Remove from "Batch"
|
||||
* Cache changes for `deleteImage`:
|
||||
* - *remove* from "All Images" / "All Assets"
|
||||
* - IF it has a board:
|
||||
* - THEN *remove* from it's own board
|
||||
* - ELSE *remove* from "No Board"
|
||||
*/
|
||||
|
||||
const { image_name, board_id, image_category } = imageDTO;
|
||||
@ -181,22 +182,23 @@ export const imagesApi = api.injectEndpoints({
|
||||
// That means constructing the possible query args that are serialized into the cache key...
|
||||
|
||||
const removeFromCacheKeys: ListImagesArgs[] = [];
|
||||
|
||||
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES;
|
||||
|
||||
// All Images board (e.g. no board)
|
||||
// remove from "All Images"
|
||||
removeFromCacheKeys.push({ categories });
|
||||
|
||||
// Board specific
|
||||
if (board_id) {
|
||||
// remove from it's own board
|
||||
removeFromCacheKeys.push({ board_id });
|
||||
} else {
|
||||
// TODO: No Board
|
||||
// remove from "No Board"
|
||||
removeFromCacheKeys.push({ board_id: 'none' });
|
||||
}
|
||||
|
||||
// TODO: Batch
|
||||
|
||||
const patches: PatchCollection[] = [];
|
||||
removeFromCacheKeys.forEach((cacheKey) => {
|
||||
patches.push(
|
||||
@ -240,32 +242,37 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ imageDTO: oldImageDTO, changes: _changes },
|
||||
{ dispatch, queryFulfilled, getState }
|
||||
) {
|
||||
// TODO: Should we handle changes to boards via this mutation? Seems reasonable...
|
||||
|
||||
// let's be extra-sure we do not accidentally change categories
|
||||
const changes = omit(_changes, 'image_category');
|
||||
|
||||
/**
|
||||
* Cache changes for `updateImage`:
|
||||
* - Update the ImageDTO
|
||||
* - Update the image in "All Images" board:
|
||||
* - IF it is in the date range represented by the cache:
|
||||
* - add the image IF it is not already in the cache & update the total
|
||||
* - ELSE update the image IF it is already in the cache
|
||||
* Cache changes for "updateImage":
|
||||
* - *update* "getImageDTO" cache
|
||||
* - for "All Images" || "All Assets":
|
||||
* - IF it is not already in the cache
|
||||
* - THEN *add* it to "All Images" / "All Assets" and update the total
|
||||
* - ELSE *update* it
|
||||
* - IF the image has a board:
|
||||
* - Update the image in it's own board
|
||||
* - ELSE Update the image in the "No Board" board (TODO)
|
||||
* - THEN *update* it's own board
|
||||
* - ELSE *update* the "No Board" board
|
||||
*/
|
||||
|
||||
const patches: PatchCollection[] = [];
|
||||
const { image_name, board_id, image_category } = oldImageDTO;
|
||||
const { image_name, board_id, image_category, is_intermediate } =
|
||||
oldImageDTO;
|
||||
|
||||
const isChangingFromIntermediate = changes.is_intermediate === false;
|
||||
// do not add intermediates to gallery cache
|
||||
if (is_intermediate && !isChangingFromIntermediate) {
|
||||
return;
|
||||
}
|
||||
|
||||
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES;
|
||||
|
||||
// TODO: No Board
|
||||
|
||||
// Update `getImageDTO` cache
|
||||
// update `getImageDTO` cache
|
||||
patches.push(
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData(
|
||||
@ -281,9 +288,13 @@ export const imagesApi = api.injectEndpoints({
|
||||
// Update the "All Image" or "All Assets" board
|
||||
const queryArgsToUpdate: ListImagesArgs[] = [{ categories }];
|
||||
|
||||
// IF the image has a board:
|
||||
if (board_id) {
|
||||
// We also need to update the user board
|
||||
// THEN update it's own board
|
||||
queryArgsToUpdate.push({ board_id });
|
||||
} else {
|
||||
// ELSE update the "No Board" board
|
||||
queryArgsToUpdate.push({ board_id: 'none' });
|
||||
}
|
||||
|
||||
queryArgsToUpdate.forEach((queryArg) => {
|
||||
@ -371,12 +382,12 @@ export const imagesApi = api.injectEndpoints({
|
||||
return;
|
||||
}
|
||||
|
||||
// Add the image to the "All Images" / "All Assets" board
|
||||
const queryArg = {
|
||||
categories: IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES,
|
||||
};
|
||||
// determine `categories`, i.e. do we update "All Images" or "All Assets"
|
||||
const categories = IMAGE_CATEGORIES.includes(image_category)
|
||||
? IMAGE_CATEGORIES
|
||||
: ASSETS_CATEGORIES;
|
||||
|
||||
const queryArg = { categories };
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
|
||||
@ -410,16 +421,14 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ dispatch, queryFulfilled, getState }
|
||||
) {
|
||||
/**
|
||||
* Cache changes for addImageToBoard:
|
||||
* - Remove from "No Board"
|
||||
* - Remove from `old_board_id` if it has one
|
||||
* - Add to new `board_id`
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* Cache changes for `addImageToBoard`:
|
||||
* - *update* the `getImageDTO` cache
|
||||
* - *remove* from "No Board"
|
||||
* - IF the image has an old `board_id`:
|
||||
* - THEN *remove* from it's old `board_id`
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* - OR the board cache has length of 0 or 1
|
||||
* - Update the `total` for each board whose cache is updated
|
||||
* - Update the ImageDTO
|
||||
*
|
||||
* TODO: maybe total should just be updated in the boards endpoints?
|
||||
* - THEN *add* it to new `board_id`
|
||||
*/
|
||||
|
||||
const { image_name, board_id: old_board_id } = oldImageDTO;
|
||||
@ -427,13 +436,10 @@ export const imagesApi = api.injectEndpoints({
|
||||
// Figure out the `listImages` caches that we need to update
|
||||
const removeFromQueryArgs: ListImagesArgs[] = [];
|
||||
|
||||
// TODO: No Board
|
||||
// TODO: Batch
|
||||
|
||||
// Remove from No Board
|
||||
// remove from "No Board"
|
||||
removeFromQueryArgs.push({ board_id: 'none' });
|
||||
|
||||
// Remove from old board
|
||||
// remove from old board
|
||||
if (old_board_id) {
|
||||
removeFromQueryArgs.push({ board_id: old_board_id });
|
||||
}
|
||||
@ -534,17 +540,15 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ dispatch, queryFulfilled, getState }
|
||||
) {
|
||||
/**
|
||||
* Cache changes for removeImageFromBoard:
|
||||
* - Add to "No Board"
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* - Remove from `old_board_id`
|
||||
* - Update the ImageDTO
|
||||
* Cache changes for `removeImageFromBoard`:
|
||||
* - *update* `getImageDTO`
|
||||
* - IF the image's `created_at` is within the range of the board's cached images
|
||||
* - THEN *add* to "No Board"
|
||||
* - *remove* from `old_board_id`
|
||||
*/
|
||||
|
||||
const { image_name, board_id: old_board_id } = imageDTO;
|
||||
|
||||
// TODO: Batch
|
||||
|
||||
const patches: PatchCollection[] = [];
|
||||
|
||||
// Updated imageDTO with new board_id
|
||||
|
@ -167,7 +167,7 @@ export type paths = {
|
||||
"/api/v1/images/clear-intermediates": {
|
||||
/**
|
||||
* Clear Intermediates
|
||||
* @description Clears first 100 intermediates
|
||||
* @description Clears all intermediates
|
||||
*/
|
||||
post: operations["clear_intermediates"];
|
||||
};
|
||||
@ -800,6 +800,13 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||
/**
|
||||
* Resize Mode
|
||||
* @description The resize mode to use
|
||||
* @default just_resize
|
||||
* @enum {string}
|
||||
*/
|
||||
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
|
||||
};
|
||||
/**
|
||||
* ControlNetInvocation
|
||||
@ -859,6 +866,13 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||
/**
|
||||
* Resize Mode
|
||||
* @description The resize mode used
|
||||
* @default just_resize
|
||||
* @enum {string}
|
||||
*/
|
||||
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
|
||||
};
|
||||
/** ControlNetModelConfig */
|
||||
ControlNetModelConfig: {
|
||||
@ -5324,11 +5338,11 @@ export type components = {
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
@ -5336,11 +5350,11 @@ export type components = {
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
@ -6125,7 +6139,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Clear Intermediates
|
||||
* @description Clears first 100 intermediates
|
||||
* @description Clears all intermediates
|
||||
*/
|
||||
clear_intermediates: {
|
||||
responses: {
|
||||
|
Loading…
Reference in New Issue
Block a user