Rewrite controlnet to new model manager (#3665)

This commit is contained in:
Lincoln Stein 2023-07-15 08:24:06 -04:00 committed by GitHub
commit 07a2da40b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 1095 additions and 625 deletions

View File

@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageField, ImageCategory, ResourceOrigin from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
@ -118,15 +125,15 @@ class ControlField(BaseModel):
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def validate_control_weight(cls, v):
"""validate that all abs(values) are <=1""" """Validate that all control weights in the valid range"""
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if abs(i) > 1: if i < -1 or i > 2:
raise ValueError('all abs(control_weight) must be <= 1') raise ValueError('Control weights must be within -1 to 2 range')
else: else:
if abs(v) > 1: if v < -1 or v > 2:
raise ValueError('abs(control_weight) must be <= 1') raise ValueError('Control weights must be within -1 to 2 range')
return v return v
class Config: class Config:
schema_extra = { schema_extra = {
@ -134,6 +141,7 @@ class ControlField(BaseModel):
"ui": { "ui": {
"type_hints": { "type_hints": {
"control_weight": "float", "control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number", # "control_weight": "number",
} }
} }
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
type: Literal["controlnet"] = "controlnet" type: Literal["controlnet"] = "controlnet"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used") description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
@ -71,16 +73,21 @@ def get_scheduler(
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim']) scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict()) **scheduler_info.dict()
)
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, ** scheduler_config = {
scheduler_extra_config, "_backup": scheduler_config} **scheduler_config,
**scheduler_extra_config,
"_backup": scheduler_config,
}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, source_node_id: str, self,
intermediate_state: PipelineIntermediateState) -> None: context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
context=context, context=context,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
) )
def get_conditioning_data( def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData: self,
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get( c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name) self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get( uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name) self.negative_conditioning.conditioning_name
)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data return conditioning_data
def create_pipeline( def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline: self,
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO: # TODO:
# configure_model_padding( # configure_model_padding(
# unet, # unet,
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
model: StableDiffusionGeneratorPipeline, model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField], control_input: List[ControlField],
latents_shape: List[int], latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> List[ControlNetData]:
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
control_data = [] control_data = []
control_models = [] control_models = []
for control_info in control_list: for control_info in control_list:
# handle control models control_model = exit_stack.enter_context(
if ("," in control_info.control_model): context.services.model_manager.get_model(
control_model_split = control_info.control_model.split(",") model_name=control_info.control_model.model_name,
control_name = control_model_split[0] model_type=ModelType.ControlNet,
control_subfolder = control_model_split[1] base_model=control_info.control_model.base_model,
print("Using HF model subfolders") )
print(" control_name: ", control_name) )
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image( input_image = context.services.images.get_pil_image(
control_image_field.image_name) control_image_field.image_name
)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,) control_mode=control_info.control_mode,
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id) context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()) **self.unet.unet.dict()
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ )
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id) context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()) **self.unet.unet.dict()
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ )
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype) latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
self.steps, self.steps,
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8), latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,) if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode, latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,) if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
const predicate: AnyListenerPredicate<RootState> = (action, state) => { const predicate: AnyListenerPredicate<RootState> = (
action,
state,
prevState
) => {
const isActionMatched = const isActionMatched =
controlNetProcessorParamsChanged.match(action) || controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) || controlNetModelChanged.match(action) ||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
return false; return false;
} }
if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } = const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId]; state.controlNet.controlNets[action.payload.controlNetId];

View File

@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1; modelsCleared += 1;
} }
// TODO: handle incompatible controlnet; pending model manager support const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) { if (modelsCleared > 0) {
dispatch( dispatch(
addToast( addToast(

View File

@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state // ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isControlNetAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
});
}, },
}); });
}; };

View File

