diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index f46c1fea4b..0626b08fd9 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -12,6 +12,7 @@ import { controlNetImageChanged, controlNetModelChanged, controlNetProcessedImageChanged, + controlNetProcessorChanged, controlNetRemoved, controlNetToggled, controlNetWeightChanged, @@ -61,17 +62,21 @@ const ControlNet = (props: ControlNetProps) => { controlImage, isControlImageProcessed, processedControlImage, + processor, } = props.controlNet; const dispatch = useAppDispatch(); - const [processorType, setProcessorType] = - useState('canny'); - - const handleProcessorTypeChanged = (type: string | null | undefined) => { - setProcessorType(type as ControlNetProcessor); - }; - - const { isOpen, onToggle } = useDisclosure(); + const handleProcessorTypeChanged = useCallback( + (processor: string | null | undefined) => { + dispatch( + controlNetProcessorChanged({ + controlNetId, + processor: processor as ControlNetProcessor, + }) + ); + }, + [controlNetId, dispatch] + ); const handleControlImageChanged = useCallback( (controlImage: ImageDTO) => { @@ -88,18 +93,6 @@ const ControlNet = (props: ControlNetProps) => { dispatch(controlNetRemoved(controlNetId)); }, [controlNetId, dispatch]); - const handleProcessedControlImageChanged = useCallback( - (processedControlImage: ImageDTO | null) => { - dispatch( - controlNetProcessedImageChanged({ - controlNetId, - processedControlImage, - }) - ); - }, - [controlNetId, dispatch] - ); - return ( { diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index dbb45c25f1..a87b591bad 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -46,18 +46,20 @@ export const initialControlNet: Omit = { controlImage: null, isControlImageProcessed: false, processedControlImage: null, + processor: 'canny', }; export type ControlNet = { controlNetId: string; isEnabled: boolean; - model: string; + model: ControlNetModel; weight: number; beginStepPct: number; endStepPct: number; controlImage: ImageDTO | null; isControlImageProcessed: boolean; processedControlImage: ImageDTO | null; + processor: ControlNetProcessor; }; export type ControlNetState = { @@ -167,6 +169,16 @@ export const controlNetSlice = createSlice({ const { controlNetId, endStepPct } = action.payload; state.controlNets[controlNetId].endStepPct = endStepPct; }, + controlNetProcessorChanged: ( + state, + action: PayloadAction<{ + controlNetId: string; + processor: ControlNetProcessor; + }> + ) => { + const { controlNetId, processor } = action.payload; + state.controlNets[controlNetId].processor = processor; + }, }, }); @@ -183,6 +195,7 @@ export const { controlNetWeightChanged, controlNetBeginStepPctChanged, controlNetEndStepPctChanged, + controlNetProcessorChanged, } = controlNetSlice.actions; export default controlNetSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts index 65c205f9a4..d52310abdd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts @@ -1,6 +1,8 @@ import { RootState } from 'app/store/store'; import { + CollectInvocation, CompelInvocation, + ControlNetInvocation, Graph, IterateInvocation, LatentsToImageInvocation, @@ -10,6 +12,9 @@ import { TextToLatentsInvocation, } from 'services/api'; import { NonNullableGraph } from 'features/nodes/types/types'; +import { forEach, map, size } from 'lodash-es'; +import { ControlNetProcessorNode } from 'features/controlNet/store/types'; +import { ControlNetModel } from 'features/controlNet/store/controlNetSlice'; const POSITIVE_CONDITIONING = 'positive_conditioning'; const NEGATIVE_CONDITIONING = 'negative_conditioning'; @@ -19,7 +24,7 @@ const NOISE = 'noise'; const RANDOM_INT = 'rand_int'; const RANGE_OF_SIZE = 'range_of_size'; const ITERATE = 'iterate'; -const CONTROL_NET = 'control_net'; +const CONTROL_NET_COLLECT = 'control_net_collect'; /** * Builds the Text to Image tab graph. @@ -39,6 +44,8 @@ export const buildTextToImageGraph = (state: RootState): Graph => { shouldRandomizeSeed, } = state.generation; + const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; + const graph: NonNullableGraph = { nodes: {}, edges: [], @@ -309,5 +316,86 @@ export const buildTextToImageGraph = (state: RootState): Graph => { }, }); } + + // Add ControlNet + if (isControlNetEnabled) { + if (size(controlNets) > 1) { + const controlNetIterateNode: CollectInvocation = { + id: CONTROL_NET_COLLECT, + type: 'collect', + }; + graph.nodes[controlNetIterateNode.id] = controlNetIterateNode; + graph.edges.push({ + source: { node_id: controlNetIterateNode.id, field: 'collection' }, + destination: { + node_id: TEXT_TO_LATENTS, + field: 'control', + }, + }); + } + + forEach(controlNets, (controlNet, index) => { + const { + controlNetId, + isEnabled, + isControlImageProcessed, + controlImage, + processedControlImage, + beginStepPct, + endStepPct, + model, + processor, + weight, + } = controlNet; + + const controlNetNode: ControlNetInvocation = { + id: `control_net_${controlNetId}`, + type: 'controlnet', + begin_step_percent: beginStepPct, + end_step_percent: endStepPct, + control_model: model as ControlNetInvocation['control_model'], + control_weight: weight, + }; + + if (processedControlImage) { + // We've already processed the image in the app, so we can just use the processed image + const { image_name, image_origin } = processedControlImage; + controlNetNode.image = { + image_name, + image_origin, + }; + } else if (controlImage) { + // The control image is preprocessed + const { image_name, image_origin } = controlImage; + controlNetNode.image = { + image_name, + image_origin, + }; + } else { + // The control image is not processed, so we need to add a preprocess node + // TODO: Add preprocess node + } + graph.nodes[controlNetNode.id] = controlNetNode; + + if (size(controlNets) > 1) { + graph.edges.push({ + source: { node_id: controlNetNode.id, field: 'control' }, + destination: { + node_id: CONTROL_NET_COLLECT, + field: 'item', + }, + }); + } else { + graph.edges.push({ + source: { node_id: controlNetNode.id, field: 'control' }, + destination: { + node_id: TEXT_TO_LATENTS, + field: 'control', + }, + }); + } + }); + } + return graph; };