feat(ui): update graphs for multi ip adapter

This commit is contained in:
psychedelicious 2023-10-07 17:30:08 +11:00
parent ed82bf6bb8
commit 35374ec531
3 changed files with 76 additions and 59 deletions

View File

@ -87,18 +87,19 @@ export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
const disableAllIPAdapters = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
// TODO: I think we can safely remove this?
// const disableAllIPAdapters = (
// state: ControlAdaptersState,
// exclude?: string
// ) => {
// const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state)
// .filter((ca) => ca.id !== exclude)
// .map((ca) => ({
// id: ca.id,
// changes: { isEnabled: false },
// }));
// caAdapter.updateMany(state, updates);
// };
const disableAllControlNets = (
state: ControlAdaptersState,
@ -131,10 +132,6 @@ const disableIncompatibleControlAdapters = (
type: ControlAdapterType,
exclude?: string
) => {
if (type === 'ip_adapter') {
// we can only have a single active IP Adapter, if we are enabling this one, disable others
disableAllIPAdapters(state, exclude);
}
if (type === 'controlnet') {
// we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is
disableAllT2IAdapters(state, exclude);

View File

@ -1,15 +1,16 @@
import { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
CollectInvocation,
IPAdapterInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
IP_ADAPTER,
IP_ADAPTER_COLLECT,
METADATA_ACCUMULATOR,
} from './constants';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
export const addIPAdapterToLinearGraph = (
state: RootState,
@ -22,60 +23,78 @@ export const addIPAdapterToLinearGraph = (
| MetadataAccumulatorInvocation
| undefined;
const ipAdapter = validIPAdapters[0];
// TODO: handle multiple IP adapters once backend is capable
if (ipAdapter && ipAdapter.model) {
const { weight, model, beginStepPct, endStepPct } = ipAdapter;
const ipAdapterNode: IPAdapterInvocation = {
id: IP_ADAPTER,
type: 'ip_adapter',
if (validIPAdapters.length) {
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
weight: weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
if (ipAdapter.controlImage) {
ipAdapterNode.image = {
image_name: ipAdapter.controlImage,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = {
image: {
image_name: ipAdapter.controlImage,
},
weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
metadataAccumulator.ipAdapters.push(ipAdapterField);
}
graph.nodes[ipAdapterCollectNode.id] = ipAdapterCollectNode;
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
source: { node_id: ipAdapterCollectNode.id, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'ip_adapter',
},
});
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
validIPAdapters.forEach((ipAdapter) => {
if (!ipAdapter.model) {
return;
}
const { id, weight, model, beginStepPct, endStepPct } = ipAdapter;
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,
weight: weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
if (ipAdapter.controlImage) {
ipAdapterNode.image = {
image_name: ipAdapter.controlImage,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = {
image: {
image_name: ipAdapter.controlImage,
},
weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
metadataAccumulator.ipAdapters.push(ipAdapterField);
}
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'ip_adapter',
node_id: ipAdapterCollectNode.id,
field: 'item',
},
});
}
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'ip_adapter',
},
});
}
});
}
};

View File

@ -46,6 +46,7 @@ export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const COLOR_CORRECT = 'color_correct';
export const PASTE_IMAGE = 'img_paste';
export const CONTROL_NET_COLLECT = 'control_net_collect';
export const IP_ADAPTER_COLLECT = 'ip_adapter_collect';
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
export const IP_ADAPTER = 'ip_adapter';
export const DYNAMIC_PROMPT = 'dynamic_prompt';