diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py
index baf558ac24..8cfe35598d 100644
--- a/invokeai/app/invocations/controlnet_image_processors.py
+++ b/invokeai/app/invocations/controlnet_image_processors.py
@@ -1,10 +1,11 @@
-# InvokeAI nodes for ControlNet image preprocessors
+# Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float, bool
+import cv2
import numpy as np
-from typing import Literal, Optional, Union, List
+from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator
@@ -29,8 +30,13 @@ from controlnet_aux import (
ContentShuffleDetector,
ZoeDetector,
MediapipeFaceDetector,
+ SamDetector,
+ LeresDetector,
)
+from controlnet_aux.util import HWC3, ade_palette
+
+
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
@@ -95,6 +101,9 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
+# crop and fill options not ready yet
+# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
+
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
@@ -105,7 +114,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
- control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
+ control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
+ # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def abs_le_one(cls, v):
@@ -180,7 +190,7 @@ class ControlNetInvocation(BaseInvocation):
),
)
-# TODO: move image processors to separate file (image_analysis.py
+
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet"""
@@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# fmt: on
def run_processor(self, image):
+ # MediaPipeFaceDetector throws an error if image has alpha channel
+ # so convert to RGB if needed
+ if image.mode == 'RGBA':
+ image = image.convert('RGB')
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image
+
+class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies leres processing to image"""
+ # fmt: off
+ type: Literal["leres_image_processor"] = "leres_image_processor"
+ # Inputs
+ thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
+ thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
+ boost: bool = Field(default=False, description="Whether to use boost mode")
+ detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
+ image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
+ # fmt: on
+
+ def run_processor(self, image):
+ leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = leres_processor(image,
+ thr_a=self.thr_a,
+ thr_b=self.thr_b,
+ boost=self.boost,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution)
+ return processed_image
+
+
+class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+
+ # fmt: off
+ type: Literal["tile_image_processor"] = "tile_image_processor"
+ # Inputs
+ #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
+ down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
+ # fmt: on
+
+ # tile_resample copied from sd-webui-controlnet/scripts/processor.py
+ def tile_resample(self,
+ np_img: np.ndarray,
+ res=512, # never used?
+ down_sampling_rate=1.0,
+ ):
+ np_img = HWC3(np_img)
+ if down_sampling_rate < 1.1:
+ return np_img
+ H, W, C = np_img.shape
+ H = int(float(H) / float(down_sampling_rate))
+ W = int(float(W) / float(down_sampling_rate))
+ np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
+ return np_img
+
+ def run_processor(self, img):
+ np_img = np.array(img, dtype=np.uint8)
+ processed_np_image = self.tile_resample(np_img,
+ #res=self.tile_size,
+ down_sampling_rate=self.down_sampling_rate
+ )
+ processed_image = Image.fromarray(processed_np_image)
+ return processed_image
+
+
+
+
+class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies segment anything processing to image"""
+ # fmt: off
+ type: Literal["segment_anything_processor"] = "segment_anything_processor"
+ # fmt: on
+
+ def run_processor(self, image):
+ # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
+ segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
+ np_img = np.array(image, dtype=np.uint8)
+ processed_image = segment_anything_processor(np_img)
+ return processed_image
+
+class SamDetectorReproducibleColors(SamDetector):
+
+ # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
+ # base class show_anns() method randomizes colors,
+ # which seems to also lead to non-reproducible image generation
+ # so using ADE20k color palette instead
+ def show_anns(self, anns: List[Dict]):
+ if len(anns) == 0:
+ return
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
+ h, w = anns[0]['segmentation'].shape
+ final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
+ palette = ade_palette()
+ for i, ann in enumerate(sorted_anns):
+ m = ann['segmentation']
+ img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
+ # doing modulo just in case number of annotated regions exceeds number of colors in palette
+ ann_color = palette[i % len(palette)]
+ img[:, :] = ann_color
+ final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
+ return np.array(final_img, dtype=np.uint8)
diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py
index 1ff6261b88..e79763a35e 100644
--- a/invokeai/app/invocations/param_easing.py
+++ b/invokeai/app/invocations/param_easing.py
@@ -133,20 +133,19 @@ class StepParamEasingInvocation(BaseInvocation):
postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics:
- logger = InvokeAILogger.getLogger(name="StepParamEasing")
- logger.debug("start_step: " + str(start_step))
- logger.debug("end_step: " + str(end_step))
- logger.debug("num_easing_steps: " + str(num_easing_steps))
- logger.debug("num_presteps: " + str(num_presteps))
- logger.debug("num_poststeps: " + str(num_poststeps))
- logger.debug("prelist size: " + str(len(prelist)))
- logger.debug("postlist size: " + str(len(postlist)))
- logger.debug("prelist: " + str(prelist))
- logger.debug("postlist: " + str(postlist))
+ context.services.logger.debug("start_step: " + str(start_step))
+ context.services.logger.debug("end_step: " + str(end_step))
+ context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
+ context.services.logger.debug("num_presteps: " + str(num_presteps))
+ context.services.logger.debug("num_poststeps: " + str(num_poststeps))
+ context.services.logger.debug("prelist size: " + str(len(prelist)))
+ context.services.logger.debug("postlist size: " + str(len(postlist)))
+ context.services.logger.debug("prelist: " + str(prelist))
+ context.services.logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
- logger.debug("easing class: " + str(easing_class))
+ context.services.logger.debug("easing class: " + str(easing_class))
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
@@ -156,7 +155,7 @@ class StepParamEasingInvocation(BaseInvocation):
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
- if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
+ if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
@@ -166,14 +165,14 @@ class StepParamEasingInvocation(BaseInvocation):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)
if log_diagnostics:
- logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
+ context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals))
else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics:
- logger.debug("base easing vals: " + str(base_easing_vals))
- logger.debug("mirror easing vals: " + str(mirror_easing_vals))
+ context.services.logger.debug("base easing vals: " + str(base_easing_vals))
+ context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
@@ -206,12 +205,12 @@ class StepParamEasingInvocation(BaseInvocation):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
if log_diagnostics:
- logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
+ context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics:
- logger.debug("prelist size: " + str(len(prelist)))
- logger.debug("easing_list size: " + str(len(easing_list)))
- logger.debug("postlist size: " + str(len(postlist)))
+ context.services.logger.debug("prelist size: " + str(len(prelist)))
+ context.services.logger.debug("easing_list size: " + str(len(easing_list)))
+ context.services.logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist
diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx
index eeaf1b81ec..1aefecf3e6 100644
--- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx
+++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx
@@ -35,8 +35,8 @@ const ParamDynamicPromptsCollapse = () => {
withSwitch
>
-
+
);
diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx
index ab56abaa35..19f02ae3e5 100644
--- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx
+++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx
@@ -9,17 +9,18 @@ import { stateSelector } from 'app/store/store';
const selector = createSelector(
stateSelector,
(state) => {
- const { maxPrompts } = state.dynamicPrompts;
+ const { maxPrompts, combinatorial } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;
- return { maxPrompts, min, sliderMax, inputMax };
+ return { maxPrompts, min, sliderMax, inputMax, combinatorial };
},
defaultSelectorOptions
);
const ParamDynamicPromptsMaxPrompts = () => {
- const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector);
+ const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
+ useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(
@@ -36,6 +37,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
return (
{
state.config.sd.iterations;
const { iterations } = state.generation;
const { shouldUseSliders } = state.ui;
- const isDisabled = state.dynamicPrompts.isEnabled;
+ const isDisabled =
+ state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
const step = state.hotkeys.shift ? fineStep : coarseStep;
diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts
index a995d9c298..2a2f90f434 100644
--- a/invokeai/frontend/web/src/services/api/types.d.ts
+++ b/invokeai/frontend/web/src/services/api/types.d.ts
@@ -4,91 +4,89 @@ import { components } from './schema';
type schemas = components['schemas'];
/**
- * Helper type to extract the invocation type from the schema.
- * Also flags the `type` property as required.
+ * Extracts the schema type from the schema.
*/
-type Invocation = O.Required;
+type S = components['schemas'][T];
/**
- * Types from the API, re-exported from the types generated by `openapi-typescript`.
+ * Extracts the node type from the schema.
+ * Also flags the `type` property as required.
*/
+type N = O.Required<
+ components['schemas'][T],
+ 'type'
+>;
// Images
-export type ImageDTO = schemas['ImageDTO'];
-export type BoardDTO = schemas['BoardDTO'];
-export type BoardChanges = schemas['BoardChanges'];
-export type ImageChanges = schemas['ImageRecordChanges'];
-export type ImageCategory = schemas['ImageCategory'];
-export type ResourceOrigin = schemas['ResourceOrigin'];
-export type ImageField = schemas['ImageField'];
+export type ImageDTO = S<'ImageDTO'>;
+export type BoardDTO = S<'BoardDTO'>;
+export type BoardChanges = S<'BoardChanges'>;
+export type ImageChanges = S<'ImageRecordChanges'>;
+export type ImageCategory = S<'ImageCategory'>;
+export type ResourceOrigin = S<'ResourceOrigin'>;
+export type ImageField = S<'ImageField'>;
export type OffsetPaginatedResults_BoardDTO_ =
- schemas['OffsetPaginatedResults_BoardDTO_'];
+ S<'OffsetPaginatedResults_BoardDTO_'>;
export type OffsetPaginatedResults_ImageDTO_ =
- schemas['OffsetPaginatedResults_ImageDTO_'];
+ S<'OffsetPaginatedResults_ImageDTO_'>;
// Models
-export type ModelType = schemas['ModelType'];
-export type BaseModelType = schemas['BaseModelType'];
-export type PipelineModelField = schemas['PipelineModelField'];
-export type ModelsList = schemas['ModelsList'];
+export type ModelType = S<'ModelType'>;
+export type BaseModelType = S<'BaseModelType'>;
+export type PipelineModelField = S<'PipelineModelField'>;
+export type ModelsList = S<'ModelsList'>;
// Graphs
-export type Graph = schemas['Graph'];
-export type Edge = schemas['Edge'];
-export type GraphExecutionState = schemas['GraphExecutionState'];
+export type Graph = S<'Graph'>;
+export type Edge = S<'Edge'>;
+export type GraphExecutionState = S<'GraphExecutionState'>;
// General nodes
-export type CollectInvocation = Invocation<'CollectInvocation'>;
-export type IterateInvocation = Invocation<'IterateInvocation'>;
-export type RangeInvocation = Invocation<'RangeInvocation'>;
-export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>;
-export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>;
-export type InpaintInvocation = Invocation<'InpaintInvocation'>;
-export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>;
-export type RandomIntInvocation = Invocation<'RandomIntInvocation'>;
-export type CompelInvocation = Invocation<'CompelInvocation'>;
-export type DynamicPromptInvocation = Invocation<'DynamicPromptInvocation'>;
-export type NoiseInvocation = Invocation<'NoiseInvocation'>;
-export type TextToLatentsInvocation = Invocation<'TextToLatentsInvocation'>;
-export type LatentsToLatentsInvocation =
- Invocation<'LatentsToLatentsInvocation'>;
-export type ImageToLatentsInvocation = Invocation<'ImageToLatentsInvocation'>;
-export type LatentsToImageInvocation = Invocation<'LatentsToImageInvocation'>;
-export type PipelineModelLoaderInvocation =
- Invocation<'PipelineModelLoaderInvocation'>;
+export type CollectInvocation = N<'CollectInvocation'>;
+export type IterateInvocation = N<'IterateInvocation'>;
+export type RangeInvocation = N<'RangeInvocation'>;
+export type RandomRangeInvocation = N<'RandomRangeInvocation'>;
+export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>;
+export type InpaintInvocation = N<'InpaintInvocation'>;
+export type ImageResizeInvocation = N<'ImageResizeInvocation'>;
+export type RandomIntInvocation = N<'RandomIntInvocation'>;
+export type CompelInvocation = N<'CompelInvocation'>;
+export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>;
+export type NoiseInvocation = N<'NoiseInvocation'>;
+export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
+export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
+export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
+export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
+export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
// ControlNet Nodes
-export type ControlNetInvocation = Invocation<'ControlNetInvocation'>;
-export type CannyImageProcessorInvocation =
- Invocation<'CannyImageProcessorInvocation'>;
+export type ControlNetInvocation = N<'ControlNetInvocation'>;
+export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>;
export type ContentShuffleImageProcessorInvocation =
- Invocation<'ContentShuffleImageProcessorInvocation'>;
-export type HedImageProcessorInvocation =
- Invocation<'HedImageProcessorInvocation'>;
+ N<'ContentShuffleImageProcessorInvocation'>;
+export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>;
export type LineartAnimeImageProcessorInvocation =
- Invocation<'LineartAnimeImageProcessorInvocation'>;
+ N<'LineartAnimeImageProcessorInvocation'>;
export type LineartImageProcessorInvocation =
- Invocation<'LineartImageProcessorInvocation'>;
+ N<'LineartImageProcessorInvocation'>;
export type MediapipeFaceProcessorInvocation =
- Invocation<'MediapipeFaceProcessorInvocation'>;
+ N<'MediapipeFaceProcessorInvocation'>;
export type MidasDepthImageProcessorInvocation =
- Invocation<'MidasDepthImageProcessorInvocation'>;
-export type MlsdImageProcessorInvocation =
- Invocation<'MlsdImageProcessorInvocation'>;
+ N<'MidasDepthImageProcessorInvocation'>;
+export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>;
export type NormalbaeImageProcessorInvocation =
- Invocation<'NormalbaeImageProcessorInvocation'>;
+ N<'NormalbaeImageProcessorInvocation'>;
export type OpenposeImageProcessorInvocation =
- Invocation<'OpenposeImageProcessorInvocation'>;
-export type PidiImageProcessorInvocation =
- Invocation<'PidiImageProcessorInvocation'>;
+ N<'OpenposeImageProcessorInvocation'>;
+export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>;
export type ZoeDepthImageProcessorInvocation =
- Invocation<'ZoeDepthImageProcessorInvocation'>;
+ N<'ZoeDepthImageProcessorInvocation'>;
// Node Outputs
-export type ImageOutput = schemas['ImageOutput'];
-export type MaskOutput = schemas['MaskOutput'];
-export type PromptOutput = schemas['PromptOutput'];
-export type IterateInvocationOutput = schemas['IterateInvocationOutput'];
-export type CollectInvocationOutput = schemas['CollectInvocationOutput'];
-export type LatentsOutput = schemas['LatentsOutput'];
-export type GraphInvocationOutput = schemas['GraphInvocationOutput'];
+export type ImageOutput = S<'ImageOutput'>;
+export type MaskOutput = S<'MaskOutput'>;
+export type PromptOutput = S<'PromptOutput'>;
+export type IterateInvocationOutput = S<'IterateInvocationOutput'>;
+export type CollectInvocationOutput = S<'CollectInvocationOutput'>;
+export type LatentsOutput = S<'LatentsOutput'>;
+export type GraphInvocationOutput = S<'GraphInvocationOutput'>;
diff --git a/pyproject.toml b/pyproject.toml
index 03396312ac..6e5b8f4e22 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,7 +39,7 @@ dependencies = [
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=1.2.1",
- "controlnet-aux>=0.0.4",
+ "controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",
"diffusers[torch]~=0.17.1",