feat(ui): refactor control adapters

Control adapters logic/state/ui is now generalized to hold controlnet, ip_adapter and t2i_adapter. In the future, other control adapter types can be added.

TODO:
- Limit IP adapter to 1
- Add T2I adapter to linear graphs
- Fix autoprocess
- T2I metadata saving & recall
- Improve on control adapters UI
This commit is contained in:
psychedelicious
2023-10-05 22:40:21 +11:00
parent 9c720da021
commit 9508e0c9db
70 changed files with 1860 additions and 1236 deletions

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { selectValidControlNets } from 'features/controlNet/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import {
CollectInvocation,
@ -19,102 +19,101 @@ export const addControlNetToLinearGraph = (
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = getValidControlNets(controlNets);
const validControlNets = selectValidControlNets(state.controlAdapters);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
if (validControlNets.length) {
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
validControlNets.forEach((controlNet) => {
if (!controlNet.model) {
return;
}
const {
id,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${id}`,
type: 'controlnet',
is_intermediate: true,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
resize_mode: resizeMode,
control_model: model,
control_weight: weight,
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
if (metadataAccumulator?.controlnets) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
}
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: baseNodeId,
field: 'control',
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
validControlNets.forEach((controlNet) => {
const {
controlNetId,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
is_intermediate: true,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
resize_mode: resizeMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
if (metadataAccumulator?.controlnets) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
}
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'control',
},
});
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'control',
},
});
}
});
}
}
});
}
};

View File

@ -9,35 +9,37 @@ import {
IP_ADAPTER,
METADATA_ACCUMULATOR,
} from './constants';
import { selectValidIPAdapters } from 'features/controlNet/store/controlAdaptersSlice';
export const addIPAdapterToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
const validIPAdapters = selectValidIPAdapters(state.controlAdapters);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isIPAdapterEnabled && ipAdapterInfo.model) {
const ipAdapter = validIPAdapters[0];
// TODO: handle multiple IP adapters once backend is capable
if (ipAdapter && ipAdapter.model) {
const { weight, model, beginStepPct, endStepPct } = ipAdapter;
const ipAdapterNode: IPAdapterInvocation = {
id: IP_ADAPTER,
type: 'ip_adapter',
is_intermediate: true,
weight: ipAdapterInfo.weight,
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
weight: weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
if (ipAdapterInfo.adapterImage) {
if (ipAdapter.controlImage) {
ipAdapterNode.image = {
image_name: ipAdapterInfo.adapterImage,
image_name: ipAdapter.controlImage,
};
} else {
return;
@ -47,15 +49,12 @@ export const addIPAdapterToLinearGraph = (
if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = {
image: {
image_name: ipAdapterInfo.adapterImage,
image_name: ipAdapter.controlImage,
},
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
weight: ipAdapterInfo.weight,
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
metadataAccumulator.ipAdapters.push(ipAdapterField);