Rewrite controlnet to new model manager (#3665)

This commit is contained in:
Lincoln Stein 2023-07-15 08:24:06 -04:00 committed by GitHub
commit 07a2da40b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 1095 additions and 625 deletions

View File

@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
from PIL import Image
from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
@ -118,15 +125,15 @@ class ControlField(BaseModel):
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
if i < -1 or i > 2:
raise ValueError('Control weights must be within -1 to 2 range')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
if v < -1 or v > 2:
raise ValueError('Control weights must be within -1 to 2 range')
return v
class Config:
schema_extra = {
@ -134,6 +141,7 @@ class ControlField(BaseModel):
"ui": {
"type_hints": {
"control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number",
}
}
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
@ -71,16 +73,21 @@ def get_scheduler(
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
**scheduler_info.dict()
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler_config = {
**scheduler_config,
**scheduler_extra_config,
"_backup": scheduler_config,
}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
)
def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData:
self,
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
self.negative_conditioning.conditioning_name
)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data
def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
self,
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
)
)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(
control_image_field.image_name)
control_image_field.image_name
)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,)
control_mode=control_info.control_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype)
latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()

View File

@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' });
const predicate: AnyListenerPredicate<RootState> = (action, state) => {
const predicate: AnyListenerPredicate<RootState> = (
action,
state,
prevState
) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) ||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
return false;
}
if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId];

View File

@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' });
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1;
}
// TODO: handle incompatible controlnet; pending model manager support
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) {
dispatch(
addToast(

View File

@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' });
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support
const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isControlNetAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
});
},
});
};

View File

@ -1,5 +1,5 @@
import {
CONTROLNET_MODELS,
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
@ -128,7 +128,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
disabledControlNetModels: string[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: {
initial: number;

View File

@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
{imageDTO && (
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}

View File

@ -1,17 +1,25 @@
import { Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
type IAIMultiSelectProps = MultiSelectProps & {
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props;
const {
searchable = true,
tooltip,
inputRef,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch();
const handleKeyDown = useCallback(
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
ref={inputRef}
disabled={disabled}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
@ -11,13 +11,22 @@ export type IAISelectDataType = {
tooltip?: string;
};
type IAISelectProps = SelectProps & {
type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string;
label?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const {
searchable = true,
tooltip,
inputRef,
onChange,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState('');
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react';
@ -9,19 +9,32 @@ export type IAISelectDataType = {
tooltip?: string;
};
type IAISelectProps = SelectProps & {
type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
};
const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, ...rest } = props;
const { tooltip, inputRef, label, disabled, ...rest } = props;
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select ref={inputRef} styles={styles} {...rest} />
<Select
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
ref={inputRef}
styles={styles}
{...rest}
/>
</Tooltip>
);
};

View File

@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = {
label?: string;
value: number;
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
{...sliderFormControlProps}
>
{label && (
<FormLabel {...sliderFormLabelProps} mb={-1}>
<FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
{label}
</FormLabel>
)}
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
key={m}
value={m}
sx={{
...SLIDER_MARK_STYLES,
transform: 'translateX(-50%)',
}}
{...sliderMarkProps}
>

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { modelsApi } from '../../services/api/endpoints/models';
import { forEach } from 'lodash-es';
const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector],
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
}
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
}
});
// All good
return { isReady, reasonsWhyNotReady };
},

View File

@ -1,10 +1,9 @@
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetAdded,
controlNetDuplicated,
controlNetRemoved,
controlNetToggled,
} from '../store/controlNetSlice';
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use';
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import { mode } from 'theme/util/mode';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
type ControlNetProps = {
controlNet: ControlNetConfig;
controlNetId: string;
};
const ControlNet = (props: ControlNetProps) => {
const {
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
processorType,
shouldAutoConfig,
} = props.controlNet;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const [isExpanded, toggleIsExpanded] = useToggle(false);
const { colorMode } = useColorMode();
const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
);
}, [dispatch, props.controlNet]);
}, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column',
gap: 2,
p: 3,
bg: mode('base.200', 'base.850')(colorMode),
borderRadius: 'base',
position: 'relative',
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}}
>
<Flex sx={{ gap: 2 }}>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch
tooltip="Toggle"
aria-label="Toggle"
tooltip={'Toggle this ControlNet'}
aria-label={'Toggle this ControlNet'}
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s',
}}
>
<ParamControlNetModel controlNetId={controlNetId} model={model} />
<ParamControlNetModel controlNetId={controlNetId} />
</Box>
<IAIIconButton
size="sm"
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
/>
<IAIIconButton
size="sm"
aria-label="Show All Options"
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded}
variant="link"
icon={
<ChevronUpIcon
sx={{
boxSize: 4,
color: mode('base.700', 'base.300')(colorMode),
color: 'base.700',
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
_dark: {
color: 'base.300',
},
}}
/>
}
/>
{!shouldAutoConfig && (
<Box
sx={{
@ -131,21 +141,23 @@ const ControlNet = (props: ControlNetProps) => {
w: 1.5,
h: 1.5,
borderRadius: 'full',
bg: mode('error.700', 'error.200')(colorMode),
top: 4,
insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}}
/>
)}
</Flex>
{isEnabled && (
<>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex
sx={{
flexDir: 'column',
gap: 3,
h: 28,
w: 'full',
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
@ -153,63 +165,35 @@ const ControlNet = (props: ControlNetProps) => {
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
<ParamControlNetWeight controlNetId={controlNetId} />
<ParamControlNetBeginEnd controlNetId={controlNetId} />
</Flex>
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
h: 28,
w: 28,
aspectRatio: '1/1',
mt: 3,
}}
>
<ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
</Flex>
)}
</Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
<Box mt={2}>
<ParamControlNetControlMode controlNetId={controlNetId} />
</Box>
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex>
{isExpanded && (
<>
<Box mt={2}>
<ControlNetImagePreview
controlNet={props.controlNet}
height={96}
/>
</Box>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</>
)}
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
<ControlNetProcessorComponent controlNetId={controlNetId} />
</>
)}
</Flex>