@ -1,5 +1,5 @@
import { import {
CONTROLNET_MODELS, // CONTROLNET_MODELS,
CONTROLNET_PROCESSORS, CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants'; } from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
@ -128,7 +128,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[]; disabledControlNetModels: string[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[]; disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: { iterations: {
initial: number; initial: number;

View File

@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} {!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable <IAIDroppable
data={droppableData} data={droppableData}
disabled={isDropDisabled} disabled={isDropDisabled}
dropLabel={dropLabel} dropLabel={dropLabel}
/> />
{imageDTO && ( )}
{imageDTO && !isDragDisabled && (
<IAIDraggable <IAIDraggable
data={draggableData} data={draggableData}
disabled={isDragDisabled || !imageDTO} disabled={isDragDisabled || !imageDTO}

View File

@ -1,17 +1,25 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles'; import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react'; import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string;
}; };
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props; const {
searchable = true,
tooltip,
inputRef,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleKeyDown = useCallback( const handleKeyDown = useCallback(
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}> <Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect <MultiSelect
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
ref={inputRef} ref={inputRef}
disabled={disabled}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp} onKeyUp={handleKeyUp}
searchable={searchable} searchable={searchable}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
@ -11,13 +11,22 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = SelectProps & { type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
label?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
}; };
const IAIMantineSearchableSelect = (props: IAISelectProps) => { const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; const {
searchable = true,
tooltip,
inputRef,
onChange,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState(''); const [searchValue, setSearchValue] = useState('');
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select <Select
ref={inputRef} ref={inputRef}
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
searchValue={searchValue} searchValue={searchValue}
onSearchChange={setSearchValue} onSearchChange={setSearchValue}
onChange={handleChange} onChange={handleChange}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles'; import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react'; import { RefObject, memo } from 'react';
@ -9,19 +9,32 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = SelectProps & { type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string;
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, ...rest } = props; const { tooltip, inputRef, label, disabled, ...rest } = props;
const styles = useMantineSelectStyles(); const styles = useMantineSelectStyles();
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select ref={inputRef} styles={styles} {...rest} /> <Select
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
ref={inputRef}
styles={styles}
{...rest}
/>
</Tooltip> </Tooltip>
); );
}; };

View File

@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi'; import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton'; import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = { export type IAIFullSliderProps = {
label?: string; label?: string;
value: number; value: number;
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
{...sliderFormControlProps} {...sliderFormControlProps}
> >
{label && ( {label && (
<FormLabel {...sliderFormLabelProps} mb={-1}> <FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
{label} {label}
</FormLabel> </FormLabel>
)} )}
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
key={m} key={m}
value={m} value={m}
sx={{ sx={{
...SLIDER_MARK_STYLES, transform: 'translateX(-50%)',
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { modelsApi } from '../../services/api/endpoints/models'; import { modelsApi } from '../../services/api/endpoints/models';
import { forEach } from 'lodash-es';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('Seed-Weights badly formatted.'); reasonsWhyNotReady.push('Seed-Weights badly formatted.');
} }
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
}
});
// All good // All good
return { isReady, reasonsWhyNotReady }; return { isReady, reasonsWhyNotReady };
}, },

View File

@ -1,10 +1,9 @@
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa'; import { FaCopy, FaTrash } from 'react-icons/fa';
import { import {
ControlNetConfig, controlNetDuplicated,
controlNetAdded,
controlNetRemoved, controlNetRemoved,
controlNetToggled, controlNetToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons'; import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use'; import { useToggle } from 'react-use';
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import { mode } from 'theme/util/mode';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
type ControlNetProps = { type ControlNetProps = {
controlNet: ControlNetConfig; controlNetId: string;
}; };
const ControlNet = (props: ControlNetProps) => { const ControlNet = (props: ControlNetProps) => {
const { const { controlNetId } = props;
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
processorType,
shouldAutoConfig,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const [isExpanded, toggleIsExpanded] = useToggle(false); const [isExpanded, toggleIsExpanded] = useToggle(false);
const { colorMode } = useColorMode();
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId })); dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => { const handleDuplicate = useCallback(() => {
dispatch( dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet }) controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
); );
}, [dispatch, props.controlNet]); }, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(() => { const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId })); dispatch(controlNetToggled({ controlNetId }));
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 2,
p: 3, p: 3,
bg: mode('base.200', 'base.850')(colorMode),
borderRadius: 'base', borderRadius: 'base',
position: 'relative', position: 'relative',
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}} }}
> >
<Flex sx={{ gap: 2 }}> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch <IAISwitch
tooltip="Toggle" tooltip={'Toggle this ControlNet'}
aria-label="Toggle" aria-label={'Toggle this ControlNet'}
isChecked={isEnabled} isChecked={isEnabled}
onChange={handleToggleIsEnabled} onChange={handleToggleIsEnabled}
/> />
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s', transitionDuration: '0.1s',
}} }}
> >
<ParamControlNetModel controlNetId={controlNetId} model={model} /> <ParamControlNetModel controlNetId={controlNetId} />
</Box> </Box>
<IAIIconButton <IAIIconButton
size="sm" size="sm"
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
/> />
<IAIIconButton <IAIIconButton
size="sm" size="sm"
aria-label="Show All Options" tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded} onClick={toggleIsExpanded}
variant="link" variant="link"
icon={ icon={
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
boxSize: 4, boxSize: 4,
color: mode('base.700', 'base.300')(colorMode), color: 'base.700',
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)', transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: 'normal', transitionDuration: 'normal',
_dark: {
color: 'base.300',
},
}} }}
/> />
} }
/> />
{!shouldAutoConfig && ( {!shouldAutoConfig && (
<Box <Box
sx={{ sx={{
@ -131,21 +141,23 @@ const ControlNet = (props: ControlNetProps) => {
w: 1.5, w: 1.5,
h: 1.5, h: 1.5,
borderRadius: 'full', borderRadius: 'full',
bg: mode('error.700', 'error.200')(colorMode),
top: 4, top: 4,
insetInlineEnd: 4, insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}} }}
/> />
)} )}
</Flex> </Flex>
{isEnabled && (
<>
<Flex sx={{ w: 'full', flexDirection: 'column' }}> <Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}> <Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex <Flex
sx={{ sx={{
flexDir: 'column', flexDir: 'column',
gap: 3, gap: 3,
h: 28,
w: 'full', w: 'full',
paddingInlineStart: 1, paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0, paddingInlineEnd: isExpanded ? 1 : 0,
@ -153,63 +165,35 @@ const ControlNet = (props: ControlNetProps) => {
justifyContent: 'space-between', justifyContent: 'space-between',
}} }}
> >
<ParamControlNetWeight <ParamControlNetWeight controlNetId={controlNetId} />
controlNetId={controlNetId} <ParamControlNetBeginEnd controlNetId={controlNetId} />
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex> </Flex>
{!isExpanded && ( {!isExpanded && (
<Flex <Flex
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
h: 24, h: 28,
w: 24, w: 28,
aspectRatio: '1/1', aspectRatio: '1/1',
mt: 3,
}} }}
> >
<ControlNetImagePreview <ControlNetImagePreview controlNetId={controlNetId} height={28} />
controlNet={props.controlNet}
height={24}
/>
</Flex> </Flex>
)} )}
</Flex> </Flex>
<ParamControlNetControlMode <Box mt={2}>
controlNetId={controlNetId} <ParamControlNetControlMode controlNetId={controlNetId} />
controlMode={controlMode} </Box>
/> <ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex> </Flex>
{isExpanded && ( {isExpanded && (
<> <>
<Box mt={2}> <ControlNetImagePreview controlNetId={controlNetId} height="392px" />
<ControlNetImagePreview <ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
controlNet={props.controlNet} <ControlNetProcessorComponent controlNetId={controlNetId} />
height={96}
/>
</Box>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</>
)}
</> </>
)} )}
</Flex> </Flex>

