Add control_mode parameter to ControlNet (#3535)

This PR adds the "control_mode" option to ControlNet implementation. 
Possible control_mode options are: 

- balanced -- this is the default, same as previous implementation
without control_mode
- more_prompt -- pays more attention to the prompt
- more _control -- pays more attention to the ControlNet (in earlier
implementations this was called "guess_mode")
- unbalanced -- pays even more attention to the ControlNet 

balanced, more_prompt, and more_control should be nearly identical to
the equivalent options in the [auto1111 sd-webui-controlnet
extension](https://github.com/Mikubill/sd-webui-controlnet#more-control-modes-previously-called-guess-mode)

The changes to enable balanced, more_prompt, and more_control are
managed deeper in the code by two booleans, "soft_injection" and
"cfg_injection". The three control mode options in sd-webui-controlnet
map to these booleans like:
 
!soft_injection && !cfg_injection ⇒  BALANCED            
 soft_injection &&  cfg_injection ⇒  MORE_CONTROL 
 soft_injection && !cfg_injection ⇒  MORE_PROMPT   
 
The "unbalanced" option simply exposes the fourth possible combination
of these two booleans:
!soft_injection &&  cfg_injection ⇒ UNBALANCED

With "unbalanced" mode it is very easy to overdrive the controlnet
inputs. It's recommended to use a cfg_scale between 2 and 4 to mitigate
this, along with lowering controlnet weight and possibly lowering "end
step percent". With those caveats, "unbalanced" can yield interesting
results.

Example of all four modes using Canny edge detection ControlNet with
prompt "old man", identical params except for control_mode:

![Screenshot from 2023-06-11
23-53-00](https://github.com/invoke-ai/InvokeAI/assets/303100/c9e31e7f-50de-4d85-94f2-b5a4af3d067b)
Top middle:       BALANCED
Top right:          MORE_CONTROL
Bottom middle: MORE_PROMPT
Bottom right :    UNBALANCED

I kind of chose this seed because it shows pretty rough results with
BALANCED (the default), but in my opinion better results with both
MORE_CONTROL and MORE_PROMPT. And you can definitely see how MORE_PROMPT
pays more attention to the prompt, and MORE_CONTROL pays more attention
to the control image. And shows that UNBALANCED with default cfg_scale
etc is unusable.

But here are four examples from same series (same seed etc), all have
control_mode = UNBALANCED but now cfg_scale is set to 3.
![Screenshot from 2023-06-11
23-48-44](https://github.com/invoke-ai/InvokeAI/assets/303100/5a495306-2164-40aa-9cc8-ce737d7671e7)
And param differences are:
Top middle: prompt="old man", control_weight=0.3, end_step_percent=0.5
Top right: prompt="old man", control_weight=0.4, end_step_percent=1.0
Bottom middle: prompt=None, control_weight=0.3, end_step_percent=0.5
Bottom right: prompt=None, control_weight=0.4, end_step_percent=1.0

So with the right settings UNBALANCED seems useful.
This commit is contained in:
blessedcoolant 2023-06-25 16:09:26 +12:00 committed by GitHub
commit 922468b836
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 290 additions and 143 deletions

View File

@ -1,7 +1,7 @@
# InvokeAI nodes for ControlNet image preprocessors # InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float from builtins import float, bool
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
@ -94,6 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [
] ]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
@ -104,6 +105,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def abs_le_one(cls, v):
"""validate that all abs(values) are <=1""" """validate that all abs(values) are <=1"""
@ -144,11 +147,11 @@ class ControlNetInvocation(BaseInvocation):
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used") description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -166,7 +169,6 @@ class ControlNetInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput( return ControlOutput(
control=ControlField( control=ControlField(
image=self.image, image=self.image,
@ -174,6 +176,7 @@ class ControlNetInvocation(BaseInvocation):
control_weight=self.control_weight, control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
), ),
) )

View File

@ -287,19 +287,14 @@ class TextToLatentsInvocation(BaseInvocation):
control_height_resize = latents_shape[2] * 8 control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8 control_width_resize = latents_shape[3] * 8
if control_input is None: if control_input is None:
# print("control input is None")
control_list = None control_list = None
elif isinstance(control_input, list) and len(control_input) == 0: elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None control_list = None
elif isinstance(control_input, ControlField): elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input] control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input control_list = control_input
else: else:
# print("input control is unrecognized:", type(self.control))
control_list = None control_list = None
if (control_list is None): if (control_list is None):
control_data = None control_data = None
@ -341,12 +336,15 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt,
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(model=control_model,
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent) end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data