View File

@ -5,42 +5,57 @@ import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { pendingControlImages } = controlNet;
return { pendingControlImages };
},
defaultSelectorOptions
);
import { controlNetImageChanged } from '../store/controlNetSlice';
type Props = {
controlNet: ControlNetConfig;
controlNetId: string;
height: SystemStyleObject['h'];
};
const ControlNetImagePreview = (props: Props) => {
const { height } = props;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const { height, controlNetId } = props;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { pendingControlImages } = controlNet;
const {
controlImage,
processedControlImage,
processorType,
isEnabled,
} = controlNet.controlNets[controlNetId];
return {
controlImageName: controlImage,
processedControlImageName: processedControlImage,
processorType,
isEnabled,
pendingControlImages,
};
},
defaultSelectorOptions
),
[controlNetId]
);
const {
controlImageName,
processedControlImageName,
processorType,
pendingControlImages,
isEnabled,
} = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
h: height,
alignItems: 'center',
justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',
opacity: isEnabled ? 1 : 0.5,
}}
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage}
isDropDisabled={shouldShowProcessedImage || !isEnabled}
onClickReset={handleResetControlImage}
postUploadAction={postUploadAction}
resetTooltip="Reset Control Image"
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
droppableData={droppableData}
imageDTO={processedControlImage}
isUploadDisabled={true}
isDropDisabled={!isEnabled}
onClickReset={handleResetControlImage}
resetTooltip="Reset Control Image"
withResetIcon={Boolean(controlImage)}

View File

@ -1,10 +1,13 @@
import { memo } from 'react';
import { RequiredControlNetProcessorNode } from '../store/types';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo, useMemo } from 'react';
import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor';
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = {
controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
};
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId } = props;
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, processorNode } = useAppSelector(selector);
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
<HedProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ContentShuffleProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartAnimeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MediapipeFaceProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MidasDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MlsdImageProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<NormalBaeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<OpenposeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<PidiProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ZoeDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}

