mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix multiple controlnets
This commit is contained in:
parent
b17f4c1650
commit
6896e69e95
@ -12,6 +12,7 @@ import {
|
||||
controlNetImageChanged,
|
||||
controlNetModelChanged,
|
||||
controlNetProcessedImageChanged,
|
||||
controlNetProcessorChanged,
|
||||
controlNetRemoved,
|
||||
controlNetToggled,
|
||||
controlNetWeightChanged,
|
||||
@ -61,17 +62,21 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
controlImage,
|
||||
isControlImageProcessed,
|
||||
processedControlImage,
|
||||
processor,
|
||||
} = props.controlNet;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [processorType, setProcessorType] =
|
||||
useState<ControlNetProcessor>('canny');
|
||||
|
||||
const handleProcessorTypeChanged = (type: string | null | undefined) => {
|
||||
setProcessorType(type as ControlNetProcessor);
|
||||
};
|
||||
|
||||
const { isOpen, onToggle } = useDisclosure();
|
||||
const handleProcessorTypeChanged = useCallback(
|
||||
(processor: string | null | undefined) => {
|
||||
dispatch(
|
||||
controlNetProcessorChanged({
|
||||
controlNetId,
|
||||
processor: processor as ControlNetProcessor,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
const handleControlImageChanged = useCallback(
|
||||
(controlImage: ImageDTO) => {
|
||||
@ -88,18 +93,6 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
dispatch(controlNetRemoved(controlNetId));
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
const handleProcessedControlImageChanged = useCallback(
|
||||
(processedControlImage: ImageDTO | null) => {
|
||||
dispatch(
|
||||
controlNetProcessedImageChanged({
|
||||
controlNetId,
|
||||
processedControlImage,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||
<IAISelectableImage
|
||||
@ -147,14 +140,14 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
<IAICustomSelect
|
||||
label="Processor"
|
||||
items={CONTROLNET_PROCESSORS}
|
||||
selectedItem={processorType}
|
||||
selectedItem={processor}
|
||||
setSelectedItem={handleProcessorTypeChanged}
|
||||
/>
|
||||
<ProcessorComponent
|
||||
controlNetId={controlNetId}
|
||||
controlImage={controlImage}
|
||||
processedControlImage={processedControlImage}
|
||||
type={processorType}
|
||||
type={processor}
|
||||
/>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
|
@ -46,18 +46,20 @@ export const initialControlNet: Omit<ControlNet, 'controlNetId'> = {
|
||||
controlImage: null,
|
||||
isControlImageProcessed: false,
|
||||
processedControlImage: null,
|
||||
processor: 'canny',
|
||||
};
|
||||
|
||||
export type ControlNet = {
|
||||
controlNetId: string;
|
||||
isEnabled: boolean;
|
||||
model: string;
|
||||
model: ControlNetModel;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
controlImage: ImageDTO | null;
|
||||
isControlImageProcessed: boolean;
|
||||
processedControlImage: ImageDTO | null;
|
||||
processor: ControlNetProcessor;
|
||||
};
|
||||
|
||||
export type ControlNetState = {
|
||||
@ -167,6 +169,16 @@ export const controlNetSlice = createSlice({
|
||||
const { controlNetId, endStepPct } = action.payload;
|
||||
state.controlNets[controlNetId].endStepPct = endStepPct;
|
||||
},
|
||||
controlNetProcessorChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
processor: ControlNetProcessor;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, processor } = action.payload;
|
||||
state.controlNets[controlNetId].processor = processor;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@ -183,6 +195,7 @@ export const {
|
||||
controlNetWeightChanged,
|
||||
controlNetBeginStepPctChanged,
|
||||
controlNetEndStepPctChanged,
|
||||
controlNetProcessorChanged,
|
||||
} = controlNetSlice.actions;
|
||||
|
||||
export default controlNetSlice.reducer;
|
||||
|
@ -1,6 +1,8 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
CollectInvocation,
|
||||
CompelInvocation,
|
||||
ControlNetInvocation,
|
||||
Graph,
|
||||
IterateInvocation,
|
||||
LatentsToImageInvocation,
|
||||
@ -10,6 +12,9 @@ import {
|
||||
TextToLatentsInvocation,
|
||||
} from 'services/api';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, map, size } from 'lodash-es';
|
||||
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
|
||||
import { ControlNetModel } from 'features/controlNet/store/controlNetSlice';
|
||||
|
||||
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||
@ -19,7 +24,7 @@ const NOISE = 'noise';
|
||||
const RANDOM_INT = 'rand_int';
|
||||
const RANGE_OF_SIZE = 'range_of_size';
|
||||
const ITERATE = 'iterate';
|
||||
const CONTROL_NET = 'control_net';
|
||||
const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||
|
||||
/**
|
||||
* Builds the Text to Image tab graph.
|
||||
@ -39,6 +44,8 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
nodes: {},
|
||||
edges: [],
|
||||
@ -309,5 +316,86 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Add ControlNet
|
||||
if (isControlNetEnabled) {
|
||||
if (size(controlNets) > 1) {
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetIterateNode.id, field: 'collection' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
forEach(controlNets, (controlNet, index) => {
|
||||
const {
|
||||
controlNetId,
|
||||
isEnabled,
|
||||
isControlImageProcessed,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
model,
|
||||
processor,
|
||||
weight,
|
||||
} = controlNet;
|
||||
|
||||
const controlNetNode: ControlNetInvocation = {
|
||||
id: `control_net_${controlNetId}`,
|
||||
type: 'controlnet',
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_model: model as ControlNetInvocation['control_model'],
|
||||
control_weight: weight,
|
||||
};
|
||||
|
||||
if (processedControlImage) {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
const { image_name, image_origin } = processedControlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_origin,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
const { image_name, image_origin } = controlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_origin,
|
||||
};
|
||||
} else {
|
||||
// The control image is not processed, so we need to add a preprocess node
|
||||
// TODO: Add preprocess node
|
||||
}
|
||||
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||
|
||||
if (size(controlNets) > 1) {
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: CONTROL_NET_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user