mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix multiple controlnets
This commit is contained in:
parent
b17f4c1650
commit
6896e69e95
@ -12,6 +12,7 @@ import {
|
|||||||
controlNetImageChanged,
|
controlNetImageChanged,
|
||||||
controlNetModelChanged,
|
controlNetModelChanged,
|
||||||
controlNetProcessedImageChanged,
|
controlNetProcessedImageChanged,
|
||||||
|
controlNetProcessorChanged,
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
controlNetToggled,
|
controlNetToggled,
|
||||||
controlNetWeightChanged,
|
controlNetWeightChanged,
|
||||||
@ -61,17 +62,21 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
controlImage,
|
controlImage,
|
||||||
isControlImageProcessed,
|
isControlImageProcessed,
|
||||||
processedControlImage,
|
processedControlImage,
|
||||||
|
processor,
|
||||||
} = props.controlNet;
|
} = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [processorType, setProcessorType] =
|
const handleProcessorTypeChanged = useCallback(
|
||||||
useState<ControlNetProcessor>('canny');
|
(processor: string | null | undefined) => {
|
||||||
|
dispatch(
|
||||||
const handleProcessorTypeChanged = (type: string | null | undefined) => {
|
controlNetProcessorChanged({
|
||||||
setProcessorType(type as ControlNetProcessor);
|
controlNetId,
|
||||||
};
|
processor: processor as ControlNetProcessor,
|
||||||
|
})
|
||||||
const { isOpen, onToggle } = useDisclosure();
|
);
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
const handleControlImageChanged = useCallback(
|
const handleControlImageChanged = useCallback(
|
||||||
(controlImage: ImageDTO) => {
|
(controlImage: ImageDTO) => {
|
||||||
@ -88,18 +93,6 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
dispatch(controlNetRemoved(controlNetId));
|
dispatch(controlNetRemoved(controlNetId));
|
||||||
}, [controlNetId, dispatch]);
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
const handleProcessedControlImageChanged = useCallback(
|
|
||||||
(processedControlImage: ImageDTO | null) => {
|
|
||||||
dispatch(
|
|
||||||
controlNetProcessedImageChanged({
|
|
||||||
controlNetId,
|
|
||||||
processedControlImage,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
[controlNetId, dispatch]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||||
<IAISelectableImage
|
<IAISelectableImage
|
||||||
@ -147,14 +140,14 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
<IAICustomSelect
|
<IAICustomSelect
|
||||||
label="Processor"
|
label="Processor"
|
||||||
items={CONTROLNET_PROCESSORS}
|
items={CONTROLNET_PROCESSORS}
|
||||||
selectedItem={processorType}
|
selectedItem={processor}
|
||||||
setSelectedItem={handleProcessorTypeChanged}
|
setSelectedItem={handleProcessorTypeChanged}
|
||||||
/>
|
/>
|
||||||
<ProcessorComponent
|
<ProcessorComponent
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
controlImage={controlImage}
|
controlImage={controlImage}
|
||||||
processedControlImage={processedControlImage}
|
processedControlImage={processedControlImage}
|
||||||
type={processorType}
|
type={processor}
|
||||||
/>
|
/>
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
|
@ -46,18 +46,20 @@ export const initialControlNet: Omit<ControlNet, 'controlNetId'> = {
|
|||||||
controlImage: null,
|
controlImage: null,
|
||||||
isControlImageProcessed: false,
|
isControlImageProcessed: false,
|
||||||
processedControlImage: null,
|
processedControlImage: null,
|
||||||
|
processor: 'canny',
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ControlNet = {
|
export type ControlNet = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
model: string;
|
model: ControlNetModel;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
controlImage: ImageDTO | null;
|
controlImage: ImageDTO | null;
|
||||||
isControlImageProcessed: boolean;
|
isControlImageProcessed: boolean;
|
||||||
processedControlImage: ImageDTO | null;
|
processedControlImage: ImageDTO | null;
|
||||||
|
processor: ControlNetProcessor;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ControlNetState = {
|
export type ControlNetState = {
|
||||||
@ -167,6 +169,16 @@ export const controlNetSlice = createSlice({
|
|||||||
const { controlNetId, endStepPct } = action.payload;
|
const { controlNetId, endStepPct } = action.payload;
|
||||||
state.controlNets[controlNetId].endStepPct = endStepPct;
|
state.controlNets[controlNetId].endStepPct = endStepPct;
|
||||||
},
|
},
|
||||||
|
controlNetProcessorChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{
|
||||||
|
controlNetId: string;
|
||||||
|
processor: ControlNetProcessor;
|
||||||
|
}>
|
||||||
|
) => {
|
||||||
|
const { controlNetId, processor } = action.payload;
|
||||||
|
state.controlNets[controlNetId].processor = processor;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -183,6 +195,7 @@ export const {
|
|||||||
controlNetWeightChanged,
|
controlNetWeightChanged,
|
||||||
controlNetBeginStepPctChanged,
|
controlNetBeginStepPctChanged,
|
||||||
controlNetEndStepPctChanged,
|
controlNetEndStepPctChanged,
|
||||||
|
controlNetProcessorChanged,
|
||||||
} = controlNetSlice.actions;
|
} = controlNetSlice.actions;
|
||||||
|
|
||||||
export default controlNetSlice.reducer;
|
export default controlNetSlice.reducer;
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
|
CollectInvocation,
|
||||||
CompelInvocation,
|
CompelInvocation,
|
||||||
|
ControlNetInvocation,
|
||||||
Graph,
|
Graph,
|
||||||
IterateInvocation,
|
IterateInvocation,
|
||||||
LatentsToImageInvocation,
|
LatentsToImageInvocation,
|
||||||
@ -10,6 +12,9 @@ import {
|
|||||||
TextToLatentsInvocation,
|
TextToLatentsInvocation,
|
||||||
} from 'services/api';
|
} from 'services/api';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { forEach, map, size } from 'lodash-es';
|
||||||
|
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
|
||||||
|
import { ControlNetModel } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
|
||||||
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||||
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||||
@ -19,7 +24,7 @@ const NOISE = 'noise';
|
|||||||
const RANDOM_INT = 'rand_int';
|
const RANDOM_INT = 'rand_int';
|
||||||
const RANGE_OF_SIZE = 'range_of_size';
|
const RANGE_OF_SIZE = 'range_of_size';
|
||||||
const ITERATE = 'iterate';
|
const ITERATE = 'iterate';
|
||||||
const CONTROL_NET = 'control_net';
|
const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Text to Image tab graph.
|
* Builds the Text to Image tab graph.
|
||||||
@ -39,6 +44,8 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
|
|||||||
shouldRandomizeSeed,
|
shouldRandomizeSeed,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
|
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
nodes: {},
|
nodes: {},
|
||||||
edges: [],
|
edges: [],
|
||||||
@ -309,5 +316,86 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add ControlNet
|
||||||
|
if (isControlNetEnabled) {
|
||||||
|
if (size(controlNets) > 1) {
|
||||||
|
const controlNetIterateNode: CollectInvocation = {
|
||||||
|
id: CONTROL_NET_COLLECT,
|
||||||
|
type: 'collect',
|
||||||
|
};
|
||||||
|
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: controlNetIterateNode.id, field: 'collection' },
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'control',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
forEach(controlNets, (controlNet, index) => {
|
||||||
|
const {
|
||||||
|
controlNetId,
|
||||||
|
isEnabled,
|
||||||
|
isControlImageProcessed,
|
||||||
|
controlImage,
|
||||||
|
processedControlImage,
|
||||||
|
beginStepPct,
|
||||||
|
endStepPct,
|
||||||
|
model,
|
||||||
|
processor,
|
||||||
|
weight,
|
||||||
|
} = controlNet;
|
||||||
|
|
||||||
|
const controlNetNode: ControlNetInvocation = {
|
||||||
|
id: `control_net_${controlNetId}`,
|
||||||
|
type: 'controlnet',
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
control_model: model as ControlNetInvocation['control_model'],
|
||||||
|
control_weight: weight,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (processedControlImage) {
|
||||||
|
// We've already processed the image in the app, so we can just use the processed image
|
||||||
|
const { image_name, image_origin } = processedControlImage;
|
||||||
|
controlNetNode.image = {
|
||||||
|
image_name,
|
||||||
|
image_origin,
|
||||||
|
};
|
||||||
|
} else if (controlImage) {
|
||||||
|
// The control image is preprocessed
|
||||||
|
const { image_name, image_origin } = controlImage;
|
||||||
|
controlNetNode.image = {
|
||||||
|
image_name,
|
||||||
|
image_origin,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// The control image is not processed, so we need to add a preprocess node
|
||||||
|
// TODO: Add preprocess node
|
||||||
|
}
|
||||||
|
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||||
|
|
||||||
|
if (size(controlNets) > 1) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
|
destination: {
|
||||||
|
node_id: CONTROL_NET_COLLECT,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'control',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user