View File

@ -1,18 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback, useMemo } from 'react';
type Props = {
controlNetId: string;
shouldAutoConfig: boolean;
};
const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isBusy = useAppSelector(selectIsBusy);
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]);
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor"
isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
);
};

View File

@ -1,5 +1,4 @@
import {
ChakraProps,
FormControl,
FormLabel,
HStack,
@ -10,34 +9,41 @@ import {
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
import { memo, useCallback, useMemo } from 'react';
type Props = {
controlNetId: string;
beginStepPct: number;
endStepPct: number;
mini?: boolean;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { beginStepPct, endStepPct, isEnabled } =
controlNet.controlNets[controlNetId];
return { beginStepPct, endStepPct, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
const handleStepPctChanged = useCallback(
(v: number[]) => {
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
}, [controlNetId, dispatch]);
return (
<FormControl>
<FormControl isDisabled={!isEnabled}>
<FormLabel>Begin / End Step Percentage</FormLabel>
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
max={1}
step={0.01}
minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
@ -76,14 +83,11 @@ const ParamControlNetBeginEnd = (props: Props) => {
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
{!mini && (
<>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
>
0%
@ -91,7 +95,8 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark
value={0.5}
sx={{
...SLIDER_MARK_STYLES,
insetInlineStart: '50% !important',
transform: 'translateX(-50%)',
}}
>
50%
@ -101,13 +106,10 @@ const ParamControlNetBeginEnd = (props: Props) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
>
100%
</RangeSliderMark>
</>
)}
</RangeSlider>
</HStack>
</FormControl>

View File

@ -1,15 +1,17 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNetId: string;
controlMode: string;
};
const CONTROL_MODE_DATA = [
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlNetId, controlMode = false } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { controlMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { controlMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation();
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
return (
<IAIMantineSelect
label={t('parameters.controlNetControlMode')}
disabled={!isEnabled}
label="Control Mode"
data={CONTROL_MODE_DATA}
value={String(controlMode)}
onChange={handleControlModeChange}

View File

@ -1,28 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isEnabled: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isEnabled } = props;
const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]);
return (
<IAISwitch
label="Enabled"
isChecked={isEnabled}
onChange={handleIsEnabledChanged}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,59 +1,118 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = {
controlNetId: string;
model: ControlNetModelName;
};
const selector = createSelector(configSelector, (config) => {
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
label: m.label,
value: m.type,
})).filter(
(d) =>
!config.sd.disabledControlNetModels.includes(
d.value as ControlNetModelName
)
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId } = props;
const dispatch = useAppDispatch();
const isBusy = useAppSelector(selectIsBusy);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
return { mainModel: model, controlNetModel, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
return controlNetModels;
});
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector);
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const { data: controlNetModels } = useGetControlNetModelsQuery();
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
}, [controlNetModels, mainModel?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const handleModelChanged = useCallback(
(val: string | null) => {
// TODO: do not cast
const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
},
[controlNetId, dispatch]
);
return (
<IAIMantineSearchableSelect
data={controlNetModels}
value={model}
itemComponent={IAIMantineSelectItemWithTooltip}
data={data}
error={
!selectedModel || mainModel?.base_model !== selectedModel.base_model
}
placeholder="Select a model"
value={selectedModel?.id ?? null}
onChange={handleModelChanged}
disabled={!isReady}
tooltip={model}
disabled={isBusy || !isEnabled}
tooltip={selectedModel?.description}
/>
);
};

View File

@ -1,24 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { ControlNetProcessorType } from '../../store/types';
import { FormControl, FormLabel } from '@chakra-ui/react';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const selector = createSelector(
@ -54,10 +52,24 @@ const selector = createSelector(
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const { controlNetId } = props;
const processorNodeSelector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const isBusy = useAppSelector(selectIsBusy);
const controlNetProcessors = useAppSelector(selector);
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
const handleProcessorTypeChanged = useCallback(
(v: string | null) => {
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors}
onChange={handleProcessorTypeChanged}
disabled={!isReady}
disabled={isBusy || !isEnabled}
/>
);
};

View File

@ -1,18 +1,32 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
type ParamControlNetWeightProps = {
controlNetId: string;
weight: number;
mini?: boolean;
};
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId, weight, mini = false } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
return { weight, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { weight, isEnabled } = useAppSelector(selector);
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight }));
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
return (
<IAISlider
isDisabled={!isEnabled}
label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight}
onChange={handleWeightChanged}
min={-1}
max={1}
min={0}
max={2}
step={0.01}
withSliderMarks={!mini}
sliderMarks={[-1, 0, 1]}
withSliderMarks
sliderMarks={[0, 1, 2]}
/>
);
};

