mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use controlnet from metadata if available (#4658)
* add control net to useRecallParams * got recall controlnets working * fix metadata viewer controlnet * fix type errors * fix controlnet metadata viewer * set control image and use correct processor type and node * clean up logs * recall processor using substring * feat(ui): enable controlNet when recalling one --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
3432fd72f8
commit
4a0a1c30db
@ -98,6 +98,9 @@ export const controlNetSlice = createSlice({
|
|||||||
isControlNetEnabledToggled: (state) => {
|
isControlNetEnabledToggled: (state) => {
|
||||||
state.isEnabled = !state.isEnabled;
|
state.isEnabled = !state.isEnabled;
|
||||||
},
|
},
|
||||||
|
controlNetEnabled: (state) => {
|
||||||
|
state.isEnabled = true;
|
||||||
|
},
|
||||||
controlNetAdded: (
|
controlNetAdded: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -111,6 +114,12 @@ export const controlNetSlice = createSlice({
|
|||||||
controlNetId,
|
controlNetId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
|
||||||
|
const controlNet = action.payload;
|
||||||
|
state.controlNets[controlNet.controlNetId] = {
|
||||||
|
...controlNet,
|
||||||
|
};
|
||||||
|
},
|
||||||
controlNetDuplicated: (
|
controlNetDuplicated: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -439,7 +448,9 @@ export const controlNetSlice = createSlice({
|
|||||||
|
|
||||||
export const {
|
export const {
|
||||||
isControlNetEnabledToggled,
|
isControlNetEnabledToggled,
|
||||||
|
controlNetEnabled,
|
||||||
controlNetAdded,
|
controlNetAdded,
|
||||||
|
controlNetRecalled,
|
||||||
controlNetDuplicated,
|
controlNetDuplicated,
|
||||||
controlNetAddedFromImage,
|
controlNetAddedFromImage,
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
|
@ -1,8 +1,15 @@
|
|||||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
import {
|
||||||
|
ControlNetMetadataItem,
|
||||||
|
CoreMetadata,
|
||||||
|
LoRAMetadataItem,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useMemo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
import {
|
||||||
|
isValidControlNetModel,
|
||||||
|
isValidLoRAModel,
|
||||||
|
} from '../../../parameters/types/parameterSchemas';
|
||||||
import ImageMetadataItem from './ImageMetadataItem';
|
import ImageMetadataItem from './ImageMetadataItem';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallHeight,
|
recallHeight,
|
||||||
recallStrength,
|
recallStrength,
|
||||||
recallLoRA,
|
recallLoRA,
|
||||||
|
recallControlNet,
|
||||||
} = useRecallParameters();
|
} = useRecallParameters();
|
||||||
|
|
||||||
const handleRecallPositivePrompt = useCallback(() => {
|
const handleRecallPositivePrompt = useCallback(() => {
|
||||||
@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
[recallLoRA]
|
[recallLoRA]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleRecallControlNet = useCallback(
|
||||||
|
(controlnet: ControlNetMetadataItem) => {
|
||||||
|
recallControlNet(controlnet);
|
||||||
|
},
|
||||||
|
[recallControlNet]
|
||||||
|
);
|
||||||
|
|
||||||
|
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||||
|
return metadata?.controlnets
|
||||||
|
? metadata.controlnets.filter((controlnet) =>
|
||||||
|
isValidControlNetModel(controlnet.control_model)
|
||||||
|
)
|
||||||
|
: [];
|
||||||
|
}, [metadata?.controlnets]);
|
||||||
|
|
||||||
if (!metadata || Object.keys(metadata).length === 0) {
|
if (!metadata || Object.keys(metadata).length === 0) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
})}
|
})}
|
||||||
|
{validControlNets.map((controlnet, index) => (
|
||||||
|
<ImageMetadataItem
|
||||||
|
key={index}
|
||||||
|
label="ControlNet"
|
||||||
|
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
|
||||||
|
onClick={() => handleRecallControlNet(controlnet)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({
|
|||||||
|
|
||||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||||
|
|
||||||
|
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||||
|
|
||||||
|
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||||
|
|
||||||
export const zCoreMetadata = z
|
export const zCoreMetadata = z
|
||||||
.object({
|
.object({
|
||||||
app_version: z.string().nullish().catch(null),
|
app_version: z.string().nullish().catch(null),
|
||||||
|
@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
import {
|
||||||
|
CoreMetadata,
|
||||||
|
LoRAMetadataItem,
|
||||||
|
ControlNetMetadataItem,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
setNegativeStylePromptSDXL,
|
setNegativeStylePromptSDXL,
|
||||||
@ -18,9 +22,18 @@ import { useCallback } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
|
controlNetModelsAdapter,
|
||||||
loraModelsAdapter,
|
loraModelsAdapter,
|
||||||
|
useGetControlNetModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
} from '../../../services/api/endpoints/models';
|
} from '../../../services/api/endpoints/models';
|
||||||
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetEnabled,
|
||||||
|
controlNetRecalled,
|
||||||
|
controlNetReset,
|
||||||
|
initialControlNet,
|
||||||
|
} from '../../controlNet/store/controlNetSlice';
|
||||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||||
import {
|
import {
|
||||||
@ -38,6 +51,7 @@ import {
|
|||||||
isValidCfgScale,
|
isValidCfgScale,
|
||||||
isValidHeight,
|
isValidHeight,
|
||||||
isValidLoRAModel,
|
isValidLoRAModel,
|
||||||
|
isValidControlNetModel,
|
||||||
isValidMainModel,
|
isValidMainModel,
|
||||||
isValidNegativePrompt,
|
isValidNegativePrompt,
|
||||||
isValidPositivePrompt,
|
isValidPositivePrompt,
|
||||||
@ -53,6 +67,11 @@ import {
|
|||||||
isValidStrength,
|
isValidStrength,
|
||||||
isValidWidth,
|
isValidWidth,
|
||||||
} from '../types/parameterSchemas';
|
} from '../types/parameterSchemas';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import {
|
||||||
|
CONTROLNET_PROCESSORS,
|
||||||
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||||
|
} from 'features/controlNet/store/constants';
|
||||||
|
|
||||||
const selector = createSelector(stateSelector, ({ generation }) => {
|
const selector = createSelector(stateSelector, ({ generation }) => {
|
||||||
const { model } = generation;
|
const { model } = generation;
|
||||||
@ -390,6 +409,121 @@ export const useRecallParameters = () => {
|
|||||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recall ControlNet with toast
|
||||||
|
*/
|
||||||
|
|
||||||
|
const { controlnets } = useGetControlNetModelsQuery(undefined, {
|
||||||
|
selectFromResult: (result) => ({
|
||||||
|
controlnets: result.data
|
||||||
|
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
|
||||||
|
: [],
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const prepareControlNetMetadataItem = useCallback(
|
||||||
|
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||||
|
if (!isValidControlNetModel(controlnetMetadataItem.control_model)) {
|
||||||
|
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const {
|
||||||
|
image,
|
||||||
|
control_model,
|
||||||
|
control_weight,
|
||||||
|
begin_step_percent,
|
||||||
|
end_step_percent,
|
||||||
|
control_mode,
|
||||||
|
resize_mode,
|
||||||
|
} = controlnetMetadataItem;
|
||||||
|
|
||||||
|
const matchingControlNetModel = controlnets.find(
|
||||||
|
(c) =>
|
||||||
|
c.base_model === control_model.base_model &&
|
||||||
|
c.model_name === control_model.model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!matchingControlNetModel) {
|
||||||
|
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const isCompatibleBaseModel =
|
||||||
|
matchingControlNetModel?.base_model === model?.base_model;
|
||||||
|
|
||||||
|
if (!isCompatibleBaseModel) {
|
||||||
|
return {
|
||||||
|
controlnet: null,
|
||||||
|
error: 'ControlNet incompatible with currently-selected model',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const controlNetId = uuidv4();
|
||||||
|
|
||||||
|
let processorType = initialControlNet.processorType;
|
||||||
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
|
||||||
|
processorType =
|
||||||
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
|
||||||
|
initialControlNet.processorType;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||||
|
|
||||||
|
const controlnet: ControlNetConfig = {
|
||||||
|
isEnabled: true,
|
||||||
|
model: matchingControlNetModel,
|
||||||
|
weight:
|
||||||
|
typeof control_weight === 'number'
|
||||||
|
? control_weight
|
||||||
|
: initialControlNet.weight,
|
||||||
|
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||||
|
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||||
|
controlMode: control_mode || initialControlNet.controlMode,
|
||||||
|
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||||
|
controlImage: image?.image_name || null,
|
||||||
|
processedControlImage: image?.image_name || null,
|
||||||
|
processorType,
|
||||||
|
processorNode:
|
||||||
|
processorNode.type !== 'none'
|
||||||
|
? processorNode
|
||||||
|
: initialControlNet.processorNode,
|
||||||
|
shouldAutoConfig: true,
|
||||||
|
controlNetId,
|
||||||
|
};
|
||||||
|
|
||||||
|
return { controlnet, error: null };
|
||||||
|
},
|
||||||
|
[controlnets, model?.base_model]
|
||||||
|
);
|
||||||
|
|
||||||
|
const recallControlNet = useCallback(
|
||||||
|
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||||
|
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
|
||||||
|
|
||||||
|
if (!result.controlnet) {
|
||||||
|
parameterNotSetToast(result.error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
controlNetRecalled({
|
||||||
|
...result.controlnet,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(controlNetEnabled());
|
||||||
|
|
||||||
|
parameterSetToast();
|
||||||
|
},
|
||||||
|
[
|
||||||
|
prepareControlNetMetadataItem,
|
||||||
|
dispatch,
|
||||||
|
parameterSetToast,
|
||||||
|
parameterNotSetToast,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Sets image as initial image with toast
|
* Sets image as initial image with toast
|
||||||
*/
|
*/
|
||||||
@ -428,6 +562,7 @@ export const useRecallParameters = () => {
|
|||||||
refiner_negative_aesthetic_score,
|
refiner_negative_aesthetic_score,
|
||||||
refiner_start,
|
refiner_start,
|
||||||
loras,
|
loras,
|
||||||
|
controlnets,
|
||||||
} = metadata;
|
} = metadata;
|
||||||
|
|
||||||
if (isValidCfgScale(cfg_scale)) {
|
if (isValidCfgScale(cfg_scale)) {
|
||||||
@ -517,6 +652,15 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
dispatch(controlNetReset());
|
||||||
|
dispatch(controlNetEnabled());
|
||||||
|
controlnets?.forEach((controlnet) => {
|
||||||
|
const result = prepareControlNetMetadataItem(controlnet);
|
||||||
|
if (result.controlnet) {
|
||||||
|
dispatch(controlNetRecalled(result.controlnet));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
allParameterSetToast();
|
allParameterSetToast();
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
@ -524,6 +668,7 @@ export const useRecallParameters = () => {
|
|||||||
allParameterSetToast,
|
allParameterSetToast,
|
||||||
dispatch,
|
dispatch,
|
||||||
prepareLoRAMetadataItem,
|
prepareLoRAMetadataItem,
|
||||||
|
prepareControlNetMetadataItem,
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -542,6 +687,7 @@ export const useRecallParameters = () => {
|
|||||||
recallHeight,
|
recallHeight,
|
||||||
recallStrength,
|
recallStrength,
|
||||||
recallLoRA,
|
recallLoRA,
|
||||||
|
recallControlNet,
|
||||||
recallAllParameters,
|
recallAllParameters,
|
||||||
sendToImageToImage,
|
sendToImageToImage,
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user