View File

@ -5,42 +5,57 @@ import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image'; import { PostUploadAction } from 'services/api/thunks/image';
import { import { controlNetImageChanged } from '../store/controlNetSlice';
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { pendingControlImages } = controlNet;
return { pendingControlImages };
},
defaultSelectorOptions
);
type Props = { type Props = {
controlNet: ControlNetConfig; controlNetId: string;
height: SystemStyleObject['h']; height: SystemStyleObject['h'];
}; };
const ControlNetImagePreview = (props: Props) => { const ControlNetImagePreview = (props: Props) => {
const { height } = props; const { height, controlNetId } = props;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { pendingControlImages } = controlNet;
const {
controlImage,
processedControlImage,
processorType,
isEnabled,
} = controlNet.controlNets[controlNetId];
return {
controlImageName: controlImage,
processedControlImageName: processedControlImage,
processorType,
isEnabled,
pendingControlImages,
};
},
defaultSelectorOptions
),
[controlNetId]
);
const {
controlImageName,
processedControlImageName,
processorType,
pendingControlImages,
isEnabled,
} = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
h: height, h: height,
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',
opacity: isEnabled ? 1 : 0.5,
}} }}
> >
<IAIDndImage <IAIDndImage
draggableData={draggableData} draggableData={draggableData}
droppableData={droppableData} droppableData={droppableData}
imageDTO={controlImage} imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage} isDropDisabled={shouldShowProcessedImage || !isEnabled}
onClickReset={handleResetControlImage} onClickReset={handleResetControlImage}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
resetTooltip="Reset Control Image" resetTooltip="Reset Control Image"
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
droppableData={droppableData} droppableData={droppableData}
imageDTO={processedControlImage} imageDTO={processedControlImage}
isUploadDisabled={true} isUploadDisabled={true}
isDropDisabled={!isEnabled}
onClickReset={handleResetControlImage} onClickReset={handleResetControlImage}
resetTooltip="Reset Control Image" resetTooltip="Reset Control Image"
withResetIcon={Boolean(controlImage)} withResetIcon={Boolean(controlImage)}

View File

@ -1,10 +1,13 @@
import { memo } from 'react'; import { createSelector } from '@reduxjs/toolkit';
import { RequiredControlNetProcessorNode } from '../store/types'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo, useMemo } from 'react';
import CannyProcessor from './processors/CannyProcessor'; import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor'; import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor'; import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor'; import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor'; import MlsdImageProcessor from './processors/MlsdImageProcessor';
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = { export type ControlNetProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
}; };
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId } = props;
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, processorNode } = useAppSelector(selector);
if (processorNode.type === 'canny_image_processor') { if (processorNode.type === 'canny_image_processor') {
return ( return (
<CannyProcessor <CannyProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
if (processorNode.type === 'hed_image_processor') { if (processorNode.type === 'hed_image_processor') {
return ( return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} /> <HedProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
); );
} }
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartProcessor <LineartProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ContentShuffleProcessor <ContentShuffleProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartAnimeProcessor <LineartAnimeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MediapipeFaceProcessor <MediapipeFaceProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MidasDepthProcessor <MidasDepthProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MlsdImageProcessor <MlsdImageProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<NormalBaeProcessor <NormalBaeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<OpenposeProcessor <OpenposeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<PidiProcessor <PidiProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ZoeDepthProcessor <ZoeDepthProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }

View File

@ -1,18 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice'; import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback, useMemo } from 'react';
type Props = { type Props = {
controlNetId: string; controlNetId: string;
shouldAutoConfig: boolean;
}; };
const ParamControlNetShouldAutoConfig = (props: Props) => { const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke(); const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isBusy = useAppSelector(selectIsBusy);
const handleShouldAutoConfigChanged = useCallback(() => { const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId })); dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor" aria-label="Auto configure processor"
isChecked={shouldAutoConfig} isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged} onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
); );
}; };