View File

@ -1,22 +1,25 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation;
type CannyProcessorProps = {
controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation;
isEnabled: boolean;
};
const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback(
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return (
<ProcessorWrapper>
<IAISlider
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
label="Low Threshold"
value={low_threshold}
onChange={handleLowThresholdChanged}
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
withSliderMarks
/>
<IAISlider
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
label="High Threshold"
value={high_threshold}
onChange={handleHighThresholdChanged}

View File

@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsBusy } from 'features/system/store/systemSelectors';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
.default as RequiredContentShuffleImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation;
isEnabled: boolean;
};
const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="W"
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="H"
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="F"
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,25 +1,29 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
.default as RequiredHedImageProcessorInvocation;
type HedProcessorProps = {
controlNetId: string;
processorNode: RequiredHedImageProcessorInvocation;
isEnabled: boolean;
};
const HedPreprocessor = (props: HedProcessorProps) => {
const {
controlNetId,
processorNode: { detect_resolution, image_resolution, scribble },
isEnabled,
} = props;
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
.default as RequiredLineartAnimeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredLineartAnimeImageProcessorInvocation;
isEnabled: boolean;
};
const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
.default as RequiredLineartImageProcessorInvocation;
type LineartProcessorProps = {
controlNetId: string;
processorNode: RequiredLineartImageProcessorInvocation;
isEnabled: boolean;
};
const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Coarse"
isChecked={coarse}
onChange={handleCoarseChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
.default as RequiredMediapipeFaceProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation;
isEnabled: boolean;
};
const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleMaxFacesChanged = useCallback(
(v: number) => {
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
max={20}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Min Confidence"
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
.default as RequiredMidasDepthImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation;
isEnabled: boolean;
};
const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleAMultChanged = useCallback(
(v: number) => {
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="bg_th"
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
.default as RequiredMlsdImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation;
isEnabled: boolean;
};
const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="W"
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="H"
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
.default as RequiredNormalbaeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation;
isEnabled: boolean;
};
const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
.default as RequiredOpenposeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation;
isEnabled: boolean;
};
const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Hand and Face"
isChecked={hand_and_face}
onChange={handleHandAndFaceChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
.default as RequiredPidiImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation;
isEnabled: boolean;
};
const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Scribble"
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
label="Safe"
isChecked={safe}
onChange={handleSafeChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
type Props = {
controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation;
isEnabled: boolean;
};
const ZoeDepthProcessor = (props: Props) => {

View File

@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
},
};
type ControlNetModelsDict = Record<string, ControlNetModel>;
type ControlNetModel = {
type: string;
label: string;
description?: string;
defaultProcessor?: ControlNetProcessorType;
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
} = {
canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor',
depth: 'midas_depth_image_processor',
bae: 'normalbae_image_processor',
lineart: 'lineart_image_processor',
lineart_anime: 'lineart_anime_image_processor',
softedge: 'hed_image_processor',
shuffle: 'content_shuffle_image_processor',
openpose: 'openpose_image_processor',
mediapipe: 'mediapipe_face_processor',
};
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -1,22 +1,20 @@
import { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes =
| 'balanced'
@ -26,7 +24,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelName;
model: ControlNetModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
controlNetId,
};
},
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: string }>
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
model: ControlNetModelParam;
}>
) => {
const { controlNetId, model } = action.payload;
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
});
builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery
// Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) {
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
});
});
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];
});
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
export const {
isControlNetEnabledToggled,
controlNetAdded,
controlNetDuplicated,
controlNetAddedFromImage,
controlNetRemoved,
controlNetImageChanged,

View File

@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
return (
<ControlNetModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent

View File

@ -0,0 +1,103 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ControlNetModelInputFieldComponent = (
props: FieldComponentProps<
ControlNetModelInputFieldValue,
ControlNetModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const controlNetModel = field.value;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [controlNetModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newControlNetModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(ControlNetModelInputFieldComponent);

View File

@ -1,6 +1,7 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
ControlNetModelParam,
LoRAModelParam,
MainModelParam,
VaeModelParam,
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
| ImageField[]
| MainModelParam
| VaeModelParam
| LoRAModelParam;
| LoRAModelParam
| ControlNetModelParam;
}>
) => {
const { nodeId, fieldName, value } = action.payload;

View File

@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
model: 'model',
vae_model: 'vae_model',
lora_model: 'lora_model',
controlnet_model: 'controlnet_model',
ControlNetModelField: 'controlnet_model',
array: 'array',
item: 'item',
ColorField: 'color',
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'LoRA',
description: 'Models are models.',
},
controlnet_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'ControlNet',
description: 'Models are models.',
},
array: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),

