mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use controlnet from metadata if available (#4658)
* add control net to useRecallParams * got recall controlnets working * fix metadata viewer controlnet * fix type errors * fix controlnet metadata viewer * set control image and use correct processor type and node * clean up logs * recall processor using substring * feat(ui): enable controlNet when recalling one --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
3432fd72f8
commit
4a0a1c30db
@ -98,6 +98,9 @@ export const controlNetSlice = createSlice({
|
||||
isControlNetEnabledToggled: (state) => {
|
||||
state.isEnabled = !state.isEnabled;
|
||||
},
|
||||
controlNetEnabled: (state) => {
|
||||
state.isEnabled = true;
|
||||
},
|
||||
controlNetAdded: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -111,6 +114,12 @@ export const controlNetSlice = createSlice({
|
||||
controlNetId,
|
||||
};
|
||||
},
|
||||
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
|
||||
const controlNet = action.payload;
|
||||
state.controlNets[controlNet.controlNetId] = {
|
||||
...controlNet,
|
||||
};
|
||||
},
|
||||
controlNetDuplicated: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -439,7 +448,9 @@ export const controlNetSlice = createSlice({
|
||||
|
||||
export const {
|
||||
isControlNetEnabledToggled,
|
||||
controlNetEnabled,
|
||||
controlNetAdded,
|
||||
controlNetRecalled,
|
||||
controlNetDuplicated,
|
||||
controlNetAddedFromImage,
|
||||
controlNetRemoved,
|
||||
|
@ -1,8 +1,15 @@
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import {
|
||||
ControlNetMetadataItem,
|
||||
CoreMetadata,
|
||||
LoRAMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useMemo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
||||
import {
|
||||
isValidControlNetModel,
|
||||
isValidLoRAModel,
|
||||
} from '../../../parameters/types/parameterSchemas';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
type Props = {
|
||||
@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
} = useRecallParameters();
|
||||
|
||||
const handleRecallPositivePrompt = useCallback(() => {
|
||||
@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => {
|
||||
[recallLoRA]
|
||||
);
|
||||
|
||||
const handleRecallControlNet = useCallback(
|
||||
(controlnet: ControlNetMetadataItem) => {
|
||||
recallControlNet(controlnet);
|
||||
},
|
||||
[recallControlNet]
|
||||
);
|
||||
|
||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||
return metadata?.controlnets
|
||||
? metadata.controlnets.filter((controlnet) =>
|
||||
isValidControlNetModel(controlnet.control_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.controlnets]);
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => {
|
||||
);
|
||||
}
|
||||
})}
|
||||
{validControlNets.map((controlnet, index) => (
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="ControlNet"
|
||||
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
|
||||
onClick={() => handleRecallControlNet(controlnet)}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({
|
||||
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
|
||||
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
|
||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish().catch(null),
|
||||
|
@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import {
|
||||
CoreMetadata,
|
||||
LoRAMetadataItem,
|
||||
ControlNetMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
@ -18,9 +22,18 @@ import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
controlNetModelsAdapter,
|
||||
loraModelsAdapter,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
} from '../../../services/api/endpoints/models';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetEnabled,
|
||||
controlNetRecalled,
|
||||
controlNetReset,
|
||||
initialControlNet,
|
||||
} from '../../controlNet/store/controlNetSlice';
|
||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import {
|
||||
@ -38,6 +51,7 @@ import {
|
||||
isValidCfgScale,
|
||||
isValidHeight,
|
||||
isValidLoRAModel,
|
||||
isValidControlNetModel,
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
@ -53,6 +67,11 @@ import {
|
||||
isValidStrength,
|
||||
isValidWidth,
|
||||
} from '../types/parameterSchemas';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import {
|
||||
CONTROLNET_PROCESSORS,
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||
} from 'features/controlNet/store/constants';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ generation }) => {
|
||||
const { model } = generation;
|
||||
@ -390,6 +409,121 @@ export const useRecallParameters = () => {
|
||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall ControlNet with toast
|
||||
*/
|
||||
|
||||
const { controlnets } = useGetControlNetModelsQuery(undefined, {
|
||||
selectFromResult: (result) => ({
|
||||
controlnets: result.data
|
||||
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
|
||||
: [],
|
||||
}),
|
||||
});
|
||||
|
||||
const prepareControlNetMetadataItem = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||
if (!isValidControlNetModel(controlnetMetadataItem.control_model)) {
|
||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||
}
|
||||
|
||||
const {
|
||||
image,
|
||||
control_model,
|
||||
control_weight,
|
||||
begin_step_percent,
|
||||
end_step_percent,
|
||||
control_mode,
|
||||
resize_mode,
|
||||
} = controlnetMetadataItem;
|
||||
|
||||
const matchingControlNetModel = controlnets.find(
|
||||
(c) =>
|
||||
c.base_model === control_model.base_model &&
|
||||
c.model_name === control_model.model_name
|
||||
);
|
||||
|
||||
if (!matchingControlNetModel) {
|
||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel =
|
||||
matchingControlNetModel?.base_model === model?.base_model;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
controlnet: null,
|
||||
error: 'ControlNet incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
const controlNetId = uuidv4();
|
||||
|
||||
let processorType = initialControlNet.processorType;
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
|
||||
processorType =
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
|
||||
initialControlNet.processorType;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
|
||||
const controlnet: ControlNetConfig = {
|
||||
isEnabled: true,
|
||||
model: matchingControlNetModel,
|
||||
weight:
|
||||
typeof control_weight === 'number'
|
||||
? control_weight
|
||||
: initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||
controlMode: control_mode || initialControlNet.controlMode,
|
||||
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode:
|
||||
processorNode.type !== 'none'
|
||||
? processorNode
|
||||
: initialControlNet.processorNode,
|
||||
shouldAutoConfig: true,
|
||||
controlNetId,
|
||||
};
|
||||
|
||||
return { controlnet, error: null };
|
||||
},
|
||||
[controlnets, model?.base_model]
|
||||
);
|
||||
|
||||
const recallControlNet = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
|
||||
|
||||
if (!result.controlnet) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
controlNetRecalled({
|
||||
...result.controlnet,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(controlNetEnabled());
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[
|
||||
prepareControlNetMetadataItem,
|
||||
dispatch,
|
||||
parameterSetToast,
|
||||
parameterNotSetToast,
|
||||
]
|
||||
);
|
||||
|
||||
/*
|
||||
* Sets image as initial image with toast
|
||||
*/
|
||||
@ -428,6 +562,7 @@ export const useRecallParameters = () => {
|
||||
refiner_negative_aesthetic_score,
|
||||
refiner_start,
|
||||
loras,
|
||||
controlnets,
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
@ -517,6 +652,15 @@ export const useRecallParameters = () => {
|
||||
}
|
||||
});
|
||||
|
||||
dispatch(controlNetReset());
|
||||
dispatch(controlNetEnabled());
|
||||
controlnets?.forEach((controlnet) => {
|
||||
const result = prepareControlNetMetadataItem(controlnet);
|
||||
if (result.controlnet) {
|
||||
dispatch(controlNetRecalled(result.controlnet));
|
||||
}
|
||||
});
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[
|
||||
@ -524,6 +668,7 @@ export const useRecallParameters = () => {
|
||||
allParameterSetToast,
|
||||
dispatch,
|
||||
prepareLoRAMetadataItem,
|
||||
prepareControlNetMetadataItem,
|
||||
]
|
||||
);
|
||||
|
||||
@ -542,6 +687,7 @@ export const useRecallParameters = () => {
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
recallAllParameters,
|
||||
sendToImageToImage,
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user