diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py
index 43cad3dcaf..911fede8fb 100644
--- a/invokeai/app/invocations/controlnet_image_processors.py
+++ b/invokeai/app/invocations/controlnet_image_processors.py
@@ -85,8 +85,8 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(
["balanced", "more_prompt", "more_control", "unbalanced"])]
-# crop and fill options not ready yet
-# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
+CONTROLNET_RESIZE_VALUES = Literal[tuple(
+ ["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
class ControlNetModelField(BaseModel):
@@ -111,7 +111,8 @@ class ControlField(BaseModel):
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(
default="balanced", description="The control mode to use")
- # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
+ resize_mode: CONTROLNET_RESIZE_VALUES = Field(
+ default="just_resize", description="The resize mode to use")
@validator("control_weight")
def validate_control_weight(cls, v):
@@ -161,6 +162,7 @@ class ControlNetInvocation(BaseInvocation):
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
+ resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
# fmt: on
class Config(InvocationConfig):
@@ -187,6 +189,7 @@ class ControlNetInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
+ resize_mode=self.resize_mode,
),
)
diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py
index cd15fe156b..b4c3454c88 100644
--- a/invokeai/app/invocations/latent.py
+++ b/invokeai/app/invocations/latent.py
@@ -30,6 +30,7 @@ from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
+from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import (
AttnProcessor2_0,
@@ -288,7 +289,7 @@ class TextToLatentsInvocation(BaseInvocation):
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
- control_image = model.prepare_control_image(
+ control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
@@ -298,13 +299,18 @@ class TextToLatentsInvocation(BaseInvocation):
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
+ resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
- model=control_model, image_tensor=control_image,
+ model=control_model,
+ image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
+ # any resizing needed should currently be happening in prepare_control_image(),
+ # but adding resize_mode to ControlNetData in case needed in the future
+ resize_mode=control_info.resize_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
@@ -601,7 +607,7 @@ class ResizeLatentsInvocation(BaseInvocation):
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
-
+
class Config(InvocationConfig):
schema_extra = {
"ui": {
@@ -647,7 +653,7 @@ class ScaleLatentsInvocation(BaseInvocation):
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
-
+
class Config(InvocationConfig):
schema_extra = {
"ui": {
diff --git a/invokeai/app/util/controlnet_utils.py b/invokeai/app/util/controlnet_utils.py
new file mode 100644
index 0000000000..342fa147c5
--- /dev/null
+++ b/invokeai/app/util/controlnet_utils.py
@@ -0,0 +1,342 @@
+import torch
+import numpy as np
+import cv2
+from PIL import Image
+from diffusers.utils import PIL_INTERPOLATION
+
+from einops import rearrange
+from controlnet_aux.util import HWC3, resize_image
+
+###################################################################
+# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
+###################################################################
+# High Quality Edge Thinning using Pure Python
+# Written by Lvmin Zhangu
+# 2023 April
+# Stanford University
+# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
+
+lvmin_kernels_raw = [
+ np.array([
+ [-1, -1, -1],
+ [0, 1, 0],
+ [1, 1, 1]
+ ], dtype=np.int32),
+ np.array([
+ [0, -1, -1],
+ [1, 1, -1],
+ [0, 1, 0]
+ ], dtype=np.int32)
+]
+
+lvmin_kernels = []
+lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
+lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
+lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
+lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
+
+lvmin_prunings_raw = [
+ np.array([
+ [-1, -1, -1],
+ [-1, 1, -1],
+ [0, 0, -1]
+ ], dtype=np.int32),
+ np.array([
+ [-1, -1, -1],
+ [-1, 1, -1],
+ [-1, 0, 0]
+ ], dtype=np.int32)
+]
+
+lvmin_prunings = []
+lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
+lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
+lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
+lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
+
+
+def remove_pattern(x, kernel):
+ objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
+ objects = np.where(objects > 127)
+ x[objects] = 0
+ return x, objects[0].shape[0] > 0
+
+
+def thin_one_time(x, kernels):
+ y = x
+ is_done = True
+ for k in kernels:
+ y, has_update = remove_pattern(y, k)
+ if has_update:
+ is_done = False
+ return y, is_done
+
+
+def lvmin_thin(x, prunings=True):
+ y = x
+ for i in range(32):
+ y, is_done = thin_one_time(y, lvmin_kernels)
+ if is_done:
+ break
+ if prunings:
+ y, _ = thin_one_time(y, lvmin_prunings)
+ return y
+
+
+def nake_nms(x):
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+ y = np.zeros_like(x)
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+ return y
+
+
+################################################################################
+# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
+################################################################################
+# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
+def pixel_perfect_resolution(
+ image: np.ndarray,
+ target_H: int,
+ target_W: int,
+ resize_mode: str,
+) -> int:
+ """
+ Calculate the estimated resolution for resizing an image while preserving aspect ratio.
+
+ The function first calculates scaling factors for height and width of the image based on the target
+ height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
+ scaling factor to estimate the new resolution.
+
+ If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
+ fits within the target dimensions, potentially leaving some empty space.
+
+ If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
+ dimensions are fully filled, potentially cropping the image.
+
+ After calculating the estimated resolution, the function prints some debugging information.
+
+ Args:
+ image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
+ target_H (int): The target height for the image.
+ target_W (int): The target width for the image.
+ resize_mode (ResizeMode): The mode for resizing.
+
+ Returns:
+ int: The estimated resolution after resizing.
+ """
+ raw_H, raw_W, _ = image.shape
+
+ k0 = float(target_H) / float(raw_H)
+ k1 = float(target_W) / float(raw_W)
+
+ if resize_mode == "fill_resize":
+ estimation = min(k0, k1) * float(min(raw_H, raw_W))
+ else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
+ estimation = max(k0, k1) * float(min(raw_H, raw_W))
+
+ # print(f"Pixel Perfect Computation:")
+ # print(f"resize_mode = {resize_mode}")
+ # print(f"raw_H = {raw_H}")
+ # print(f"raw_W = {raw_W}")
+ # print(f"target_H = {target_H}")
+ # print(f"target_W = {target_W}")
+ # print(f"estimation = {estimation}")
+
+ return int(np.round(estimation))
+
+
+###########################################################################
+# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
+# modified for InvokeAI
+###########################################################################
+# def detectmap_proc(detected_map, module, resize_mode, h, w):
+def np_img_resize(
+ np_img: np.ndarray,
+ resize_mode: str,
+ h: int,
+ w: int,
+ device: torch.device = torch.device('cpu')
+):
+ # if 'inpaint' in module:
+ # np_img = np_img.astype(np.float32)
+ # else:
+ # np_img = HWC3(np_img)
+ np_img = HWC3(np_img)
+
+ def safe_numpy(x):
+ # A very safe method to make sure that Apple/Mac works
+ y = x
+
+ # below is very boring but do not change these. If you change these Apple or Mac may fail.
+ y = y.copy()
+ y = np.ascontiguousarray(y)
+ y = y.copy()
+ return y
+
+ def get_pytorch_control(x):
+ # A very safe method to make sure that Apple/Mac works
+ y = x
+
+ # below is very boring but do not change these. If you change these Apple or Mac may fail.
+ y = torch.from_numpy(y)
+ y = y.float() / 255.0
+ y = rearrange(y, 'h w c -> 1 c h w')
+ y = y.clone()
+ # y = y.to(devices.get_device_for("controlnet"))
+ y = y.to(device)
+ y = y.clone()
+ return y
+
+ def high_quality_resize(x: np.ndarray,
+ size):
+ # Written by lvmin
+ # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
+ inpaint_mask = None
+ if x.ndim == 3 and x.shape[2] == 4:
+ inpaint_mask = x[:, :, 3]
+ x = x[:, :, 0:3]
+
+ new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
+ new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
+ unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
+ is_one_pixel_edge = False
+ is_binary = False
+ if unique_color_count == 2:
+ is_binary = np.min(x) < 16 and np.max(x) > 240
+ if is_binary:
+ xc = x
+ xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
+ xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
+ one_pixel_edge_count = np.where(xc < x)[0].shape[0]
+ all_edge_count = np.where(x > 127)[0].shape[0]
+ is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
+
+ if 2 < unique_color_count < 200:
+ interpolation = cv2.INTER_NEAREST
+ elif new_size_is_smaller:
+ interpolation = cv2.INTER_AREA
+ else:
+ interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
+
+ y = cv2.resize(x, size, interpolation=interpolation)
+ if inpaint_mask is not None:
+ inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
+
+ if is_binary:
+ y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
+ if is_one_pixel_edge:
+ y = nake_nms(y)
+ _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
+ y = lvmin_thin(y, prunings=new_size_is_bigger)
+ else:
+ _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
+ y = np.stack([y] * 3, axis=2)
+
+ if inpaint_mask is not None:
+ inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
+ inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
+ y = np.concatenate([y, inpaint_mask], axis=2)
+
+ return y
+
+ # if resize_mode == external_code.ResizeMode.RESIZE:
+ if resize_mode == "just_resize": # RESIZE
+ np_img = high_quality_resize(np_img, (w, h))
+ np_img = safe_numpy(np_img)
+ return get_pytorch_control(np_img), np_img
+
+ old_h, old_w, _ = np_img.shape
+ old_w = float(old_w)
+ old_h = float(old_h)
+ k0 = float(h) / old_h
+ k1 = float(w) / old_w
+
+ safeint = lambda x: int(np.round(x))
+
+ # if resize_mode == external_code.ResizeMode.OUTER_FIT:
+ if resize_mode == "fill_resize": # OUTER_FIT
+ k = min(k0, k1)
+ borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
+ high_quality_border_color = np.median(borders, axis=0).astype(np_img.dtype)
+ if len(high_quality_border_color) == 4:
+ # Inpaint hijack
+ high_quality_border_color[3] = 255
+ high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
+ np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
+ new_h, new_w, _ = np_img.shape
+ pad_h = max(0, (h - new_h) // 2)
+ pad_w = max(0, (w - new_w) // 2)
+ high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
+ np_img = high_quality_background
+ np_img = safe_numpy(np_img)
+ return get_pytorch_control(np_img), np_img
+ else: # resize_mode == "crop_resize" (INNER_FIT)
+ k = max(k0, k1)
+ np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
+ new_h, new_w, _ = np_img.shape
+ pad_h = max(0, (new_h - h) // 2)
+ pad_w = max(0, (new_w - w) // 2)
+ np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
+ np_img = safe_numpy(np_img)
+ return get_pytorch_control(np_img), np_img
+
+def prepare_control_image(
+ # image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
+ # but now should be able to assume that image is a single PIL.Image, which simplifies things
+ image: Image,
+ # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
+ # latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
+ width=512, # should be 8 * latent.shape[3]
+ height=512, # should be 8 * latent height[2]
+ # batch_size=1, # currently no batching
+ # num_images_per_prompt=1, # currently only single image
+ device="cuda",
+ dtype=torch.float16,
+ do_classifier_free_guidance=True,
+ control_mode="balanced",
+ resize_mode="just_resize_simple",
+):
+ # FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
+ if (resize_mode == "just_resize_simple" or
+ resize_mode == "crop_resize_simple" or
+ resize_mode == "fill_resize_simple"):
+ image = image.convert("RGB")
+ if (resize_mode == "just_resize_simple"):
+ image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ elif (resize_mode == "crop_resize_simple"): # not yet implemented
+ pass
+ elif (resize_mode == "fill_resize_simple"): # not yet implemented
+ pass
+ nimage = np.array(image)
+ nimage = nimage[None, :]
+ nimage = np.concatenate([nimage], axis=0)
+ # normalizing RGB values to [0,1] range (in PIL.Image they are [0-255])
+ nimage = np.array(nimage).astype(np.float32) / 255.0
+ nimage = nimage.transpose(0, 3, 1, 2)
+ timage = torch.from_numpy(nimage)
+
+ # use fancy lvmin controlnet resizing
+ elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
+ nimage = np.array(image)
+ timage, nimage = np_img_resize(
+ np_img=nimage,
+ resize_mode=resize_mode,
+ h=height,
+ w=width,
+ # device=torch.device('cpu')
+ device=device,
+ )
+ else:
+ pass
+ print("ERROR: invalid resize_mode ==> ", resize_mode)
+ exit(1)
+
+ timage = timage.to(device=device, dtype=dtype)
+ cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
+ if do_classifier_free_guidance and not cfg_injection:
+ timage = torch.cat([timage] * 2)
+ return timage
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index 228fbd0585..8acfb100a6 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -219,6 +219,7 @@ class ControlNetData:
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
+ resize_mode: str = Field(default="just_resize")
@dataclass
@@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
- # add 0 to the unconditional batch to keep it unchanged.
+ # prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
@@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
)
-
- # Copied from diffusers pipeline_stable_diffusion_controlnet.py
- # Returns torch.Tensor of shape (batch_size, 3, height, width)
- @staticmethod
- def prepare_control_image(
- image,
- # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
- # latents,
- width=512, # should be 8 * latent.shape[3]
- height=512, # should be 8 * latent height[2]
- batch_size=1,
- num_images_per_prompt=1,
- device="cuda",
- dtype=torch.float16,
- do_classifier_free_guidance=True,
- control_mode="balanced"
- ):
-
- if not isinstance(image, torch.Tensor):
- if isinstance(image, PIL.Image.Image):
- image = [image]
-
- if isinstance(image[0], PIL.Image.Image):
- images = []
- for image_ in image:
- image_ = image_.convert("RGB")
- image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
- image_ = np.array(image_)
- image_ = image_[None, :]
- images.append(image_)
- image = images
- image = np.concatenate(image, axis=0)
- image = np.array(image).astype(np.float32) / 255.0
- image = image.transpose(0, 3, 1, 2)
- image = torch.from_numpy(image)
- elif isinstance(image[0], torch.Tensor):
- image = torch.cat(image, dim=0)
-
- image_batch_size = image.shape[0]
- if image_batch_size == 1:
- repeat_by = batch_size
- else:
- # image batch size is the same as prompt batch size
- repeat_by = num_images_per_prompt
- image = image.repeat_interleave(repeat_by, dim=0)
- image = image.to(device=device, dtype=dtype)
- cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
- if do_classifier_free_guidance and not cfg_injection:
- image = torch.cat([image] * 2)
- return image
diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx
index 368b9f727c..d858e46fdb 100644
--- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx
+++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx
@@ -24,6 +24,7 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
+import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
type ControlNetProps = {
controlNetId: string;
@@ -68,7 +69,7 @@ const ControlNet = (props: ControlNetProps) => {
{
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded}
- variant="link"
+ variant="ghost"
+ sx={{
+ _hover: {
+ bg: 'none',
+ },
+ }}
icon={
{
/>
)}
-
+
{
h: 28,
w: 28,
aspectRatio: '1/1',
- mt: 3,
}}
>
)}
-
+
-
+
+
diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx
new file mode 100644
index 0000000000..4b31ebfc64
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx
@@ -0,0 +1,62 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import IAIMantineSelect from 'common/components/IAIMantineSelect';
+import {
+ ResizeModes,
+ controlNetResizeModeChanged,
+} from 'features/controlNet/store/controlNetSlice';
+import { useCallback, useMemo } from 'react';
+import { useTranslation } from 'react-i18next';
+
+type ParamControlNetResizeModeProps = {
+ controlNetId: string;
+};
+
+const RESIZE_MODE_DATA = [
+ { label: 'Resize', value: 'just_resize' },
+ { label: 'Crop', value: 'crop_resize' },
+ { label: 'Fill', value: 'fill_resize' },
+];
+
+export default function ParamControlNetResizeMode(
+ props: ParamControlNetResizeModeProps
+) {
+ const { controlNetId } = props;
+ const dispatch = useAppDispatch();
+ const selector = useMemo(
+ () =>
+ createSelector(
+ stateSelector,
+ ({ controlNet }) => {
+ const { resizeMode, isEnabled } =
+ controlNet.controlNets[controlNetId];
+ return { resizeMode, isEnabled };
+ },
+ defaultSelectorOptions
+ ),
+ [controlNetId]
+ );
+
+ const { resizeMode, isEnabled } = useAppSelector(selector);
+
+ const { t } = useTranslation();
+
+ const handleResizeModeChange = useCallback(
+ (resizeMode: ResizeModes) => {
+ dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
+ },
+ [controlNetId, dispatch]
+ );
+
+ return (
+
+ );
+}
diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
index 663edfd65f..2f8668115a 100644
--- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
+++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
@@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
+import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
@@ -16,11 +17,13 @@ import {
RequiredControlNetProcessorNode,
} from './types';
-export type ControlModes =
- | 'balanced'
- | 'more_prompt'
- | 'more_control'
- | 'unbalanced';
+export type ControlModes = NonNullable<
+ components['schemas']['ControlNetInvocation']['control_mode']
+>;
+
+export type ResizeModes = NonNullable<
+ components['schemas']['ControlNetInvocation']['resize_mode']
+>;
export const initialControlNet: Omit = {
isEnabled: true,
@@ -29,6 +32,7 @@ export const initialControlNet: Omit = {
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
+ resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
@@ -45,6 +49,7 @@ export type ControlNetConfig = {
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
+ resizeMode: ResizeModes;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
@@ -215,6 +220,16 @@ export const controlNetSlice = createSlice({
const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode;
},
+ controlNetResizeModeChanged: (
+ state,
+ action: PayloadAction<{
+ controlNetId: string;
+ resizeMode: ResizeModes;
+ }>
+ ) => {
+ const { controlNetId, resizeMode } = action.payload;
+ state.controlNets[controlNetId].resizeMode = resizeMode;
+ },
controlNetProcessorParamsChanged: (
state,
action: PayloadAction<{
@@ -342,6 +357,7 @@ export const {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetControlModeChanged,
+ controlNetResizeModeChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
controlNetReset,
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts
index 0f882f248d..578c4371f2 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts
@@ -48,6 +48,7 @@ export const addControlNetToLinearGraph = (
beginStepPct,
endStepPct,
controlMode,
+ resizeMode,
model,
processorType,
weight,
@@ -60,6 +61,7 @@ export const addControlNetToLinearGraph = (
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
+ resize_mode: resizeMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts
index 6a2e176ffd..3ecef092af 100644
--- a/invokeai/frontend/web/src/services/api/schema.d.ts
+++ b/invokeai/frontend/web/src/services/api/schema.d.ts
@@ -167,7 +167,7 @@ export type paths = {
"/api/v1/images/clear-intermediates": {
/**
* Clear Intermediates
- * @description Clears first 100 intermediates
+ * @description Clears all intermediates
*/
post: operations["clear_intermediates"];
};
@@ -800,6 +800,13 @@ export type components = {
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
+ /**
+ * Resize Mode
+ * @description The resize mode to use
+ * @default just_resize
+ * @enum {string}
+ */
+ resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
};
/**
* ControlNetInvocation
@@ -859,6 +866,13 @@ export type components = {
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
+ /**
+ * Resize Mode
+ * @description The resize mode used
+ * @default just_resize
+ * @enum {string}
+ */
+ resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
};
/** ControlNetModelConfig */
ControlNetModelConfig: {
@@ -5324,11 +5338,11 @@ export type components = {
image?: components["schemas"]["ImageField"];
};
/**
- * StableDiffusion2ModelFormat
+ * StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
- StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
+ StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
@@ -5336,11 +5350,11 @@ export type components = {
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
- * StableDiffusion1ModelFormat
+ * StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
- StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
+ StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@@ -6125,7 +6139,7 @@ export type operations = {
};
/**
* Clear Intermediates
- * @description Clears first 100 intermediates
+ * @description Clears all intermediates
*/
clear_intermediates: {
responses: {