remove frontend constants, use backend response for controlnet models. add disabled state if base model is not compatible. clear control net model if main base model changes. add logic to guess processor and move it up in UI

This commit is contained in:
Mary Hipp 2023-07-11 16:18:38 -04:00 committed by psychedelicious
parent 5ac114576f
commit 76dc47e88d
5 changed files with 97 additions and 137 deletions

View File

@ -1,5 +1,5 @@
import {
CONTROLNET_MODELS,
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
@ -128,7 +128,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
disabledControlNetModels: string[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: {
initial: number;

View File

@ -124,6 +124,7 @@ const ControlNet = (props: ControlNetProps) => {
/>
}
/>
{!shouldAutoConfig && (
<Box
sx={{
@ -138,6 +139,16 @@ const ControlNet = (props: ControlNetProps) => {
/>
)}
</Flex>
<Flex alignItems="flex-end" gap="2">
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</Flex>
{isEnabled && (
<>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
@ -196,18 +207,10 @@ const ControlNet = (props: ControlNetProps) => {
height={96}
/>
</Box>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</>
)}
</>

View File

@ -1,55 +1,71 @@
import { createSelector } from '@reduxjs/toolkit';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = {
controlNetId: string;
model: ControlNetModelName;
model: string;
};
const selector = createSelector(configSelector, (config) => {
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
label: m.label,
value: m.type,
})).filter(
(d) =>
!config.sd.disabledControlNetModels.includes(
d.value as ControlNetModelName
)
);
return controlNetModels;
});
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector);
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const { data: controlNetModels } = useGetControlNetModelsQuery();
const handleModelChanged = useCallback(
(val: string | null) => {
// TODO: do not cast
const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
if (!val) return;
dispatch(controlNetModelChanged({ controlNetId, model: val }));
},
[controlNetId, dispatch]
);
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
const disabled = currentMainModel?.base_model !== model.base_model;
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
}, [controlNetModels, currentMainModel?.base_model]);
return (
<IAIMantineSearchableSelect
data={controlNetModels}
itemComponent={IAIMantineSelectItemWithTooltip}
data={data}
value={model}
onChange={handleModelChanged}
disabled={!isReady}

View File

@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
},
};
type ControlNetModelsDict = Record<string, ControlNetModel>;
type ControlNetModel = {
type: string;
label: string;
description?: string;
defaultProcessor?: ControlNetProcessorType;
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
} = {
canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor',
depth: 'midas_depth_image_processor',
bae: 'normalbae_image_processor',
lineart: 'lineart_image_processor',
lineart_anime: 'lineart_anime_image_processor',
softedge: 'hed_image_processor',
shuffle: 'content_shuffle_image_processor',
openpose: 'openpose_image_processor',
mediapipe: 'mediapipe_face_processor',
};
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -8,9 +8,10 @@ import {
RequiredControlNetProcessorNode,
} from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
// ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
@ -26,7 +27,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
model: '',
weight: 1,
beginStepPct: 0,
endStepPct: 1,
@ -42,7 +43,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelName;
model: string;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -147,7 +148,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
model: string;
}>
) => {
const { controlNetId, model } = action.payload;
@ -155,7 +156,15 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -241,9 +250,15 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (state.controlNets[controlNetId].model.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[