View File

@ -215,10 +215,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
@dataclass @dataclass
class ControlNetData: class ControlNetData:
model: ControlNetModel = Field(default=None) model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None) image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]]= Field(default=1.0) weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -599,48 +601,68 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO: should this scaling happen here or inside self._unet_forward? # TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None # default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data): for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum)) control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range # only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step: if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index] controlnet_weight = control_datum.weight[step_index]
else: else:
# if controlnet has a single weight, use it for all steps # if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model( down_samples, mid_sample = control_datum.model(
sample=latent_control_input, sample=control_latent_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states=encoder_hidden_states,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
# cross_attention_kwargs, guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
guess_mode=False,
return_dict=False, return_dict=False,
) )
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None: if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else: else:
@ -653,11 +675,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, x=unet_latent_input,
t, sigma=t,
conditioning_data.unconditioned_embeddings, unconditioning=conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings, conditioning=conditioning_data.text_embeddings,
conditioning_data.guidance_scale, unconditional_guidance_scale=conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s) down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
@ -962,6 +984,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device="cuda", device="cuda",
dtype=torch.float16, dtype=torch.float16,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
control_mode="balanced"
): ):
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
@ -992,6 +1015,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0) image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance: cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2) image = torch.cat([image] * 2)
return image return image

View File

@ -23,7 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"", "dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"", "dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build", "build": "yarn run lint && vite build",
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/schema.d.ts -t", "typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/api/schema.d.ts -t",
"preview": "vite preview", "preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx", "lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .", "lint:eslint": "eslint --max-warnings=0 .",

View File

@ -524,7 +524,8 @@
"initialImage": "Initial Image", "initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel", "showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview", "hidePreview": "Hide Preview",
"showPreview": "Show Preview" "showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",

View File