View File

@ -1,4 +1,5 @@
import {
ControlNetModelParam,
LoRAModelParam,
MainModelParam,
VaeModelParam,
@ -71,6 +72,7 @@ export type FieldType =
| 'model'
| 'vae_model'
| 'lora_model'
| 'controlnet_model'
| 'array'
| 'item'
| 'color'
@ -100,6 +102,7 @@ export type InputFieldValue =
| MainModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue
| ColorInputFieldValue
@ -127,6 +130,7 @@ export type InputFieldTemplate =
| ModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate
| ColorInputFieldTemplate
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
value?: LoRAModelParam;
};
export type ControlNetModelInputFieldValue = FieldValueBase & {
type: 'controlnet_model';
value?: ControlNetModelParam;
};
export type ArrayInputFieldValue = FieldValueBase & {
type: 'array';
value?: (string | number)[];
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'lora_model';
};
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'controlnet_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'array';

View File

@ -9,6 +9,7 @@ import {
ColorInputFieldTemplate,
ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
FloatInputFieldTemplate,
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
return template;
};
const buildControlNetModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
const template: ControlNetModelInputFieldTemplate = {
...baseField,
type: 'controlnet_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
}
if (['controlnet_model'].includes(fieldType)) {
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -83,6 +83,10 @@ export const buildInputFieldValue = (
if (template.type === 'lora_model') {
fieldValue.value = undefined;
}
if (template.type === 'controlnet_model') {
fieldValue.value = undefined;
}
}
return fieldValue;

View File

@ -8,6 +8,7 @@ import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import {
controlNetAdded,
controlNetModelChanged,
controlNetSelector,
} from 'features/controlNet/store/controlNetSlice';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
@ -15,6 +16,7 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es';
import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid';
const selector = createSelector(
@ -39,10 +41,23 @@ const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch();
const { firstModel } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => {
const firstModel = result.data?.entities[result.data?.ids[0]];
return {
firstModel,
};
},
});
const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
}, [dispatch]);
if (!firstModel) {
return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) {
return null;
@ -52,15 +67,20 @@ const ParamControlNetCollapse = () => {
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 3 }}>
<ParamControlNetFeatureToggle />
<IAIButton
isDisabled={!firstModel}
flexGrow={1}
size="sm"
onClick={handleClickedAddControlNet}
>
Add ControlNet
</IAIButton>
{controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}>
{i > 0 && <Divider />}
<ControlNet controlNet={c} />
<ControlNet controlNetId={c.controlNetId} />
</Fragment>
))}
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
</Flex>
</IAICollapse>
);

