use controlnet from metadata if available (#4658)

* add control net to useRecallParams

* got recall controlnets working

* fix metadata viewer controlnet

* fix type errors

* fix controlnet metadata viewer

* set control image and use correct processor type and node

* clean up logs

* recall processor using substring

* feat(ui): enable controlNet when recalling one

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
chainchompa 2023-09-27 05:30:50 -04:00 committed by GitHub
parent 3432fd72f8
commit 4a0a1c30db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 196 additions and 4 deletions

View File

@ -98,6 +98,9 @@ export const controlNetSlice = createSlice({
isControlNetEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
controlNetEnabled: (state) => {
state.isEnabled = true;
},
controlNetAdded: (
state,
action: PayloadAction<{
@ -111,6 +114,12 @@ export const controlNetSlice = createSlice({
controlNetId,
};
},
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
const controlNet = action.payload;
state.controlNets[controlNet.controlNetId] = {
...controlNet,
};
},
controlNetDuplicated: (
state,
action: PayloadAction<{
@ -439,7 +448,9 @@ export const controlNetSlice = createSlice({
export const {
isControlNetEnabledToggled,
controlNetEnabled,
controlNetAdded,
controlNetRecalled,
controlNetDuplicated,
controlNetAddedFromImage,
controlNetRemoved,

View File

@ -1,8 +1,15 @@
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
import {
ControlNetMetadataItem,
CoreMetadata,
LoRAMetadataItem,
} from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react';
import { memo, useMemo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
import {
isValidControlNetModel,
isValidLoRAModel,
} from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem';
type Props = {
@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
recallHeight,
recallStrength,
recallLoRA,
recallControlNet,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => {
[recallLoRA]
);
const handleRecallControlNet = useCallback(
(controlnet: ControlNetMetadataItem) => {
recallControlNet(controlnet);
},
[recallControlNet]
);
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) =>
isValidControlNetModel(controlnet.control_model)
)
: [];
}, [metadata?.controlnets]);
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => {
);
}
})}
{validControlNets.map((controlnet, index) => (
<ImageMetadataItem
key={index}
label="ControlNet"
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
onClick={() => handleRecallControlNet(controlnet)}
/>
))}
</>
);
};

View File

@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
const zControlNetMetadataItem = zControlField.deepPartial();
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
export const zCoreMetadata = z
.object({
app_version: z.string().nullish().catch(null),

View File

@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
import {
CoreMetadata,
LoRAMetadataItem,
ControlNetMetadataItem,
} from 'features/nodes/types/types';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
@ -18,9 +22,18 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types';
import {
controlNetModelsAdapter,
loraModelsAdapter,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import {
ControlNetConfig,
controlNetEnabled,
controlNetRecalled,
controlNetReset,
initialControlNet,
} from '../../controlNet/store/controlNetSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions';
import {
@ -38,6 +51,7 @@ import {
isValidCfgScale,
isValidHeight,
isValidLoRAModel,
isValidControlNetModel,
isValidMainModel,
isValidNegativePrompt,
isValidPositivePrompt,
@ -53,6 +67,11 @@ import {
isValidStrength,
isValidWidth,
} from '../types/parameterSchemas';
import { v4 as uuidv4 } from 'uuid';
import {
CONTROLNET_PROCESSORS,
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
} from 'features/controlNet/store/constants';
const selector = createSelector(stateSelector, ({ generation }) => {
const { model } = generation;
@ -390,6 +409,121 @@ export const useRecallParameters = () => {
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall ControlNet with toast
*/
const { controlnets } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => ({
controlnets: result.data
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
if (!isValidControlNetModel(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const {
image,
control_model,
control_weight,
begin_step_percent,
end_step_percent,
control_mode,
resize_mode,
} = controlnetMetadataItem;
const matchingControlNetModel = controlnets.find(
(c) =>
c.base_model === control_model.base_model &&
c.model_name === control_model.model_name
);
if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel =
matchingControlNetModel?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
controlnet: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
const controlNetId = uuidv4();
let processorType = initialControlNet.processorType;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
processorType =
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
initialControlNet.processorType;
break;
}
}
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
const controlnet: ControlNetConfig = {
isEnabled: true,
model: matchingControlNetModel,
weight:
typeof control_weight === 'number'
? control_weight
: initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode:
processorNode.type !== 'none'
? processorNode
: initialControlNet.processorNode,
shouldAutoConfig: true,
controlNetId,
};
return { controlnet, error: null };
},
[controlnets, model?.base_model]
);
const recallControlNet = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
if (!result.controlnet) {
parameterNotSetToast(result.error);
return;
}
dispatch(
controlNetRecalled({
...result.controlnet,
})
);
dispatch(controlNetEnabled());
parameterSetToast();
},
[
prepareControlNetMetadataItem,
dispatch,
parameterSetToast,
parameterNotSetToast,
]
);
/*
* Sets image as initial image with toast
*/
@ -428,6 +562,7 @@ export const useRecallParameters = () => {
refiner_negative_aesthetic_score,
refiner_start,
loras,
controlnets,
} = metadata;
if (isValidCfgScale(cfg_scale)) {
@ -517,6 +652,15 @@ export const useRecallParameters = () => {
}
});
dispatch(controlNetReset());
dispatch(controlNetEnabled());
controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet);
if (result.controlnet) {
dispatch(controlNetRecalled(result.controlnet));
}
});
allParameterSetToast();
},
[
@ -524,6 +668,7 @@ export const useRecallParameters = () => {
allParameterSetToast,
dispatch,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
]
);
@ -542,6 +687,7 @@ export const useRecallParameters = () => {
recallHeight,
recallStrength,
recallLoRA,
recallControlNet,
recallAllParameters,
sendToImageToImage,
};