@ -1,26 +1,27 @@
import { Box, ChakraProps, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import { import {
ControlNetConfig, ControlNetConfig,
controlNetAdded, controlNetAdded,
controlNetRemoved, controlNetRemoved,
controlNetToggled, controlNetToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { Flex, Box, ChakraProps } from '@chakra-ui/react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview';
import IAIIconButton from 'common/components/IAIIconButton';
import { v4 as uuidv4 } from 'uuid';
import { useToggle } from 'react-use';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import IAISwitch from 'common/components/IAISwitch';
import { ChevronUpIcon } from '@chakra-ui/icons'; import { ChevronUpIcon } from '@chakra-ui/icons';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig'; import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 }; const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
@ -36,6 +37,7 @@ const ControlNet = (props: ControlNetProps) => {
weight, weight,
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlMode,
controlImage, controlImage,
processedControlImage, processedControlImage,
processorNode, processorNode,
@ -137,13 +139,13 @@ const ControlNet = (props: ControlNetProps) => {
</Flex> </Flex>
{isEnabled && ( {isEnabled && (
<> <>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}> <Flex sx={{ gap: 4, w: 'full' }}>
<Flex <Flex
sx={{ sx={{
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 3,
w: 'full', w: 'full',
h: isExpanded ? 28 : 24,
paddingInlineStart: 1, paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0, paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2, pb: 2,
@ -172,13 +174,16 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1', aspectRatio: '1/1',
}} }}
> >
<ControlNetImagePreview <ControlNetImagePreview controlNet={props.controlNet} />
controlNet={props.controlNet}
height={24}
/>
</Flex> </Flex>
)} )}
</Flex> </Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
</Flex>
{isExpanded && ( {isExpanded && (
<> <>
<Box mt={2}> <Box mt={2}>

View File

@ -0,0 +1,45 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNetId: string;
controlMode: string;
};
const CONTROL_MODE_DATA = [
{ label: 'Balanced', value: 'balanced' },
{ label: 'Prompt', value: 'more_prompt' },
{ label: 'Control', value: 'more_control' },
{ label: 'Mega Control', value: 'unbalanced' },
];
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlNetId, controlMode = false } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleControlModeChange = useCallback(
(controlMode: ControlModes) => {
dispatch(controlNetControlModeChanged({ controlNetId, controlMode }));
},
[controlNetId, dispatch]
);
return (
<IAIMantineSelect
label={t('parameters.controlNetControlMode')}
data={CONTROL_MODE_DATA}
value={String(controlMode)}
onChange={handleControlModeChange}
/>
);
}

View File

@ -1,6 +1,5 @@
import { import {
ControlNetProcessorType, ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
@ -23,7 +22,7 @@ type ControlNetProcessorsDict = Record<
* *
* TODO: Generate from the OpenAPI schema * TODO: Generate from the OpenAPI schema
*/ */
export const CONTROLNET_PROCESSORS = { export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
none: { none: {
type: 'none', type: 'none',
label: 'none', label: 'none',
@ -174,6 +173,8 @@ export const CONTROLNET_PROCESSORS = {
}, },
}; };
type ControlNetModelsDict = Record<string, ControlNetModel>;
type ControlNetModel = { type ControlNetModel = {
type: string; type: string;
label: string; label: string;
@ -181,7 +182,7 @@ type ControlNetModel = {
defaultProcessor?: ControlNetProcessorType; defaultProcessor?: ControlNetProcessorType;
}; };
export const CONTROLNET_MODELS = { export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': { 'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny', type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny', label: 'Canny',
@ -190,6 +191,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_inpaint': { 'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint', type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint', label: 'Inpaint',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11p_sd15_mlsd': { 'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd', type: 'lllyasviel/control_v11p_sd15_mlsd',
@ -209,6 +211,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_seg': { 'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg', type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation', label: 'Segmentation',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11p_sd15_lineart': { 'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart', type: 'lllyasviel/control_v11p_sd15_lineart',
@ -223,6 +226,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_scribble': { 'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble', type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble', label: 'Scribble',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11p_sd15_softedge': { 'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge', type: 'lllyasviel/control_v11p_sd15_softedge',
@ -242,10 +246,12 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11f1e_sd15_tile': { 'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile', type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)', label: 'Tile (experimental)',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11e_sd15_ip2p': { 'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p', type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)', label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
}, },
'CrucibleAI/ControlNetMediaPipeFace': { 'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace', type: 'CrucibleAI/ControlNetMediaPipeFace',

View File

@ -18,12 +18,19 @@ import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session'; import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions'; import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes =
| 'balanced'
| 'more_prompt'
| 'more_control'
| 'unbalanced';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
controlMode: 'balanced',
controlImage: null, controlImage: null,
processedControlImage: null, processedControlImage: null,
processorType: 'canny_image_processor', processorType: 'canny_image_processor',
@ -39,6 +46,7 @@ export type ControlNetConfig = {
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
controlMode: ControlModes;
controlImage: string | null; controlImage: string | null;
processedControlImage: string | null; processedControlImage: string | null;
processorType: ControlNetProcessorType; processorType: ControlNetProcessorType;
@ -181,6 +189,13 @@ export const controlNetSlice = createSlice({
const { controlNetId, endStepPct } = action.payload; const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct; state.controlNets[controlNetId].endStepPct = endStepPct;
}, },
controlNetControlModeChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
) => {
const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode;
},
controlNetProcessorParamsChanged: ( controlNetProcessorParamsChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -307,6 +322,7 @@ export const {
controlNetWeightChanged, controlNetWeightChanged,
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
controlNetControlModeChanged,
controlNetProcessorParamsChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged, controlNetProcessorTypeChanged,
controlNetReset, controlNetReset,

View File

@ -44,6 +44,7 @@ export const addControlNetToLinearGraph = (
processedControlImage, processedControlImage,
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlMode,
model, model,
processorType, processorType,
weight, weight,
@ -59,6 +60,7 @@ export const addControlNetToLinearGraph = (
type: 'controlnet', type: 'controlnet',
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
control_mode: controlMode,
control_model: model as ControlNetInvocation['control_model'], control_model: model as ControlNetInvocation['control_model'],
control_weight: weight, control_weight: weight,
}; };

View File

@ -0,0 +1,36 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies HED edge detection to image
*/
export type HedImageProcessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'hed_image_processor';
/**
* The image to process
*/
image?: ImageField;
/**
* The pixel resolution for detection
*/
detect_resolution?: number;
/**
* The pixel resolution for the output image
*/
image_resolution?: number;
/**
* Whether to use scribble mode
*/
scribble?: boolean;
};

View File

@ -648,6 +648,13 @@ export type components = {
* @default 1 * @default 1
*/ */
end_step_percent: number; end_step_percent: number;
/**
* Control Mode
* @description The contorl mode to use
* @default balanced
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
}; };
/** /**
* ControlNetInvocation * ControlNetInvocation
@ -701,6 +708,13 @@ export type components = {
* @default 1 * @default 1
*/ */
end_step_percent?: number; end_step_percent?: number;
/**
* Control Mode
* @description The control mode used
* @default balanced
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
}; };
/** ControlNetModelConfig */ /** ControlNetModelConfig */
ControlNetModelConfig: { ControlNetModelConfig: {
@ -2903,7 +2917,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -4163,18 +4177,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion2ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -1,81 +1,78 @@
import { components } from './schema'; import { components } from './schema';
type schemas = components['schemas'];
/** /**
* Types from the API, re-exported from the types generated by `openapi-typescript`. * Types from the API, re-exported from the types generated by `openapi-typescript`.
*/ */
// Images // Images
export type ImageDTO = components['schemas']['ImageDTO']; export type ImageDTO = schemas['ImageDTO'];
export type BoardDTO = components['schemas']['BoardDTO']; export type BoardDTO = schemas['BoardDTO'];
export type BoardChanges = components['schemas']['BoardChanges']; export type BoardChanges = schemas['BoardChanges'];
export type ImageChanges = components['schemas']['ImageRecordChanges']; export type ImageChanges = schemas['ImageRecordChanges'];
export type ImageCategory = components['schemas']['ImageCategory']; export type ImageCategory = schemas['ImageCategory'];
export type ResourceOrigin = components['schemas']['ResourceOrigin']; export type ResourceOrigin = schemas['ResourceOrigin'];
export type ImageField = components['schemas']['ImageField']; export type ImageField = schemas['ImageField'];
export type OffsetPaginatedResults_BoardDTO_ = export type OffsetPaginatedResults_BoardDTO_ =
components['schemas']['OffsetPaginatedResults_BoardDTO_']; schemas['OffsetPaginatedResults_BoardDTO_'];
export type OffsetPaginatedResults_ImageDTO_ = export type OffsetPaginatedResults_ImageDTO_ =
components['schemas']['OffsetPaginatedResults_ImageDTO_']; schemas['OffsetPaginatedResults_ImageDTO_'];
// Models // Models
export type ModelType = components['schemas']['ModelType']; export type ModelType = schemas['ModelType'];
export type BaseModelType = components['schemas']['BaseModelType']; export type BaseModelType = schemas['BaseModelType'];
export type PipelineModelField = components['schemas']['PipelineModelField']; export type PipelineModelField = schemas['PipelineModelField'];
export type ModelsList = components['schemas']['ModelsList']; export type ModelsList = schemas['ModelsList'];
// Graphs // Graphs
export type Graph = components['schemas']['Graph']; export type Graph = schemas['Graph'];
export type Edge = components['schemas']['Edge']; export type Edge = schemas['Edge'];
export type GraphExecutionState = components['schemas']['GraphExecutionState']; export type GraphExecutionState = schemas['GraphExecutionState'];
// General nodes // General nodes
export type CollectInvocation = components['schemas']['CollectInvocation']; export type CollectInvocation = schemas['CollectInvocation'];
export type IterateInvocation = components['schemas']['IterateInvocation']; export type IterateInvocation = schemas['IterateInvocation'];
export type RangeInvocation = components['schemas']['RangeInvocation']; export type RangeInvocation = schemas['RangeInvocation'];
export type RandomRangeInvocation = export type RandomRangeInvocation = schemas['RandomRangeInvocation'];
components['schemas']['RandomRangeInvocation']; export type RangeOfSizeInvocation = schemas['RangeOfSizeInvocation'];
export type RangeOfSizeInvocation = export type InpaintInvocation = schemas['InpaintInvocation'];
components['schemas']['RangeOfSizeInvocation']; export type ImageResizeInvocation = schemas['ImageResizeInvocation'];
export type InpaintInvocation = components['schemas']['InpaintInvocation']; export type RandomIntInvocation = schemas['RandomIntInvocation'];
export type ImageResizeInvocation = export type CompelInvocation = schemas['CompelInvocation'];
components['schemas']['ImageResizeInvocation'];
export type RandomIntInvocation = components['schemas']['RandomIntInvocation'];
export type CompelInvocation = components['schemas']['CompelInvocation'];
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = schemas['ControlNetInvocation'];
export type CannyImageProcessorInvocation = export type CannyImageProcessorInvocation =
components['schemas']['CannyImageProcessorInvocation']; schemas['CannyImageProcessorInvocation'];
export type ContentShuffleImageProcessorInvocation = export type ContentShuffleImageProcessorInvocation =
components['schemas']['ContentShuffleImageProcessorInvocation']; schemas['ContentShuffleImageProcessorInvocation'];
export type HedImageProcessorInvocation = export type HedImageProcessorInvocation =
components['schemas']['HedImageProcessorInvocation']; schemas['HedImageProcessorInvocation'];
export type LineartAnimeImageProcessorInvocation = export type LineartAnimeImageProcessorInvocation =
components['schemas']['LineartAnimeImageProcessorInvocation']; schemas['LineartAnimeImageProcessorInvocation'];
export type LineartImageProcessorInvocation = export type LineartImageProcessorInvocation =
components['schemas']['LineartImageProcessorInvocation']; schemas['LineartImageProcessorInvocation'];
export type MediapipeFaceProcessorInvocation = export type MediapipeFaceProcessorInvocation =
components['schemas']['MediapipeFaceProcessorInvocation']; schemas['MediapipeFaceProcessorInvocation'];
export type MidasDepthImageProcessorInvocation = export type MidasDepthImageProcessorInvocation =
components['schemas']['MidasDepthImageProcessorInvocation']; schemas['MidasDepthImageProcessorInvocation'];
export type MlsdImageProcessorInvocation = export type MlsdImageProcessorInvocation =
components['schemas']['MlsdImageProcessorInvocation']; schemas['MlsdImageProcessorInvocation'];
export type NormalbaeImageProcessorInvocation = export type NormalbaeImageProcessorInvocation =
components['schemas']['NormalbaeImageProcessorInvocation']; schemas['NormalbaeImageProcessorInvocation'];
export type OpenposeImageProcessorInvocation = export type OpenposeImageProcessorInvocation =
components['schemas']['OpenposeImageProcessorInvocation']; schemas['OpenposeImageProcessorInvocation'];
export type PidiImageProcessorInvocation = export type PidiImageProcessorInvocation =
components['schemas']['PidiImageProcessorInvocation']; schemas['PidiImageProcessorInvocation'];
export type ZoeDepthImageProcessorInvocation = export type ZoeDepthImageProcessorInvocation =
components['schemas']['ZoeDepthImageProcessorInvocation']; schemas['ZoeDepthImageProcessorInvocation'];
// Node Outputs // Node Outputs
export type ImageOutput = components['schemas']['ImageOutput']; export type ImageOutput = schemas['ImageOutput'];
export type MaskOutput = components['schemas']['MaskOutput']; export type MaskOutput = schemas['MaskOutput'];
export type PromptOutput = components['schemas']['PromptOutput']; export type PromptOutput = schemas['PromptOutput'];
export type IterateInvocationOutput = export type IterateInvocationOutput = schemas['IterateInvocationOutput'];
components['schemas']['IterateInvocationOutput']; export type CollectInvocationOutput = schemas['CollectInvocationOutput'];
export type CollectInvocationOutput = export type LatentsOutput = schemas['LatentsOutput'];
components['schemas']['CollectInvocationOutput']; export type GraphInvocationOutput = schemas['GraphInvocationOutput'];
export type LatentsOutput = components['schemas']['LatentsOutput'];
export type GraphInvocationOutput =
components['schemas']['GraphInvocationOutput'];