View File

@ -1,5 +1,4 @@
import { import {
ChakraProps,
FormControl, FormControl,
FormLabel, FormLabel,
HStack, HStack,
@ -10,34 +9,41 @@ import {
RangeSliderTrack, RangeSliderTrack,
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { import {
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
type Props = { type Props = {
controlNetId: string; controlNetId: string;
beginStepPct: number;
endStepPct: number;
mini?: boolean;
}; };
const formatPct = (v: number) => `${Math.round(v * 100)}%`; const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => { const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, mini = false, endStepPct } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { beginStepPct, endStepPct, isEnabled } =
controlNet.controlNets[controlNetId];
return { beginStepPct, endStepPct, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
const handleStepPctChanged = useCallback( const handleStepPctChanged = useCallback(
(v: number[]) => { (v: number[]) => {
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
return ( return (
<FormControl> <FormControl isDisabled={!isEnabled}>
<FormLabel>Begin / End Step Percentage</FormLabel> <FormLabel>Begin / End Step Percentage</FormLabel>
<HStack w="100%" gap={2} alignItems="center"> <HStack w="100%" gap={2} alignItems="center">
<RangeSlider <RangeSlider
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
minStepsBetweenThumbs={5} minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
> >
<RangeSliderTrack> <RangeSliderTrack>
<RangeSliderFilledTrack /> <RangeSliderFilledTrack />
@ -76,14 +83,11 @@ const ParamControlNetBeginEnd = (props: Props) => {
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow> <Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} /> <RangeSliderThumb index={1} />
</Tooltip> </Tooltip>
{!mini && (
<>
<RangeSliderMark <RangeSliderMark
value={0} value={0}
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
> >
0% 0%
@ -91,7 +95,8 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={0.5} value={0.5}
sx={{ sx={{
...SLIDER_MARK_STYLES, insetInlineStart: '50% !important',
transform: 'translateX(-50%)',
}} }}
> >
50% 50%
@ -101,13 +106,10 @@ const ParamControlNetBeginEnd = (props: Props) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
> >
100% 100%
</RangeSliderMark> </RangeSliderMark>
</>
)}
</RangeSlider> </RangeSlider>
</HStack> </HStack>
</FormControl> </FormControl>

View File

@ -1,15 +1,17 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { import {
ControlModes, ControlModes,
controlNetControlModeChanged, controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = { type ParamControlNetControlModeProps = {
controlNetId: string; controlNetId: string;
controlMode: string;
}; };
const CONTROL_MODE_DATA = [ const CONTROL_MODE_DATA = [
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
export default function ParamControlNetControlMode( export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps props: ParamControlNetControlModeProps
) { ) {
const { controlNetId, controlMode = false } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { controlMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { controlMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
return ( return (
<IAIMantineSelect <IAIMantineSelect
label={t('parameters.controlNetControlMode')} disabled={!isEnabled}
label="Control Mode"
data={CONTROL_MODE_DATA} data={CONTROL_MODE_DATA}
value={String(controlMode)} value={String(controlMode)}
onChange={handleControlModeChange} onChange={handleControlModeChange}

View File

@ -1,28 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isEnabled: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isEnabled } = props;
const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]);
return (
<IAISwitch
label="Enabled"
isChecked={isEnabled}
onChange={handleIsEnabledChanged}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,59 +1,118 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect, { import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
IAISelectDataType, import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
} from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { map } from 'lodash-es'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { memo, useCallback } from 'react'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNetId: string; controlNetId: string;
model: ControlNetModelName;
}; };
const selector = createSelector(configSelector, (config) => { const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({ const { controlNetId } = props;
label: m.label, const dispatch = useAppDispatch();
value: m.type, const isBusy = useAppSelector(selectIsBusy);
})).filter(
(d) => const selector = useMemo(
!config.sd.disabledControlNetModels.includes( () =>
d.value as ControlNetModelName createSelector(
) stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
return { mainModel: model, controlNetModel, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
); );
return controlNetModels; const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
});
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const { data: controlNetModels } = useGetControlNetModelsQuery();
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector); const data = useMemo(() => {
const dispatch = useAppDispatch(); if (!controlNetModels) {
const isReady = useIsReadyToInvoke(); return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
}, [controlNetModels, mainModel?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const handleModelChanged = useCallback( const handleModelChanged = useCallback(
(val: string | null) => { (v: string | null) => {
// TODO: do not cast if (!v) {
const model = val as ControlNetModelName; return;
dispatch(controlNetModelChanged({ controlNetId, model })); }
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
); );
return ( return (
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
data={controlNetModels} itemComponent={IAIMantineSelectItemWithTooltip}
value={model} data={data}
error={
!selectedModel || mainModel?.base_model !== selectedModel.base_model
}
placeholder="Select a model"
value={selectedModel?.id ?? null}
onChange={handleModelChanged} onChange={handleModelChanged}
disabled={!isReady} disabled={isBusy || !isEnabled}
tooltip={model} tooltip={selectedModel?.description}
/> />
); );
}; };

View File

@ -1,24 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, { import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { import { ControlNetProcessorType } from '../../store/types';
ControlNetProcessorNode, import { FormControl, FormLabel } from '@chakra-ui/react';
ControlNetProcessorType,
} from '../../store/types';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
processorNode: ControlNetProcessorNode;
}; };
const selector = createSelector( const selector = createSelector(
@ -54,10 +52,24 @@ const selector = createSelector(
const ParamControlNetProcessorSelect = ( const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps props: ParamControlNetProcessorSelectProps
) => { ) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke(); const { controlNetId } = props;
const processorNodeSelector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const isBusy = useAppSelector(selectIsBusy);
const controlNetProcessors = useAppSelector(selector); const controlNetProcessors = useAppSelector(selector);
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
const handleProcessorTypeChanged = useCallback( const handleProcessorTypeChanged = useCallback(
(v: string | null) => { (v: string | null) => {
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
value={processorNode.type ?? 'canny_image_processor'} value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors} data={controlNetProcessors}
onChange={handleProcessorTypeChanged} onChange={handleProcessorTypeChanged}
disabled={!isReady} disabled={isBusy || !isEnabled}
/> />
); );
}; };

View File

@ -1,18 +1,32 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
type ParamControlNetWeightProps = { type ParamControlNetWeightProps = {
controlNetId: string; controlNetId: string;
weight: number;
mini?: boolean;
}; };
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId, weight, mini = false } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
return { weight, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { weight, isEnabled } = useAppSelector(selector);
const handleWeightChanged = useCallback( const handleWeightChanged = useCallback(
(weight: number) => { (weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight })); dispatch(controlNetWeightChanged({ controlNetId, weight }));
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
return ( return (
<IAISlider <IAISlider
isDisabled={!isEnabled}
label={'Weight'} label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight} value={weight}
onChange={handleWeightChanged} onChange={handleWeightChanged}
min={-1} min={0}
max={1} max={2}
step={0.01} step={0.01}
withSliderMarks={!mini} withSliderMarks
sliderMarks={[-1, 0, 1]} sliderMarks={[0, 1, 2]}
/> />
); );
}; };

View File

@ -1,22 +1,25 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation;
type CannyProcessorProps = { type CannyProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation; processorNode: RequiredCannyImageProcessorInvocation;
isEnabled: boolean;
}; };
const CannyProcessor = (props: CannyProcessorProps) => { const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { low_threshold, high_threshold } = processorNode; const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback( const handleLowThresholdChanged = useCallback(
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return ( return (
<ProcessorWrapper> <ProcessorWrapper>
<IAISlider <IAISlider
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
label="Low Threshold" label="Low Threshold"
value={low_threshold} value={low_threshold}
onChange={handleLowThresholdChanged} onChange={handleLowThresholdChanged}
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
withSliderMarks withSliderMarks
/> />
<IAISlider <IAISlider
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
label="High Threshold" label="High Threshold"
value={high_threshold} value={high_threshold}
onChange={handleHighThresholdChanged} onChange={handleHighThresholdChanged}

View File

@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { useAppSelector } from 'app/store/storeHooks';
import { selectIsBusy } from 'features/system/store/systemSelectors';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
.default as RequiredContentShuffleImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation; processorNode: RequiredContentShuffleImageProcessorInvocation;
isEnabled: boolean;
}; };
const ContentShuffleProcessor = (props: Props) => { const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode; const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="F" label="F"
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,25 +1,29 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
.default as RequiredHedImageProcessorInvocation;
type HedProcessorProps = { type HedProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredHedImageProcessorInvocation; processorNode: RequiredHedImageProcessorInvocation;
isEnabled: boolean;
}; };
const HedPreprocessor = (props: HedProcessorProps) => { const HedPreprocessor = (props: HedProcessorProps) => {
const { const {
controlNetId, controlNetId,
processorNode: { detect_resolution, image_resolution, scribble }, processorNode: { detect_resolution, image_resolution, scribble },
isEnabled,
} = props; } = props;
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={scribble} isChecked={scribble}
onChange={handleScribbleChanged} onChange={handleScribbleChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
.default as RequiredLineartAnimeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredLineartAnimeImageProcessorInvocation; processorNode: RequiredLineartAnimeImageProcessorInvocation;
isEnabled: boolean;
}; };
const LineartAnimeProcessor = (props: Props) => { const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
.default as RequiredLineartImageProcessorInvocation;
type LineartProcessorProps = { type LineartProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredLineartImageProcessorInvocation; processorNode: RequiredLineartImageProcessorInvocation;
isEnabled: boolean;
}; };
const LineartProcessor = (props: LineartProcessorProps) => { const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, coarse } = processorNode; const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Coarse" label="Coarse"
isChecked={coarse} isChecked={coarse}
onChange={handleCoarseChanged} onChange={handleCoarseChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
.default as RequiredMediapipeFaceProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation; processorNode: RequiredMediapipeFaceProcessorInvocation;
isEnabled: boolean;
}; };
const MediapipeFaceProcessor = (props: Props) => { const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { max_faces, min_confidence } = processorNode; const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleMaxFacesChanged = useCallback( const handleMaxFacesChanged = useCallback(
(v: number) => { (v: number) => {
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
max={20} max={20}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Min Confidence" label="Min Confidence"
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
.default as RequiredMidasDepthImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation; processorNode: RequiredMidasDepthImageProcessorInvocation;
isEnabled: boolean;
}; };
const MidasDepthProcessor = (props: Props) => { const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { a_mult, bg_th } = processorNode; const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleAMultChanged = useCallback( const handleAMultChanged = useCallback(
(v: number) => { (v: number) => {
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="bg_th" label="bg_th"
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
.default as RequiredMlsdImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation; processorNode: RequiredMlsdImageProcessorInvocation;
isEnabled: boolean;
}; };
const MlsdImageProcessor = (props: Props) => { const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode; const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
.default as RequiredNormalbaeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation; processorNode: RequiredNormalbaeImageProcessorInvocation;
isEnabled: boolean;
}; };
const NormalBaeProcessor = (props: Props) => { const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
.default as RequiredOpenposeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation; processorNode: RequiredOpenposeImageProcessorInvocation;
isEnabled: boolean;
}; };
const OpenposeProcessor = (props: Props) => { const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode; const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Hand and Face" label="Hand and Face"
isChecked={hand_and_face} isChecked={hand_and_face}
onChange={handleHandAndFaceChanged} onChange={handleHandAndFaceChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
.default as RequiredPidiImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation; processorNode: RequiredPidiImageProcessorInvocation;
isEnabled: boolean;
}; };
const PidiProcessor = (props: Props) => { const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode; const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
label="Safe" label="Safe"
isChecked={safe} isChecked={safe}
onChange={handleSafeChanged} onChange={handleSafeChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation; processorNode: RequiredZoeDepthImageProcessorInvocation;
isEnabled: boolean;
}; };
const ZoeDepthProcessor = (props: Props) => { const ZoeDepthProcessor = (props: Props) => {

View File

@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
}, },
}; };
type ControlNetModelsDict = Record<string, ControlNetModel>; export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
type ControlNetModel = { } = {
type: string; canny: 'canny_image_processor',
label: string; mlsd: 'mlsd_image_processor',
description?: string; depth: 'midas_depth_image_processor',
defaultProcessor?: ControlNetProcessorType; bae: 'normalbae_image_processor',
lineart: 'lineart_image_processor',
lineart_anime: 'lineart_anime_image_processor',
softedge: 'hed_image_processor',
shuffle: 'content_shuffle_image_processor',
openpose: 'openpose_image_processor',
mediapipe: 'mediapipe_face_processor',
}; };
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -1,22 +1,20 @@
import { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import { import {
ControlNetProcessorType, ControlNetProcessorType,
RequiredCannyImageProcessorInvocation, RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes = export type ControlModes =
| 'balanced' | 'balanced'
@ -26,7 +24,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, model: null,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = { export type ControlNetConfig = {
controlNetId: string; controlNetId: string;
isEnabled: boolean; isEnabled: boolean;
model: ControlNetModelName; model: ControlNetModelParam | null;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
controlNetId, controlNetId,
}; };
}, },
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: ( controlNetAddedFromImage: (
state, state,
action: PayloadAction<{ controlNetId: string; controlImage: string }> action: PayloadAction<{ controlNetId: string; controlImage: string }>
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
controlNetId: string; controlNetId: string;
model: ControlNetModelName; model: ControlNetModelParam;
}> }>
) => { ) => {
const { controlNetId, model } = action.payload; const { controlNetId, model } = action.payload;
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null; state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) { if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor; let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) { if (newShouldAutoConfig) {
// manage the processor for the user // manage the processor for the user
const processorType = let processorType: ControlNetProcessorType | undefined = undefined;
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor; for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
}); });
builder.addCase(imageDeleted.pending, (state, action) => { builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery // Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg; const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => { forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) { if (c.controlImage === image_name) {
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
}); });
}); });
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => { builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = []; state.pendingControlImages = [];
}); });
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
export const { export const {
isControlNetEnabledToggled, isControlNetEnabledToggled,
controlNetAdded, controlNetAdded,
controlNetDuplicated,
controlNetAddedFromImage, controlNetAddedFromImage,
controlNetRemoved, controlNetRemoved,
controlNetImageChanged, controlNetImageChanged,

View File

@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
return (
<ControlNetModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') { if (type === 'array' && template.type === 'array') {
return ( return (
<ArrayInputFieldComponent <ArrayInputFieldComponent

View File

@ -0,0 +1,103 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ControlNetModelInputFieldComponent = (
props: FieldComponentProps<
ControlNetModelInputFieldValue,
ControlNetModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const controlNetModel = field.value;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [controlNetModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newControlNetModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(ControlNetModelInputFieldComponent);

View File

@ -1,6 +1,7 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
| ImageField[] | ImageField[]
| MainModelParam | MainModelParam
| VaeModelParam | VaeModelParam
| LoRAModelParam; | LoRAModelParam
| ControlNetModelParam;
}> }>
) => { ) => {
const { nodeId, fieldName, value } = action.payload; const { nodeId, fieldName, value } = action.payload;

View File

@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
model: 'model', model: 'model',
vae_model: 'vae_model', vae_model: 'vae_model',
lora_model: 'lora_model', lora_model: 'lora_model',
controlnet_model: 'controlnet_model',
ControlNetModelField: 'controlnet_model',
array: 'array', array: 'array',
item: 'item', item: 'item',
ColorField: 'color', ColorField: 'color',
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'LoRA', title: 'LoRA',
description: 'Models are models.', description: 'Models are models.',
}, },
controlnet_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'ControlNet',
description: 'Models are models.',
},
array: { array: {
color: 'gray', color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'), colorCssVar: getColorTokenCssVariable('gray'),

View File

@ -1,4 +1,5 @@
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -71,6 +72,7 @@ export type FieldType =
| 'model' | 'model'
| 'vae_model' | 'vae_model'
| 'lora_model' | 'lora_model'
| 'controlnet_model'
| 'array' | 'array'
| 'item' | 'item'
| 'color' | 'color'
@ -100,6 +102,7 @@ export type InputFieldValue =
| MainModelInputFieldValue | MainModelInputFieldValue
| VaeModelInputFieldValue | VaeModelInputFieldValue
| LoRAModelInputFieldValue | LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
| ArrayInputFieldValue | ArrayInputFieldValue
| ItemInputFieldValue | ItemInputFieldValue
| ColorInputFieldValue | ColorInputFieldValue
@ -127,6 +130,7 @@ export type InputFieldTemplate =
| ModelInputFieldTemplate | ModelInputFieldTemplate
| VaeModelInputFieldTemplate | VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate | LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ArrayInputFieldTemplate | ArrayInputFieldTemplate
| ItemInputFieldTemplate | ItemInputFieldTemplate
| ColorInputFieldTemplate | ColorInputFieldTemplate
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
value?: LoRAModelParam; value?: LoRAModelParam;
}; };
export type ControlNetModelInputFieldValue = FieldValueBase & {
type: 'controlnet_model';
value?: ControlNetModelParam;
};
export type ArrayInputFieldValue = FieldValueBase & { export type ArrayInputFieldValue = FieldValueBase & {
type: 'array'; type: 'array';
value?: (string | number)[]; value?: (string | number)[];
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'lora_model'; type: 'lora_model';
}; };
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'controlnet_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & { export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: []; default: [];
type: 'array'; type: 'array';

View File

@ -9,6 +9,7 @@ import {
ColorInputFieldTemplate, ColorInputFieldTemplate,
ConditioningInputFieldTemplate, ConditioningInputFieldTemplate,
ControlInputFieldTemplate, ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
EnumInputFieldTemplate, EnumInputFieldTemplate,
FieldType, FieldType,
FloatInputFieldTemplate, FloatInputFieldTemplate,
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
return template; return template;
}; };
const buildControlNetModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
const template: ControlNetModelInputFieldTemplate = {
...baseField,
type: 'controlnet_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({ const buildImageInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
if (['lora_model'].includes(fieldType)) { if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
} }
if (['controlnet_model'].includes(fieldType)) {
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) { if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField }); return buildEnumInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -83,6 +83,10 @@ export const buildInputFieldValue = (
if (template.type === 'lora_model') { if (template.type === 'lora_model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'controlnet_model') {
fieldValue.value = undefined;
}
} }
return fieldValue; return fieldValue;

View File

@ -8,6 +8,7 @@ import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import { import {
controlNetAdded, controlNetAdded,
controlNetModelChanged,
controlNetSelector, controlNetSelector,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
@ -15,6 +16,7 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { Fragment, memo, useCallback } from 'react'; import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
const selector = createSelector( const selector = createSelector(
@ -39,10 +41,23 @@ const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector); const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { firstModel } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => {
const firstModel = result.data?.entities[result.data?.ids[0]];
return {
firstModel,
};
},
});
const handleClickedAddControlNet = useCallback(() => { const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() })); if (!firstModel) {
}, [dispatch]); return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) { if (isControlNetDisabled) {
return null; return null;
@ -52,15 +67,20 @@ const ParamControlNetCollapse = () => {
<IAICollapse label="ControlNet" activeLabel={activeLabel}> <IAICollapse label="ControlNet" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 3 }}> <Flex sx={{ flexDir: 'column', gap: 3 }}>
<ParamControlNetFeatureToggle /> <ParamControlNetFeatureToggle />
<IAIButton
isDisabled={!firstModel}
flexGrow={1}
size="sm"
onClick={handleClickedAddControlNet}
>
Add ControlNet
</IAIButton>
{controlNetsArray.map((c, i) => ( {controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}> <Fragment key={c.controlNetId}>
{i > 0 && <Divider />} {i > 0 && <Divider />}
<ControlNet controlNet={c} /> <ControlNet controlNetId={c.controlNetId} />
</Fragment> </Fragment>
))} ))}
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
return []; return [];
} }
// add a "default" option, this means use the main model's included VAE
const data: SelectItem[] = [ const data: SelectItem[] = [
{ {
value: 'default', value: 'default',

View File

@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
const IN_PROGRESS_STYLES: ChakraProps['sx'] = { const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
_disabled: { _disabled: {
bg: 'none', bg: 'none',
color: 'base.600',
cursor: 'not-allowed', cursor: 'not-allowed',
_hover: { _hover: {
color: 'base.600',
bg: 'none', bg: 'none',
}, },
}, },

View File

@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
*/ */
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam => export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success; zLoRAModel.safeParse(val).success;
/**
* Zod schema for ControlNet models
*/
export const zControlNetModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidControlNetModel = (
val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/** /**
* Zod schema for l2l strength parameter * Zod schema for l2l strength parameter

View File

@ -0,0 +1,30 @@
import { log } from 'app/logging/useLogger';
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
import { ControlNetModelField } from 'services/api/types';
const moduleLog = log.child({ module: 'models' });
export const modelIdToControlNetModelParam = (
controlNetModelId: string
): ControlNetModelField | undefined => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const result = zControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -1,9 +1,12 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToLoRAModelParam = ( export const modelIdToLoRAModelParam = (
loraId: string loraModelId: string
): LoRAModelParam | undefined => { ): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/'); const [base_model, model_type, model_name] = loraModelId.split('/');
const result = zLoRAModel.safeParse({ const result = zLoRAModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return; return;
} }

View File

@ -2,11 +2,14 @@ import {
MainModelParam, MainModelParam,
zMainModel, zMainModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToMainModelParam = ( export const modelIdToMainModelParam = (
modelId: string mainModelId: string
): MainModelParam | undefined => { ): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({ const result = zMainModel.safeParse({
base_model, base_model,
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return; return;
} }

View File

@ -1,9 +1,12 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToVAEModelParam = ( export const modelIdToVAEModelParam = (
modelId: string vaeModelId: string
): VaeModelParam | undefined => { ): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = vaeModelId.split('/');
const result = zVaeModel.safeParse({ const result = zVaeModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return; return;
} }

View File

@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />
<ParamSeamCorrectionCollapse /> <ParamSeamCorrectionCollapse />

View File

@ -734,7 +734,7 @@ export type components = {
* Control Model * Control Model
* @description The ControlNet model to use * @description The ControlNet model to use
*/ */
control_model: string; control_model: components["schemas"]["ControlNetModelField"];
/** /**
* Control Weight * Control Weight
* @description The weight given to the ControlNet * @description The weight given to the ControlNet
@ -792,9 +792,8 @@ export type components = {
* Control Model * Control Model
* @description control model used * @description control model used
* @default lllyasviel/sd-controlnet-canny * @default lllyasviel/sd-controlnet-canny
* @enum {string}
*/ */
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace"; control_model?: components["schemas"]["ControlNetModelField"];
/** /**
* Control Weight * Control Weight
* @description The weight given to the ControlNet * @description The weight given to the ControlNet
@ -838,6 +837,19 @@ export type components = {
model_format: components["schemas"]["ControlNetModelFormat"]; model_format: components["schemas"]["ControlNetModelFormat"];
error?: components["schemas"]["ModelError"]; error?: components["schemas"]["ModelError"];
}; };
/**
* ControlNetModelField
* @description ControlNet model field
*/
ControlNetModelField: {
/**
* Model Name
* @description Name of the ControlNet model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/** /**
* ControlNetModelFormat * ControlNetModelFormat
* @description An enumeration. * @description An enumeration.
@ -1923,12 +1935,12 @@ export type components = {
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
*/ */
height: number; height?: number;
/** /**
* Resample Mode * Resample Mode
* @description The resampling mode * @description The resampling mode
@ -3911,13 +3923,15 @@ export type components = {
/** /**
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
* @default 512
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
* @default 512
*/ */
height: number; height?: number;
/** /**
* Mode * Mode
* @description The interpolation mode * @description The interpolation mode
@ -4605,18 +4619,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField']; export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField']; export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField']; export type LoRAModelField = components['schemas']['LoRAModelField'];
export type ControlNetModelField =
components['schemas']['ControlNetModelField'];
export type ModelsList = components['schemas']['ModelsList']; export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField']; export type ControlField = components['schemas']['ControlField'];

View File

@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
const invokeAIMark = defineStyle((props) => { const invokeAIMark = defineStyle((props) => {
return { return {
fontSize: 'xs', fontSize: '2xs',
fontWeight: '500', fontWeight: '500',
color: mode('base.700', 'base.400')(props), color: mode('base.700', 'base.400')(props),
mt: 2, mt: 2,