feat(ui): update graphs for multi ip adapter

This commit is contained in:
psychedelicious 2023-10-07 17:30:08 +11:00
parent ed82bf6bb8
commit 35374ec531
3 changed files with 76 additions and 59 deletions

View File

@ -87,18 +87,19 @@ export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
(ca.processorType === 'none' && Boolean(ca.controlImage))) (ca.processorType === 'none' && Boolean(ca.controlImage)))
); );
const disableAllIPAdapters = ( // TODO: I think we can safely remove this?
state: ControlAdaptersState, // const disableAllIPAdapters = (
exclude?: string // state: ControlAdaptersState,
) => { // exclude?: string
const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state) // ) => {
.filter((ca) => ca.id !== exclude) // const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state)
.map((ca) => ({ // .filter((ca) => ca.id !== exclude)
id: ca.id, // .map((ca) => ({
changes: { isEnabled: false }, // id: ca.id,
})); // changes: { isEnabled: false },
caAdapter.updateMany(state, updates); // }));
}; // caAdapter.updateMany(state, updates);
// };
const disableAllControlNets = ( const disableAllControlNets = (
state: ControlAdaptersState, state: ControlAdaptersState,
@ -131,10 +132,6 @@ const disableIncompatibleControlAdapters = (
type: ControlAdapterType, type: ControlAdapterType,
exclude?: string exclude?: string
) => { ) => {
if (type === 'ip_adapter') {
// we can only have a single active IP Adapter, if we are enabling this one, disable others
disableAllIPAdapters(state, exclude);
}
if (type === 'controlnet') { if (type === 'controlnet') {
// we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is // we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is
disableAllT2IAdapters(state, exclude); disableAllT2IAdapters(state, exclude);

View File

@ -1,15 +1,16 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { import {
CollectInvocation,
IPAdapterInvocation, IPAdapterInvocation,
MetadataAccumulatorInvocation, MetadataAccumulatorInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from '../../types/types'; import { NonNullableGraph } from '../../types/types';
import { import {
CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_COHERENCE_DENOISE_LATENTS,
IP_ADAPTER, IP_ADAPTER_COLLECT,
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
} from './constants'; } from './constants';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
export const addIPAdapterToLinearGraph = ( export const addIPAdapterToLinearGraph = (
state: RootState, state: RootState,
@ -22,13 +23,29 @@ export const addIPAdapterToLinearGraph = (
| MetadataAccumulatorInvocation | MetadataAccumulatorInvocation
| undefined; | undefined;
const ipAdapter = validIPAdapters[0]; if (validIPAdapters.length) {
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[ipAdapterCollectNode.id] = ipAdapterCollectNode;
graph.edges.push({
source: { node_id: ipAdapterCollectNode.id, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'ip_adapter',
},
});
// TODO: handle multiple IP adapters once backend is capable validIPAdapters.forEach((ipAdapter) => {
if (ipAdapter && ipAdapter.model) { if (!ipAdapter.model) {
const { weight, model, beginStepPct, endStepPct } = ipAdapter; return;
}
const { id, weight, model, beginStepPct, endStepPct } = ipAdapter;
const ipAdapterNode: IPAdapterInvocation = { const ipAdapterNode: IPAdapterInvocation = {
id: IP_ADAPTER, id: `ip_adapter_${id}`,
type: 'ip_adapter', type: 'ip_adapter',
is_intermediate: true, is_intermediate: true,
weight: weight, weight: weight,
@ -46,6 +63,7 @@ export const addIPAdapterToLinearGraph = (
} }
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation; graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
if (metadataAccumulator?.ipAdapters) { if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = { const ipAdapterField = {
image: { image: {
@ -63,8 +81,8 @@ export const addIPAdapterToLinearGraph = (
graph.edges.push({ graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' }, source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: { destination: {
node_id: baseNodeId, node_id: ipAdapterCollectNode.id,
field: 'ip_adapter', field: 'item',
}, },
}); });
@ -77,5 +95,6 @@ export const addIPAdapterToLinearGraph = (
}, },
}); });
} }
});
} }
}; };

View File

@ -46,6 +46,7 @@ export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const COLOR_CORRECT = 'color_correct'; export const COLOR_CORRECT = 'color_correct';
export const PASTE_IMAGE = 'img_paste'; export const PASTE_IMAGE = 'img_paste';
export const CONTROL_NET_COLLECT = 'control_net_collect'; export const CONTROL_NET_COLLECT = 'control_net_collect';
export const IP_ADAPTER_COLLECT = 'ip_adapter_collect';
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect'; export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
export const IP_ADAPTER = 'ip_adapter'; export const IP_ADAPTER = 'ip_adapter';
export const DYNAMIC_PROMPT = 'dynamic_prompt'; export const DYNAMIC_PROMPT = 'dynamic_prompt';