From cf637ecaa6708e50935bba6814afc4cc0dec8143 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 12:58:26 +1000 Subject: [PATCH] fix(ui): disabled ip adapters applied to regional control --- .../util/graph/addRegionalPromptsToGraph.ts | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts index 01c0a8dbf8..8d7a3a6c9a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts @@ -15,7 +15,7 @@ import { } from 'features/nodes/util/graph/constants'; import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs'; -import { size, sumBy } from 'lodash-es'; +import { size } from 'lodash-es'; import { imagesApi } from 'services/api/endpoints/images'; import type { CollectInvocation, Edge, IPAdapterInvocation, NonNullableGraph, S } from 'services/api/types'; import { assert } from 'tsafe'; @@ -39,6 +39,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull return hasTextPrompt || hasIPAdapter; }); + const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter( + ({ id, model, controlImage, isEnabled }) => { + const hasModel = Boolean(model); + const doesBaseMatch = model?.base === state.generation.model?.base; + const hasControlImage = controlImage; + const isRegional = layers.some((l) => l.ipAdapterIds.includes(id)); + return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional; + } + ); + const layerIds = layers.map((l) => l.id); const blobs = await getRegionalPromptLayerBlobs(layerIds); assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs'); @@ -105,7 +115,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull }, }); - if (!graph.nodes[IP_ADAPTER_COLLECT] && sumBy(layers, (l) => l.ipAdapterIds.length) > 0) { + if (!graph.nodes[IP_ADAPTER_COLLECT] && regionalIPAdapters.length > 0) { const ipAdapterCollectNode: CollectInvocation = { id: IP_ADAPTER_COLLECT, type: 'collect', @@ -284,8 +294,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull } for (const ipAdapterId of layer.ipAdapterIds) { - const ipAdapter = selectAllIPAdapters(state.controlAdapters).find((ca) => ca.id === ipAdapterId); - console.log(ipAdapter); + const ipAdapter = selectAllIPAdapters(state.controlAdapters) + .filter(({ id, model, controlImage, isEnabled }) => { + const hasModel = Boolean(model); + const doesBaseMatch = model?.base === state.generation.model?.base; + const hasControlImage = controlImage; + const isRegional = layers.some((l) => l.ipAdapterIds.includes(id)); + return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional; + }) + .find((ca) => ca.id === ipAdapterId); + if (!ipAdapter?.model) { return; }