fix(ui): disabled ip adapters applied to regional control

This commit is contained in:
psychedelicious 2024-04-23 12:58:26 +10:00
parent fca718bdd3
commit cf637ecaa6

View File

@ -15,7 +15,7 @@ import {
} from 'features/nodes/util/graph/constants';
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size, sumBy } from 'lodash-es';
import { size } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { CollectInvocation, Edge, IPAdapterInvocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
@ -39,6 +39,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
return hasTextPrompt || hasIPAdapter;
});
const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter(
({ id, model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
}
);
const layerIds = layers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
@ -105,7 +115,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
},
});
if (!graph.nodes[IP_ADAPTER_COLLECT] && sumBy(layers, (l) => l.ipAdapterIds.length) > 0) {
if (!graph.nodes[IP_ADAPTER_COLLECT] && regionalIPAdapters.length > 0) {
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
@ -284,8 +294,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
}
for (const ipAdapterId of layer.ipAdapterIds) {
const ipAdapter = selectAllIPAdapters(state.controlAdapters).find((ca) => ca.id === ipAdapterId);
console.log(ipAdapter);
const ipAdapter = selectAllIPAdapters(state.controlAdapters)
.filter(({ id, model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
})
.find((ca) => ca.id === ipAdapterId);
if (!ipAdapter?.model) {
return;
}