From c686625076859ff5ad10f88cf65826c27f7cc989 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 Apr 2024 16:05:34 +1000 Subject: [PATCH] feat(ui): add 'control_layer' type --- .../util/graph/addIPAdapterToLinearGraph.ts | 27 +++++--- .../components/RPLayerIPAdapterList.tsx | 4 +- .../hooks/useRegionalControlTitle.ts | 3 +- .../store/regionalPromptsSlice.ts | 67 ++++++++++++------- .../regionalPrompts/util/getLayerBlobs.ts | 4 +- .../regionalPrompts/util/renderers.ts | 55 ++++++++------- .../ControlSettingsAccordion.tsx | 16 +++-- 7 files changed, 105 insertions(+), 71 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts index 0a90622e04..4a287b5335 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts @@ -2,6 +2,8 @@ import type { RootState } from 'app/store/store'; import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice'; import type { IPAdapterConfig } from 'features/controlAdapters/store/types'; import type { ImageField } from 'features/nodes/types/common'; +import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice'; +import { differenceBy } from 'lodash-es'; import type { CollectInvocation, CoreMetadataInvocation, @@ -19,16 +21,21 @@ export const addIPAdapterToLinearGraph = async ( graph: NonNullableGraph, baseNodeId: string ): Promise => { - const validIPAdapters = selectValidIPAdapters(state.controlAdapters) - .filter(({ model, controlImage, isEnabled }) => { - const hasModel = Boolean(model); - const doesBaseMatch = model?.base === state.generation.model?.base; - const hasControlImage = controlImage; - return isEnabled && hasModel && doesBaseMatch && hasControlImage; - }) - .filter((ca) => !state.regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id))); + const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => { + const hasModel = Boolean(model); + const doesBaseMatch = model?.base === state.generation.model?.base; + const hasControlImage = controlImage; + return isEnabled && hasModel && doesBaseMatch && hasControlImage; + }); - if (validIPAdapters.length) { + const regionalIPAdapterIds = state.regionalPrompts.present.layers + .filter(isVectorMaskLayer) + .map((l) => l.ipAdapterIds) + .flat(); + + const nonRegionalIPAdapters = differenceBy(validIPAdapters, regionalIPAdapterIds, 'id'); + + if (nonRegionalIPAdapters.length) { // Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect const ipAdapterCollectNode: CollectInvocation = { id: IP_ADAPTER_COLLECT, @@ -46,7 +53,7 @@ export const addIPAdapterToLinearGraph = async ( const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = []; - for (const ipAdapter of validIPAdapters) { + for (const ipAdapter of nonRegionalIPAdapters) { if (!ipAdapter.model) { return; } diff --git a/invokeai/frontend/web/src/features/regionalPrompts/components/RPLayerIPAdapterList.tsx b/invokeai/frontend/web/src/features/regionalPrompts/components/RPLayerIPAdapterList.tsx index c5d1ca62e9..91cc2d0736 100644 --- a/invokeai/frontend/web/src/features/regionalPrompts/components/RPLayerIPAdapterList.tsx +++ b/invokeai/frontend/web/src/features/regionalPrompts/components/RPLayerIPAdapterList.tsx @@ -2,7 +2,7 @@ import { Flex } from '@invoke-ai/ui-library'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig'; -import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; +import { isVectorMaskLayer,selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { memo, useMemo } from 'react'; import { assert } from 'tsafe'; @@ -14,7 +14,7 @@ export const RPLayerIPAdapterList = memo(({ layerId }: Props) => { const selectIPAdapterIds = useMemo( () => createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => { - const layer = regionalPrompts.present.layers.find((l) => l.id === layerId); + const layer = regionalPrompts.present.layers.filter(isVectorMaskLayer).find((l) => l.id === layerId); assert(layer, `Layer ${layerId} not found`); return layer.ipAdapterIds; }), diff --git a/invokeai/frontend/web/src/features/regionalPrompts/hooks/useRegionalControlTitle.ts b/invokeai/frontend/web/src/features/regionalPrompts/hooks/useRegionalControlTitle.ts index 4f23804c2a..24532d2fa1 100644 --- a/invokeai/frontend/web/src/features/regionalPrompts/hooks/useRegionalControlTitle.ts +++ b/invokeai/frontend/web/src/features/regionalPrompts/hooks/useRegionalControlTitle.ts @@ -1,6 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; +import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -9,6 +9,7 @@ const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (region return 0; } const validLayers = regionalPrompts.present.layers + .filter(isVectorMaskLayer) .filter((l) => l.isVisible) .filter((l) => { const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt); diff --git a/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts b/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts index 1d32938868..5593f0dff6 100644 --- a/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts +++ b/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts @@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils'; import { controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice'; +import type { ControlAdapterConfig } from 'features/controlAdapters/store/types'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; import type { IRect, Vector2d } from 'konva/lib/types'; import { isEqual } from 'lodash-es'; @@ -42,6 +43,11 @@ type LayerBase = { isVisible: boolean; }; +type ControlLayer = LayerBase & { + type: 'control_layer'; + controlAdapter: ControlAdapterConfig; +}; + type MaskLayerBase = LayerBase & { positivePrompt: string | null; negativePrompt: string | null; // Up to one text prompt per mask @@ -56,7 +62,7 @@ export type VectorMaskLayer = MaskLayerBase & { objects: (VectorMaskLine | VectorMaskRect)[]; }; -export type Layer = VectorMaskLayer; +export type Layer = VectorMaskLayer | ControlLayer; type RegionalPromptsState = { _version: 1; @@ -78,12 +84,24 @@ export const initialRegionalPromptsState: RegionalPromptsState = { const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line'; export const isVectorMaskLayer = (layer?: Layer): layer is VectorMaskLayer => layer?.type === 'vector_mask_layer'; -const resetLayer = (layer: VectorMaskLayer) => { - layer.objects = []; - layer.bbox = null; - layer.isVisible = true; - layer.needsPixelBbox = false; - layer.bboxNeedsUpdate = false; +const resetLayer = (layer: Layer) => { + if (layer.type === 'vector_mask_layer') { + layer.objects = []; + layer.bbox = null; + layer.isVisible = true; + layer.needsPixelBbox = false; + layer.bboxNeedsUpdate = false; + return; + } + + if (layer.type === 'control_layer') { + // TODO + } +}; +const getVectorMaskPreviewColor = (state: RegionalPromptsState): RgbColor => { + const vmLayers = state.layers.filter(isVectorMaskLayer); + const lastColor = vmLayers[vmLayers.length - 1]?.previewColor; + return LayerColors.next(lastColor); }; export const regionalPromptsSlice = createSlice({ @@ -93,18 +111,16 @@ export const regionalPromptsSlice = createSlice({ //#region All Layers layerAdded: { reducer: (state, action: PayloadAction) => { - const kind = action.payload; - if (action.payload === 'vector_mask_layer') { - const lastColor = state.layers[state.layers.length - 1]?.previewColor; - const previewColor = LayerColors.next(lastColor); + const type = action.payload; + if (type === 'vector_mask_layer') { const layer: VectorMaskLayer = { id: getVectorMaskLayerId(action.meta.uuid), - type: kind, + type, isVisible: true, bbox: null, bboxNeedsUpdate: false, objects: [], - previewColor, + previewColor: getVectorMaskPreviewColor(state), x: 0, y: 0, autoNegative: 'invert', @@ -117,6 +133,11 @@ export const regionalPromptsSlice = createSlice({ state.selectedLayerId = layer.id; return; } + + if (type === 'control_layer') { + // TODO + return; + } }, prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }), }, @@ -196,21 +217,21 @@ export const regionalPromptsSlice = createSlice({ maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { const { layerId, prompt } = action.payload; const layer = state.layers.find((l) => l.id === layerId); - if (layer) { + if (layer?.type === 'vector_mask_layer') { layer.positivePrompt = prompt; } }, maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { const { layerId, prompt } = action.payload; const layer = state.layers.find((l) => l.id === layerId); - if (layer) { + if (layer?.type === 'vector_mask_layer') { layer.negativePrompt = prompt; } }, maskLayerIPAdapterAdded: { reducer: (state, action: PayloadAction) => { const layer = state.layers.find((l) => l.id === action.payload); - if (layer) { + if (layer?.type === 'vector_mask_layer') { layer.ipAdapterIds.push(action.meta.uuid); } }, @@ -219,7 +240,7 @@ export const regionalPromptsSlice = createSlice({ maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => { const { layerId, color } = action.payload; const layer = state.layers.find((l) => l.id === layerId); - if (layer) { + if (layer?.type === 'vector_mask_layer') { layer.previewColor = color; } }, @@ -234,7 +255,7 @@ export const regionalPromptsSlice = createSlice({ ) => { const { layerId, points, tool } = action.payload; const layer = state.layers.find((l) => l.id === layerId); - if (layer) { + if (layer?.type === 'vector_mask_layer') { const lineId = getVectorMaskLayerLineId(layer.id, action.meta.uuid); layer.objects.push({ type: 'vector_mask_line', @@ -259,7 +280,7 @@ export const regionalPromptsSlice = createSlice({ maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => { const { layerId, point } = action.payload; const layer = state.layers.find((l) => l.id === layerId); - if (layer) { + if (layer?.type === 'vector_mask_layer') { const lastLine = layer.objects.findLast(isLine); if (!lastLine) { return; @@ -278,7 +299,7 @@ export const regionalPromptsSlice = createSlice({ return; } const layer = state.layers.find((l) => l.id === layerId); - if (layer) { + if (layer?.type === 'vector_mask_layer') { const id = getVectorMaskLayerRectId(layer.id, action.meta.uuid); layer.objects.push({ type: 'vector_mask_rect', @@ -299,7 +320,7 @@ export const regionalPromptsSlice = createSlice({ ) => { const { layerId, autoNegative } = action.payload; const layer = state.layers.find((l) => l.id === layerId); - if (layer) { + if (layer?.type === 'vector_mask_layer') { layer.autoNegative = autoNegative; } }, @@ -331,9 +352,9 @@ export const regionalPromptsSlice = createSlice({ }, extraReducers(builder) { builder.addCase(controlAdapterRemoved, (state, action) => { - for (const layer of state.layers) { + state.layers.filter(isVectorMaskLayer).forEach((layer) => { layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id); - } + }); }); }, }); diff --git a/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts b/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts index 28a11b649d..02c1ae8b60 100644 --- a/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts +++ b/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts @@ -1,7 +1,7 @@ import { getStore } from 'app/store/nanostores/store'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; -import { VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice'; +import { isVectorMaskLayer, VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { renderers } from 'features/regionalPrompts/util/renderers'; import Konva from 'konva'; import { assert } from 'tsafe'; @@ -17,7 +17,7 @@ export const getRegionalPromptLayerBlobs = async ( preview: boolean = false ): Promise> => { const state = getStore().getState(); - const reduxLayers = state.regionalPrompts.present.layers; + const reduxLayers = state.regionalPrompts.present.layers.filter(isVectorMaskLayer); const container = document.createElement('div'); const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height }); renderers.renderLayers(stage, reduxLayers, 1, 'brush'); diff --git a/invokeai/frontend/web/src/features/regionalPrompts/util/renderers.ts b/invokeai/frontend/web/src/features/regionalPrompts/util/renderers.ts index 20e5f75ab7..4f008c7758 100644 --- a/invokeai/frontend/web/src/features/regionalPrompts/util/renderers.ts +++ b/invokeai/frontend/web/src/features/regionalPrompts/util/renderers.ts @@ -494,35 +494,38 @@ const renderBbox = ( } for (const reduxLayer of reduxLayers) { - const konvaLayer = stage.findOne(`#${reduxLayer.id}`); - assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`); + if (reduxLayer.type === 'vector_mask_layer') { + const konvaLayer = stage.findOne(`#${reduxLayer.id}`); + assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`); - let bbox = reduxLayer.bbox; + let bbox = reduxLayer.bbox; - // We only need to recalculate the bbox if the layer has changed and it has objects - if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) { - // We only need to use the pixel-perfect bounding box if the layer has eraser strokes - bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer); - // Update the layer's bbox in the redux store - onBboxChanged(reduxLayer.id, bbox); + // We only need to recalculate the bbox if the layer has changed and it has objects + if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) { + // We only need to use the pixel-perfect bounding box if the layer has eraser strokes + bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer); + // Update the layer's bbox in the redux store + onBboxChanged(reduxLayer.id, bbox); + } + + if (!bbox) { + continue; + } + + const rect = + konvaLayer.findOne(`.${LAYER_BBOX_NAME}`) ?? + createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown); + + rect.setAttrs({ + visible: true, + listening: true, + x: bbox.x, + y: bbox.y, + width: bbox.width, + height: bbox.height, + stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE, + }); } - - if (!bbox) { - continue; - } - - const rect = - konvaLayer.findOne(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown); - - rect.setAttrs({ - visible: true, - listening: true, - x: bbox.x, - y: bbox.y, - width: bbox.width, - height: bbox.height, - stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE, - }); } }; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx index 36448c8909..449091536e 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx @@ -13,7 +13,7 @@ import { selectValidIPAdapters, selectValidT2IAdapters, } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; +import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { Fragment, memo } from 'react'; @@ -26,15 +26,17 @@ const selector = createMemoizedSelector( const badges: string[] = []; let isError = false; - const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters) - .filter((ca) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id))) + const enabledNonRegionalIPAdapterCount = selectAllIPAdapters(controlAdapters) + .filter( + (ca) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(ca.id)) + ) .filter((ca) => ca.isEnabled).length; const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length; - if (enabledIPAdapterCount > 0) { - badges.push(`${enabledIPAdapterCount} IP`); + if (enabledNonRegionalIPAdapterCount > 0) { + badges.push(`${enabledNonRegionalIPAdapterCount} IP`); } - if (enabledIPAdapterCount > validIPAdapterCount) { + if (enabledNonRegionalIPAdapterCount > validIPAdapterCount) { isError = true; } @@ -57,7 +59,7 @@ const selector = createMemoizedSelector( } const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter( - (id) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(id)) + (id) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(id)) ); return {