fix(ui): fix multiple controlnets

This commit is contained in:
psychedelicious 2023-06-02 00:45:07 +10:00
parent b17f4c1650
commit 6896e69e95
3 changed files with 118 additions and 24 deletions

View File

@ -12,6 +12,7 @@ import {
controlNetImageChanged,
controlNetModelChanged,
controlNetProcessedImageChanged,
controlNetProcessorChanged,
controlNetRemoved,
controlNetToggled,
controlNetWeightChanged,
@ -61,17 +62,21 @@ const ControlNet = (props: ControlNetProps) => {
controlImage,
isControlImageProcessed,
processedControlImage,
processor,
} = props.controlNet;
const dispatch = useAppDispatch();
const [processorType, setProcessorType] =
useState<ControlNetProcessor>('canny');
const handleProcessorTypeChanged = (type: string | null | undefined) => {
setProcessorType(type as ControlNetProcessor);
};
const { isOpen, onToggle } = useDisclosure();
const handleProcessorTypeChanged = useCallback(
(processor: string | null | undefined) => {
dispatch(
controlNetProcessorChanged({
controlNetId,
processor: processor as ControlNetProcessor,
})
);
},
[controlNetId, dispatch]
);
const handleControlImageChanged = useCallback(
(controlImage: ImageDTO) => {
@ -88,18 +93,6 @@ const ControlNet = (props: ControlNetProps) => {
dispatch(controlNetRemoved(controlNetId));
}, [controlNetId, dispatch]);
const handleProcessedControlImageChanged = useCallback(
(processedControlImage: ImageDTO | null) => {
dispatch(
controlNetProcessedImageChanged({
controlNetId,
processedControlImage,
})
);
},
[controlNetId, dispatch]
);
return (
<Flex sx={{ flexDir: 'column', gap: 3 }}>
<IAISelectableImage
@ -147,14 +140,14 @@ const ControlNet = (props: ControlNetProps) => {
<IAICustomSelect
label="Processor"
items={CONTROLNET_PROCESSORS}
selectedItem={processorType}
selectedItem={processor}
setSelectedItem={handleProcessorTypeChanged}
/>
<ProcessorComponent
controlNetId={controlNetId}
controlImage={controlImage}
processedControlImage={processedControlImage}
type={processorType}
type={processor}
/>
</TabPanel>
</TabPanels>

View File

@ -46,18 +46,20 @@ export const initialControlNet: Omit<ControlNet, 'controlNetId'> = {
controlImage: null,
isControlImageProcessed: false,
processedControlImage: null,
processor: 'canny',
};
export type ControlNet = {
controlNetId: string;
isEnabled: boolean;
model: string;
model: ControlNetModel;
weight: number;
beginStepPct: number;
endStepPct: number;
controlImage: ImageDTO | null;
isControlImageProcessed: boolean;
processedControlImage: ImageDTO | null;
processor: ControlNetProcessor;
};
export type ControlNetState = {
@ -167,6 +169,16 @@ export const controlNetSlice = createSlice({
const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct;
},
controlNetProcessorChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processor: ControlNetProcessor;
}>
) => {
const { controlNetId, processor } = action.payload;
state.controlNets[controlNetId].processor = processor;
},
},
});
@ -183,6 +195,7 @@ export const {
controlNetWeightChanged,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetProcessorChanged,
} = controlNetSlice.actions;
export default controlNetSlice.reducer;

View File

@ -1,6 +1,8 @@
import { RootState } from 'app/store/store';
import {
CollectInvocation,
CompelInvocation,
ControlNetInvocation,
Graph,
IterateInvocation,
LatentsToImageInvocation,
@ -10,6 +12,9 @@ import {
TextToLatentsInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, map, size } from 'lodash-es';
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
import { ControlNetModel } from 'features/controlNet/store/controlNetSlice';
const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning';
@ -19,7 +24,7 @@ const NOISE = 'noise';
const RANDOM_INT = 'rand_int';
const RANGE_OF_SIZE = 'range_of_size';
const ITERATE = 'iterate';
const CONTROL_NET = 'control_net';
const CONTROL_NET_COLLECT = 'control_net_collect';
/**
* Builds the Text to Image tab graph.
@ -39,6 +44,8 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
shouldRandomizeSeed,
} = state.generation;
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const graph: NonNullableGraph = {
nodes: {},
edges: [],
@ -309,5 +316,86 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
},
});
}
// Add ControlNet
if (isControlNetEnabled) {
if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
};
graph.nodes[controlNetIterateNode.id] = controlNetIterateNode;
graph.edges.push({
source: { node_id: controlNetIterateNode.id, field: 'collection' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'control',
},
});
}
forEach(controlNets, (controlNet, index) => {
const {
controlNetId,
isEnabled,
isControlImageProcessed,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
model,
processor,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage) {
// We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else if (controlImage) {
// The control image is preprocessed
const { image_name, image_origin } = controlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else {
// The control image is not processed, so we need to add a preprocess node
// TODO: Add preprocess node
}
graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
} else {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'control',
},
});
}
});
}
return graph;
};