mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): refactor control adapters
Control adapters logic/state/ui is now generalized to hold controlnet, ip_adapter and t2i_adapter. In the future, other control adapter types can be added. TODO: - Limit IP adapter to 1 - Add T2I adapter to linear graphs - Fix autoprocess - T2I metadata saving & recall - Improve on control adapters UI
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||
import { selectValidControlNets } from 'features/controlNet/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import {
|
||||
CollectInvocation,
|
||||
@ -19,102 +19,101 @@ export const addControlNetToLinearGraph = (
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||
|
||||
const validControlNets = getValidControlNets(controlNets);
|
||||
const validControlNets = selectValidControlNets(state.controlAdapters);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (isControlNetEnabled && Boolean(validControlNets.length)) {
|
||||
if (validControlNets.length) {
|
||||
// We have multiple controlnets, add ControlNet collector
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
if (validControlNets.length) {
|
||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
|
||||
destination: {
|
||||
node_id: baseNodeId,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
|
||||
validControlNets.forEach((controlNet) => {
|
||||
if (!controlNet.model) {
|
||||
return;
|
||||
}
|
||||
const {
|
||||
id,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
resizeMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
} = controlNet;
|
||||
|
||||
const controlNetNode: ControlNetInvocation = {
|
||||
id: `control_net_${id}`,
|
||||
type: 'controlnet',
|
||||
is_intermediate: true,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_mode: controlMode,
|
||||
resize_mode: resizeMode,
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
};
|
||||
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
controlNetNode.image = {
|
||||
image_name: processedControlImage,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
controlNetNode.image = {
|
||||
image_name: controlImage,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||
|
||||
if (metadataAccumulator?.controlnets) {
|
||||
// metadata accumulator only needs a control field - not the whole node
|
||||
// extract what we need and add to the accumulator
|
||||
const controlField = omit(controlNetNode, [
|
||||
'id',
|
||||
'type',
|
||||
]) as ControlField;
|
||||
metadataAccumulator.controlnets.push(controlField);
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: baseNodeId,
|
||||
field: 'control',
|
||||
node_id: CONTROL_NET_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
|
||||
validControlNets.forEach((controlNet) => {
|
||||
const {
|
||||
controlNetId,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
resizeMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
} = controlNet;
|
||||
|
||||
const controlNetNode: ControlNetInvocation = {
|
||||
id: `control_net_${controlNetId}`,
|
||||
type: 'controlnet',
|
||||
is_intermediate: true,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
control_mode: controlMode,
|
||||
resize_mode: resizeMode,
|
||||
control_model: model as ControlNetInvocation['control_model'],
|
||||
control_weight: weight,
|
||||
};
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
controlNetNode.image = {
|
||||
image_name: processedControlImage,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
controlNetNode.image = {
|
||||
image_name: controlImage,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||
|
||||
if (metadataAccumulator?.controlnets) {
|
||||
// metadata accumulator only needs a control field - not the whole node
|
||||
// extract what we need and add to the accumulator
|
||||
const controlField = omit(controlNetNode, [
|
||||
'id',
|
||||
'type',
|
||||
]) as ControlField;
|
||||
metadataAccumulator.controlnets.push(controlField);
|
||||
}
|
||||
|
||||
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: CONTROL_NET_COLLECT,
|
||||
field: 'item',
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
|
||||
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -9,35 +9,37 @@ import {
|
||||
IP_ADAPTER,
|
||||
METADATA_ACCUMULATOR,
|
||||
} from './constants';
|
||||
import { selectValidIPAdapters } from 'features/controlNet/store/controlAdaptersSlice';
|
||||
|
||||
export const addIPAdapterToLinearGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (isIPAdapterEnabled && ipAdapterInfo.model) {
|
||||
const ipAdapter = validIPAdapters[0];
|
||||
|
||||
// TODO: handle multiple IP adapters once backend is capable
|
||||
if (ipAdapter && ipAdapter.model) {
|
||||
const { weight, model, beginStepPct, endStepPct } = ipAdapter;
|
||||
const ipAdapterNode: IPAdapterInvocation = {
|
||||
id: IP_ADAPTER,
|
||||
type: 'ip_adapter',
|
||||
is_intermediate: true,
|
||||
weight: ipAdapterInfo.weight,
|
||||
ip_adapter_model: {
|
||||
base_model: ipAdapterInfo.model?.base_model,
|
||||
model_name: ipAdapterInfo.model?.model_name,
|
||||
},
|
||||
begin_step_percent: ipAdapterInfo.beginStepPct,
|
||||
end_step_percent: ipAdapterInfo.endStepPct,
|
||||
weight: weight,
|
||||
ip_adapter_model: model,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
};
|
||||
|
||||
if (ipAdapterInfo.adapterImage) {
|
||||
if (ipAdapter.controlImage) {
|
||||
ipAdapterNode.image = {
|
||||
image_name: ipAdapterInfo.adapterImage,
|
||||
image_name: ipAdapter.controlImage,
|
||||
};
|
||||
} else {
|
||||
return;
|
||||
@ -47,15 +49,12 @@ export const addIPAdapterToLinearGraph = (
|
||||
if (metadataAccumulator?.ipAdapters) {
|
||||
const ipAdapterField = {
|
||||
image: {
|
||||
image_name: ipAdapterInfo.adapterImage,
|
||||
image_name: ipAdapter.controlImage,
|
||||
},
|
||||
ip_adapter_model: {
|
||||
base_model: ipAdapterInfo.model?.base_model,
|
||||
model_name: ipAdapterInfo.model?.model_name,
|
||||
},
|
||||
weight: ipAdapterInfo.weight,
|
||||
begin_step_percent: ipAdapterInfo.beginStepPct,
|
||||
end_step_percent: ipAdapterInfo.endStepPct,
|
||||
weight,
|
||||
ip_adapter_model: model,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
};
|
||||
|
||||
metadataAccumulator.ipAdapters.push(ipAdapterField);
|
||||
|
Reference in New Issue
Block a user