mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
remove frontend constants, use backend response for controlnet models. add disabled state if base model is not compatible. clear control net model if main base model changes. add logic to guess processor and move it up in UI
This commit is contained in:
parent
5ac114576f
commit
76dc47e88d
@ -1,5 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
CONTROLNET_MODELS,
|
// CONTROLNET_MODELS,
|
||||||
CONTROLNET_PROCESSORS,
|
CONTROLNET_PROCESSORS,
|
||||||
} from 'features/controlNet/store/constants';
|
} from 'features/controlNet/store/constants';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
@ -128,7 +128,7 @@ export type AppConfig = {
|
|||||||
canRestoreDeletedImagesFromBin: boolean;
|
canRestoreDeletedImagesFromBin: boolean;
|
||||||
sd: {
|
sd: {
|
||||||
defaultModel?: string;
|
defaultModel?: string;
|
||||||
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
|
disabledControlNetModels: string[];
|
||||||
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
||||||
iterations: {
|
iterations: {
|
||||||
initial: number;
|
initial: number;
|
||||||
|
@ -124,6 +124,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{!shouldAutoConfig && (
|
{!shouldAutoConfig && (
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
@ -138,6 +139,16 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
<Flex alignItems="flex-end" gap="2">
|
||||||
|
<ParamControlNetProcessorSelect
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
<ParamControlNetShouldAutoConfig
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
shouldAutoConfig={shouldAutoConfig}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
{isEnabled && (
|
{isEnabled && (
|
||||||
<>
|
<>
|
||||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||||
@ -196,18 +207,10 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
height={96}
|
height={96}
|
||||||
/>
|
/>
|
||||||
</Box>
|
</Box>
|
||||||
<ParamControlNetProcessorSelect
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
processorNode={processorNode}
|
|
||||||
/>
|
|
||||||
<ControlNetProcessorComponent
|
<ControlNetProcessorComponent
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
/>
|
/>
|
||||||
<ParamControlNetShouldAutoConfig
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
shouldAutoConfig={shouldAutoConfig}
|
|
||||||
/>
|
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
|
@ -1,55 +1,71 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSearchableSelect, {
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
IAISelectDataType,
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
} from 'common/components/IAIMantineSearchableSelect';
|
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
import {
|
|
||||||
CONTROLNET_MODELS,
|
|
||||||
ControlNetModelName,
|
|
||||||
} from 'features/controlNet/store/constants';
|
|
||||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { map } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ParamControlNetModelProps = {
|
type ParamControlNetModelProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
model: ControlNetModelName;
|
model: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(configSelector, (config) => {
|
|
||||||
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
|
|
||||||
label: m.label,
|
|
||||||
value: m.type,
|
|
||||||
})).filter(
|
|
||||||
(d) =>
|
|
||||||
!config.sd.disabledControlNetModels.includes(
|
|
||||||
d.value as ControlNetModelName
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
return controlNetModels;
|
|
||||||
});
|
|
||||||
|
|
||||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||||
const { controlNetId, model } = props;
|
const { controlNetId, model } = props;
|
||||||
const controlNetModels = useAppSelector(selector);
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
|
const currentMainModel = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.model
|
||||||
|
);
|
||||||
|
|
||||||
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
const handleModelChanged = useCallback(
|
const handleModelChanged = useCallback(
|
||||||
(val: string | null) => {
|
(val: string | null) => {
|
||||||
// TODO: do not cast
|
if (!val) return;
|
||||||
const model = val as ControlNetModelName;
|
dispatch(controlNetModelChanged({ controlNetId, model: val }));
|
||||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
|
||||||
},
|
},
|
||||||
[controlNetId, dispatch]
|
[controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!controlNetModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(controlNetModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const disabled = currentMainModel?.base_model !== model.base_model;
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
disabled,
|
||||||
|
tooltip: disabled
|
||||||
|
? `Incompatible base model: ${model.base_model}`
|
||||||
|
: undefined,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [controlNetModels, currentMainModel?.base_model]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSearchableSelect
|
<IAIMantineSearchableSelect
|
||||||
data={controlNetModels}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
|
data={data}
|
||||||
value={model}
|
value={model}
|
||||||
onChange={handleModelChanged}
|
onChange={handleModelChanged}
|
||||||
disabled={!isReady}
|
disabled={!isReady}
|
||||||
|
@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
type ControlNetModelsDict = Record<string, ControlNetModel>;
|
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
|
||||||
|
[key: string]: ControlNetProcessorType;
|
||||||
type ControlNetModel = {
|
} = {
|
||||||
type: string;
|
canny: 'canny_image_processor',
|
||||||
label: string;
|
mlsd: 'mlsd_image_processor',
|
||||||
description?: string;
|
depth: 'midas_depth_image_processor',
|
||||||
defaultProcessor?: ControlNetProcessorType;
|
bae: 'normalbae_image_processor',
|
||||||
|
lineart: 'lineart_image_processor',
|
||||||
|
lineart_anime: 'lineart_anime_image_processor',
|
||||||
|
softedge: 'hed_image_processor',
|
||||||
|
shuffle: 'content_shuffle_image_processor',
|
||||||
|
openpose: 'openpose_image_processor',
|
||||||
|
mediapipe: 'mediapipe_face_processor',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const CONTROLNET_MODELS: ControlNetModelsDict = {
|
|
||||||
'lllyasviel/control_v11p_sd15_canny': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
|
||||||
label: 'Canny',
|
|
||||||
defaultProcessor: 'canny_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_inpaint': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
|
||||||
label: 'Inpaint',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_mlsd': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
|
||||||
label: 'M-LSD',
|
|
||||||
defaultProcessor: 'mlsd_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11f1p_sd15_depth': {
|
|
||||||
type: 'lllyasviel/control_v11f1p_sd15_depth',
|
|
||||||
label: 'Depth',
|
|
||||||
defaultProcessor: 'midas_depth_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_normalbae': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_normalbae',
|
|
||||||
label: 'Normal Map (BAE)',
|
|
||||||
defaultProcessor: 'normalbae_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_seg': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
|
||||||
label: 'Segmentation',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_lineart': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
|
||||||
label: 'Lineart',
|
|
||||||
defaultProcessor: 'lineart_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
|
||||||
label: 'Lineart Anime',
|
|
||||||
defaultProcessor: 'lineart_anime_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_scribble': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_scribble',
|
|
||||||
label: 'Scribble',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_softedge': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_softedge',
|
|
||||||
label: 'Soft Edge',
|
|
||||||
defaultProcessor: 'hed_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11e_sd15_shuffle': {
|
|
||||||
type: 'lllyasviel/control_v11e_sd15_shuffle',
|
|
||||||
label: 'Content Shuffle',
|
|
||||||
defaultProcessor: 'content_shuffle_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_openpose': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_openpose',
|
|
||||||
label: 'Openpose',
|
|
||||||
defaultProcessor: 'openpose_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11f1e_sd15_tile': {
|
|
||||||
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
|
||||||
label: 'Tile (experimental)',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11e_sd15_ip2p': {
|
|
||||||
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
|
||||||
label: 'Pix2Pix (experimental)',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'CrucibleAI/ControlNetMediaPipeFace': {
|
|
||||||
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
|
||||||
label: 'Mediapipe Face',
|
|
||||||
defaultProcessor: 'mediapipe_face_processor',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;
|
|
||||||
|
@ -8,9 +8,10 @@ import {
|
|||||||
RequiredControlNetProcessorNode,
|
RequiredControlNetProcessorNode,
|
||||||
} from './types';
|
} from './types';
|
||||||
import {
|
import {
|
||||||
CONTROLNET_MODELS,
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||||
|
// CONTROLNET_MODELS,
|
||||||
CONTROLNET_PROCESSORS,
|
CONTROLNET_PROCESSORS,
|
||||||
ControlNetModelName,
|
// ControlNetModelName,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { controlNetImageProcessed } from './actions';
|
import { controlNetImageProcessed } from './actions';
|
||||||
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
|
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
|
||||||
@ -26,7 +27,7 @@ export type ControlModes =
|
|||||||
|
|
||||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
model: '',
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
@ -42,7 +43,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
|||||||
export type ControlNetConfig = {
|
export type ControlNetConfig = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
model: ControlNetModelName;
|
model: string;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
@ -147,7 +148,7 @@ export const controlNetSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
model: ControlNetModelName;
|
model: string;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, model } = action.payload;
|
const { controlNetId, model } = action.payload;
|
||||||
@ -155,7 +156,15 @@ export const controlNetSlice = createSlice({
|
|||||||
state.controlNets[controlNetId].processedControlImage = null;
|
state.controlNets[controlNetId].processedControlImage = null;
|
||||||
|
|
||||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
||||||
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
|
|
||||||
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
if (model.includes(modelSubstring)) {
|
||||||
|
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
state.controlNets[controlNetId].processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||||
@ -241,9 +250,15 @@ export const controlNetSlice = createSlice({
|
|||||||
|
|
||||||
if (newShouldAutoConfig) {
|
if (newShouldAutoConfig) {
|
||||||
// manage the processor for the user
|
// manage the processor for the user
|
||||||
const processorType =
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
|
|
||||||
.defaultProcessor;
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
if (state.controlNets[controlNetId].model.includes(modelSubstring)) {
|
||||||
|
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
state.controlNets[controlNetId].processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||||
|
Loading…
Reference in New Issue
Block a user