use controlnet from metadata if available (#4658)

* add control net to useRecallParams

* got recall controlnets working

* fix metadata viewer controlnet

* fix type errors

* fix controlnet metadata viewer

* set control image and use correct processor type and node

* clean up logs

* recall processor using substring

* feat(ui): enable controlNet when recalling one

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
chainchompa 2023-09-27 05:30:50 -04:00 committed by GitHub
parent 3432fd72f8
commit 4a0a1c30db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 196 additions and 4 deletions

View File

@ -98,6 +98,9 @@ export const controlNetSlice = createSlice({
isControlNetEnabledToggled: (state) => { isControlNetEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled; state.isEnabled = !state.isEnabled;
}, },
controlNetEnabled: (state) => {
state.isEnabled = true;
},
controlNetAdded: ( controlNetAdded: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -111,6 +114,12 @@ export const controlNetSlice = createSlice({
controlNetId, controlNetId,
}; };
}, },
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
const controlNet = action.payload;
state.controlNets[controlNet.controlNetId] = {
...controlNet,
};
},
controlNetDuplicated: ( controlNetDuplicated: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -439,7 +448,9 @@ export const controlNetSlice = createSlice({
export const { export const {
isControlNetEnabledToggled, isControlNetEnabledToggled,
controlNetEnabled,
controlNetAdded, controlNetAdded,
controlNetRecalled,
controlNetDuplicated, controlNetDuplicated,
controlNetAddedFromImage, controlNetAddedFromImage,
controlNetRemoved, controlNetRemoved,

View File

@ -1,8 +1,15 @@
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import {
ControlNetMetadataItem,
CoreMetadata,
LoRAMetadataItem,
} from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react'; import { memo, useMemo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; import {
isValidControlNetModel,
isValidLoRAModel,
} from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem'; import ImageMetadataItem from './ImageMetadataItem';
type Props = { type Props = {
@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
recallHeight, recallHeight,
recallStrength, recallStrength,
recallLoRA, recallLoRA,
recallControlNet,
} = useRecallParameters(); } = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => { const handleRecallPositivePrompt = useCallback(() => {
@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => {
[recallLoRA] [recallLoRA]
); );
const handleRecallControlNet = useCallback(
(controlnet: ControlNetMetadataItem) => {
recallControlNet(controlnet);
},
[recallControlNet]
);
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) =>
isValidControlNetModel(controlnet.control_model)
)
: [];
}, [metadata?.controlnets]);
if (!metadata || Object.keys(metadata).length === 0) { if (!metadata || Object.keys(metadata).length === 0) {
return null; return null;
} }
@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => {
); );
} }
})} })}
{validControlNets.map((controlnet, index) => (
<ImageMetadataItem
key={index}
label="ControlNet"
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
onClick={() => handleRecallControlNet(controlnet)}
/>
))}
</> </>
); );
}; };

View File

@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>; export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
const zControlNetMetadataItem = zControlField.deepPartial();
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
export const zCoreMetadata = z export const zCoreMetadata = z
.object({ .object({
app_version: z.string().nullish().catch(null), app_version: z.string().nullish().catch(null),

View File

@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import {
CoreMetadata,
LoRAMetadataItem,
ControlNetMetadataItem,
} from 'features/nodes/types/types';
import { import {
refinerModelChanged, refinerModelChanged,
setNegativeStylePromptSDXL, setNegativeStylePromptSDXL,
@ -18,9 +22,18 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { import {
controlNetModelsAdapter,
loraModelsAdapter, loraModelsAdapter,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models'; } from '../../../services/api/endpoints/models';
import {
ControlNetConfig,
controlNetEnabled,
controlNetRecalled,
controlNetReset,
initialControlNet,
} from '../../controlNet/store/controlNetSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice'; import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
@ -38,6 +51,7 @@ import {
isValidCfgScale, isValidCfgScale,
isValidHeight, isValidHeight,
isValidLoRAModel, isValidLoRAModel,
isValidControlNetModel,
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
@ -53,6 +67,11 @@ import {
isValidStrength, isValidStrength,
isValidWidth, isValidWidth,
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
import { v4 as uuidv4 } from 'uuid';
import {
CONTROLNET_PROCESSORS,
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
} from 'features/controlNet/store/constants';
const selector = createSelector(stateSelector, ({ generation }) => { const selector = createSelector(stateSelector, ({ generation }) => {
const { model } = generation; const { model } = generation;
@ -390,6 +409,121 @@ export const useRecallParameters = () => {
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall ControlNet with toast
*/
const { controlnets } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => ({
controlnets: result.data
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
if (!isValidControlNetModel(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const {
image,
control_model,
control_weight,
begin_step_percent,
end_step_percent,
control_mode,
resize_mode,
} = controlnetMetadataItem;
const matchingControlNetModel = controlnets.find(
(c) =>
c.base_model === control_model.base_model &&
c.model_name === control_model.model_name
);
if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel =
matchingControlNetModel?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
controlnet: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
const controlNetId = uuidv4();
let processorType = initialControlNet.processorType;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
processorType =
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
initialControlNet.processorType;
break;
}
}
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
const controlnet: ControlNetConfig = {
isEnabled: true,
model: matchingControlNetModel,
weight:
typeof control_weight === 'number'
? control_weight
: initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode:
processorNode.type !== 'none'
? processorNode
: initialControlNet.processorNode,
shouldAutoConfig: true,
controlNetId,
};
return { controlnet, error: null };
},
[controlnets, model?.base_model]
);
const recallControlNet = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
if (!result.controlnet) {
parameterNotSetToast(result.error);
return;
}
dispatch(
controlNetRecalled({
...result.controlnet,
})
);
dispatch(controlNetEnabled());
parameterSetToast();
},
[
prepareControlNetMetadataItem,
dispatch,
parameterSetToast,
parameterNotSetToast,
]
);
/* /*
* Sets image as initial image with toast * Sets image as initial image with toast
*/ */
@ -428,6 +562,7 @@ export const useRecallParameters = () => {
refiner_negative_aesthetic_score, refiner_negative_aesthetic_score,
refiner_start, refiner_start,
loras, loras,
controlnets,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -517,6 +652,15 @@ export const useRecallParameters = () => {
} }
}); });
dispatch(controlNetReset());
dispatch(controlNetEnabled());
controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet);
if (result.controlnet) {
dispatch(controlNetRecalled(result.controlnet));
}
});
allParameterSetToast(); allParameterSetToast();
}, },
[ [
@ -524,6 +668,7 @@ export const useRecallParameters = () => {
allParameterSetToast, allParameterSetToast,
dispatch, dispatch,
prepareLoRAMetadataItem, prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
] ]
); );
@ -542,6 +687,7 @@ export const useRecallParameters = () => {
recallHeight, recallHeight,
recallStrength, recallStrength,
recallLoRA, recallLoRA,
recallControlNet,
recallAllParameters, recallAllParameters,
sendToImageToImage, sendToImageToImage,
}; };