Merge branch 'main' into feat/nodes/cpu-noise

This commit is contained in:
blessedcoolant 2023-06-28 18:22:08 +12:00 committed by GitHub
commit 75614bbba3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 200 additions and 92 deletions

View File

@ -1,10 +1,11 @@
# InvokeAI nodes for ControlNet image preprocessors
# Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float, bool
import cv2
import numpy as np
from typing import Literal, Optional, Union, List
from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator
@ -29,8 +30,13 @@ from controlnet_aux import (
ContentShuffleDetector,
ZoeDetector,
MediapipeFaceDetector,
SamDetector,
LeresDetector,
)
from controlnet_aux.util import HWC3, ade_palette
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
@ -95,6 +101,9 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
# crop and fill options not ready yet
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
@ -105,7 +114,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def abs_le_one(cls, v):
@ -180,7 +190,7 @@ class ControlNetInvocation(BaseInvocation):
),
)
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet"""
@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# fmt: on
def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
# so convert to RGB if needed
if image.mode == 'RGBA':
image = image.convert('RGB')
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies leres processing to image"""
# fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
boost: bool = Field(default=False, description="Whether to use boost mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(image,
thr_a=self.thr_a,
thr_b=self.thr_b,
boost=self.boost,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
# fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# fmt: on
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
return np_img
H, W, C = np_img.shape
H = int(float(H) / float(down_sampling_rate))
W = int(float(W) / float(down_sampling_rate))
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(np_img,
#res=self.tile_size,
down_sampling_rate=self.down_sampling_rate
)
processed_image = Image.fromarray(processed_np_image)
return processed_image
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies segment anything processing to image"""
# fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img)
return processed_image
class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
# so using ADE20k color palette instead
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
h, w = anns[0]['segmentation'].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann['segmentation']
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)

View File

@ -133,20 +133,19 @@ class StepParamEasingInvocation(BaseInvocation):
postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing")
logger.debug("start_step: " + str(start_step))
logger.debug("end_step: " + str(end_step))
logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("postlist size: " + str(len(postlist)))
logger.debug("prelist: " + str(prelist))
logger.debug("postlist: " + str(postlist))
context.services.logger.debug("start_step: " + str(start_step))
context.services.logger.debug("end_step: " + str(end_step))
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
context.services.logger.debug("num_presteps: " + str(num_presteps))
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
context.services.logger.debug("prelist size: " + str(len(prelist)))
context.services.logger.debug("postlist size: " + str(len(postlist)))
context.services.logger.debug("prelist: " + str(prelist))
context.services.logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
logger.debug("easing class: " + str(easing_class))
context.services.logger.debug("easing class: " + str(easing_class))
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
@ -156,7 +155,7 @@ class StepParamEasingInvocation(BaseInvocation):
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
@ -166,14 +165,14 @@ class StepParamEasingInvocation(BaseInvocation):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals))
else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
@ -206,12 +205,12 @@ class StepParamEasingInvocation(BaseInvocation):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist)))
context.services.logger.debug("prelist size: " + str(len(prelist)))
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
context.services.logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist

View File

@ -35,8 +35,8 @@ const ParamDynamicPromptsCollapse = () => {
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsMaxPrompts />
<ParamDynamicPromptsCombinatorial />
<ParamDynamicPromptsMaxPrompts />
</Flex>
</IAICollapse>
);

View File

@ -9,17 +9,18 @@ import { stateSelector } from 'app/store/store';
const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts } = state.dynamicPrompts;
const { maxPrompts, combinatorial } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax };
return { maxPrompts, min, sliderMax, inputMax, combinatorial };
},
defaultSelectorOptions
);
const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector);
const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(
@ -36,6 +37,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
return (
<IAISlider
label="Max Prompts"
isDisabled={!combinatorial}
min={min}
max={sliderMax}
value={maxPrompts}

View File

@ -37,7 +37,7 @@ export const addDynamicPromptsToGraph = (
const dynamicPromptNode: DynamicPromptInvocation = {
id: DYNAMIC_PROMPT,
type: 'dynamic_prompt',
max_prompts: maxPrompts,
max_prompts: combinatorial ? maxPrompts : iterations,
combinatorial,
prompt: positivePrompt,
};

View File

@ -16,7 +16,8 @@ const selector = createSelector([stateSelector], (state) => {
state.config.sd.iterations;
const { iterations } = state.generation;
const { shouldUseSliders } = state.ui;
const isDisabled = state.dynamicPrompts.isEnabled;
const isDisabled =
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
const step = state.hotkeys.shift ? fineStep : coarseStep;

View File

@ -4,91 +4,89 @@ import { components } from './schema';
type schemas = components['schemas'];
/**
* Helper type to extract the invocation type from the schema.
* Also flags the `type` property as required.
* Extracts the schema type from the schema.
*/
type Invocation<T extends keyof schemas> = O.Required<schemas[T], 'type'>;
type S<T extends keyof components['schemas']> = components['schemas'][T];
/**
* Types from the API, re-exported from the types generated by `openapi-typescript`.
* Extracts the node type from the schema.
* Also flags the `type` property as required.
*/
type N<T extends keyof components['schemas']> = O.Required<
components['schemas'][T],
'type'
>;
// Images
export type ImageDTO = schemas['ImageDTO'];
export type BoardDTO = schemas['BoardDTO'];
export type BoardChanges = schemas['BoardChanges'];
export type ImageChanges = schemas['ImageRecordChanges'];
export type ImageCategory = schemas['ImageCategory'];
export type ResourceOrigin = schemas['ResourceOrigin'];
export type ImageField = schemas['ImageField'];
export type ImageDTO = S<'ImageDTO'>;
export type BoardDTO = S<'BoardDTO'>;
export type BoardChanges = S<'BoardChanges'>;
export type ImageChanges = S<'ImageRecordChanges'>;
export type ImageCategory = S<'ImageCategory'>;
export type ResourceOrigin = S<'ResourceOrigin'>;
export type ImageField = S<'ImageField'>;
export type OffsetPaginatedResults_BoardDTO_ =
schemas['OffsetPaginatedResults_BoardDTO_'];
S<'OffsetPaginatedResults_BoardDTO_'>;
export type OffsetPaginatedResults_ImageDTO_ =
schemas['OffsetPaginatedResults_ImageDTO_'];
S<'OffsetPaginatedResults_ImageDTO_'>;
// Models
export type ModelType = schemas['ModelType'];
export type BaseModelType = schemas['BaseModelType'];
export type PipelineModelField = schemas['PipelineModelField'];
export type ModelsList = schemas['ModelsList'];
export type ModelType = S<'ModelType'>;
export type BaseModelType = S<'BaseModelType'>;
export type PipelineModelField = S<'PipelineModelField'>;
export type ModelsList = S<'ModelsList'>;
// Graphs
export type Graph = schemas['Graph'];
export type Edge = schemas['Edge'];
export type GraphExecutionState = schemas['GraphExecutionState'];
export type Graph = S<'Graph'>;
export type Edge = S<'Edge'>;
export type GraphExecutionState = S<'GraphExecutionState'>;
// General nodes
export type CollectInvocation = Invocation<'CollectInvocation'>;
export type IterateInvocation = Invocation<'IterateInvocation'>;
export type RangeInvocation = Invocation<'RangeInvocation'>;
export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>;
export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>;
export type InpaintInvocation = Invocation<'InpaintInvocation'>;
export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>;
export type RandomIntInvocation = Invocation<'RandomIntInvocation'>;
export type CompelInvocation = Invocation<'CompelInvocation'>;
export type DynamicPromptInvocation = Invocation<'DynamicPromptInvocation'>;
export type NoiseInvocation = Invocation<'NoiseInvocation'>;
export type TextToLatentsInvocation = Invocation<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation =
Invocation<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = Invocation<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = Invocation<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation =
Invocation<'PipelineModelLoaderInvocation'>;
export type CollectInvocation = N<'CollectInvocation'>;
export type IterateInvocation = N<'IterateInvocation'>;
export type RangeInvocation = N<'RangeInvocation'>;
export type RandomRangeInvocation = N<'RandomRangeInvocation'>;
export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>;
export type InpaintInvocation = N<'InpaintInvocation'>;
export type ImageResizeInvocation = N<'ImageResizeInvocation'>;
export type RandomIntInvocation = N<'RandomIntInvocation'>;
export type CompelInvocation = N<'CompelInvocation'>;
export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>;
export type NoiseInvocation = N<'NoiseInvocation'>;
export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
// ControlNet Nodes
export type ControlNetInvocation = Invocation<'ControlNetInvocation'>;
export type CannyImageProcessorInvocation =
Invocation<'CannyImageProcessorInvocation'>;
export type ControlNetInvocation = N<'ControlNetInvocation'>;
export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>;
export type ContentShuffleImageProcessorInvocation =
Invocation<'ContentShuffleImageProcessorInvocation'>;
export type HedImageProcessorInvocation =
Invocation<'HedImageProcessorInvocation'>;
N<'ContentShuffleImageProcessorInvocation'>;
export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>;
export type LineartAnimeImageProcessorInvocation =
Invocation<'LineartAnimeImageProcessorInvocation'>;
N<'LineartAnimeImageProcessorInvocation'>;
export type LineartImageProcessorInvocation =
Invocation<'LineartImageProcessorInvocation'>;
N<'LineartImageProcessorInvocation'>;
export type MediapipeFaceProcessorInvocation =
Invocation<'MediapipeFaceProcessorInvocation'>;
N<'MediapipeFaceProcessorInvocation'>;
export type MidasDepthImageProcessorInvocation =
Invocation<'MidasDepthImageProcessorInvocation'>;
export type MlsdImageProcessorInvocation =
Invocation<'MlsdImageProcessorInvocation'>;
N<'MidasDepthImageProcessorInvocation'>;
export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>;
export type NormalbaeImageProcessorInvocation =
Invocation<'NormalbaeImageProcessorInvocation'>;
N<'NormalbaeImageProcessorInvocation'>;
export type OpenposeImageProcessorInvocation =
Invocation<'OpenposeImageProcessorInvocation'>;
export type PidiImageProcessorInvocation =
Invocation<'PidiImageProcessorInvocation'>;
N<'OpenposeImageProcessorInvocation'>;
export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>;
export type ZoeDepthImageProcessorInvocation =
Invocation<'ZoeDepthImageProcessorInvocation'>;
N<'ZoeDepthImageProcessorInvocation'>;
// Node Outputs
export type ImageOutput = schemas['ImageOutput'];
export type MaskOutput = schemas['MaskOutput'];
export type PromptOutput = schemas['PromptOutput'];
export type IterateInvocationOutput = schemas['IterateInvocationOutput'];
export type CollectInvocationOutput = schemas['CollectInvocationOutput'];
export type LatentsOutput = schemas['LatentsOutput'];
export type GraphInvocationOutput = schemas['GraphInvocationOutput'];
export type ImageOutput = S<'ImageOutput'>;
export type MaskOutput = S<'MaskOutput'>;
export type PromptOutput = S<'PromptOutput'>;
export type IterateInvocationOutput = S<'IterateInvocationOutput'>;
export type CollectInvocationOutput = S<'CollectInvocationOutput'>;
export type LatentsOutput = S<'LatentsOutput'>;
export type GraphInvocationOutput = S<'GraphInvocationOutput'>;

View File

@ -39,7 +39,7 @@ dependencies = [
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=1.2.1",
"controlnet-aux>=0.0.4",
"controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",
"diffusers[torch]~=0.17.1",