View File

@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
return [];
}
// add a "default" option, this means use the main model's included VAE
const data: SelectItem[] = [
{
value: 'default',

View File

@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
_disabled: {
bg: 'none',
color: 'base.600',
cursor: 'not-allowed',
_hover: {
color: 'base.600',
bg: 'none',
},
},

View File

@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
*/
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success;
/**
* Zod schema for ControlNet models
*/
export const zControlNetModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidControlNetModel = (
val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/**
* Zod schema for l2l strength parameter

View File

@ -0,0 +1,30 @@
import { log } from 'app/logging/useLogger';
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
import { ControlNetModelField } from 'services/api/types';
const moduleLog = log.child({ module: 'models' });
export const modelIdToControlNetModelParam = (
controlNetModelId: string
): ControlNetModelField | undefined => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const result = zControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -1,9 +1,12 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToLoRAModelParam = (
loraId: string
loraModelId: string
): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/');
const [base_model, model_type, model_name] = loraModelId.split('/');
const result = zLoRAModel.safeParse({
base_model,
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
});
if (!result.success) {
moduleLog.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return;
}

View File

@ -2,11 +2,14 @@ import {
MainModelParam,
zMainModel,
} from 'features/parameters/types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToMainModelParam = (
modelId: string
mainModelId: string
): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({
base_model,
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
});
if (!result.success) {
moduleLog.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return;
}

View File

@ -1,9 +1,12 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToVAEModelParam = (
modelId: string
vaeModelId: string
): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const [base_model, model_type, model_name] = vaeModelId.split('/');
const result = zVaeModel.safeParse({
base_model,
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
});
if (!result.success) {
moduleLog.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return;
}

View File

@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<ImageToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamNoiseCollapse />
<ParamSymmetryCollapse />

View File

@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamNoiseCollapse />
<ParamSymmetryCollapse />

View File

@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<UnifiedCanvasCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamSymmetryCollapse />
<ParamSeamCorrectionCollapse />

View File

@ -734,7 +734,7 @@ export type components = {
* Control Model
* @description The ControlNet model to use
*/
control_model: string;
control_model: components["schemas"]["ControlNetModelField"];
/**
* Control Weight
* @description The weight given to the ControlNet
@ -792,9 +792,8 @@ export type components = {
* Control Model
* @description control model used
* @default lllyasviel/sd-controlnet-canny
* @enum {string}
*/
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace";
control_model?: components["schemas"]["ControlNetModelField"];
/**
* Control Weight
* @description The weight given to the ControlNet
@ -838,6 +837,19 @@ export type components = {
model_format: components["schemas"]["ControlNetModelFormat"];
error?: components["schemas"]["ModelError"];
};
/**
* ControlNetModelField
* @description ControlNet model field
*/
ControlNetModelField: {
/**
* Model Name
* @description Name of the ControlNet model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/**
* ControlNetModelFormat
* @description An enumeration.
@ -1923,12 +1935,12 @@ export type components = {
* Width
* @description The width to resize to (px)
*/
width: number;
width?: number;
/**
* Height
* @description The height to resize to (px)
*/
height: number;
height?: number;
/**
* Resample Mode
* @description The resampling mode
@ -3911,13 +3923,15 @@ export type components = {
/**
* Width
* @description The width to resize to (px)
* @default 512
*/
width: number;
width?: number;
/**
* Height
* @description The height to resize to (px)
* @default 512
*/
height: number;
height?: number;
/**
* Mode
* @description The interpolation mode
@ -4605,18 +4619,18 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;

View File

@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField'];
export type ControlNetModelField =
components['schemas']['ControlNetModelField'];
export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField'];

View File

@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
const invokeAIMark = defineStyle((props) => {
return {
fontSize: 'xs',
fontSize: '2xs',
fontWeight: '500',
color: mode('base.700', 'base.400')(props),
mt: 2,