feat(ui): add 'control_layer' type

This commit is contained in:
psychedelicious 2024-04-24 16:05:34 +10:00 committed by Kent Keirsey
parent d861bc690e
commit c686625076
7 changed files with 105 additions and 71 deletions

View File

@ -2,6 +2,8 @@ import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice'; import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types'; import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { differenceBy } from 'lodash-es';
import type { import type {
CollectInvocation, CollectInvocation,
CoreMetadataInvocation, CoreMetadataInvocation,
@ -19,16 +21,21 @@ export const addIPAdapterToLinearGraph = async (
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string baseNodeId: string
): Promise<void> => { ): Promise<void> => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters) const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
.filter(({ model, controlImage, isEnabled }) => { const hasModel = Boolean(model);
const hasModel = Boolean(model); const doesBaseMatch = model?.base === state.generation.model?.base;
const doesBaseMatch = model?.base === state.generation.model?.base; const hasControlImage = controlImage;
const hasControlImage = controlImage; return isEnabled && hasModel && doesBaseMatch && hasControlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage; });
})
.filter((ca) => !state.regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)));
if (validIPAdapters.length) { const regionalIPAdapterIds = state.regionalPrompts.present.layers
.filter(isVectorMaskLayer)
.map((l) => l.ipAdapterIds)
.flat();
const nonRegionalIPAdapters = differenceBy(validIPAdapters, regionalIPAdapterIds, 'id');
if (nonRegionalIPAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect // Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = { const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT, id: IP_ADAPTER_COLLECT,
@ -46,7 +53,7 @@ export const addIPAdapterToLinearGraph = async (
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = []; const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
for (const ipAdapter of validIPAdapters) { for (const ipAdapter of nonRegionalIPAdapters) {
if (!ipAdapter.model) { if (!ipAdapter.model) {
return; return;
} }

View File

@ -2,7 +2,7 @@ import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig'; import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { isVectorMaskLayer,selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -14,7 +14,7 @@ export const RPLayerIPAdapterList = memo(({ layerId }: Props) => {
const selectIPAdapterIds = useMemo( const selectIPAdapterIds = useMemo(
() => () =>
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => { createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId); const layer = regionalPrompts.present.layers.filter(isVectorMaskLayer).find((l) => l.id === layerId);
assert(layer, `Layer ${layerId} not found`); assert(layer, `Layer ${layerId} not found`);
return layer.ipAdapterIds; return layer.ipAdapterIds;
}), }),

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -9,6 +9,7 @@ const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (region
return 0; return 0;
} }
const validLayers = regionalPrompts.present.layers const validLayers = regionalPrompts.present.layers
.filter(isVectorMaskLayer)
.filter((l) => l.isVisible) .filter((l) => l.isVisible)
.filter((l) => { .filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt); const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);

View File

@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils'; import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterConfig } from 'features/controlAdapters/store/types';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import type { IRect, Vector2d } from 'konva/lib/types'; import type { IRect, Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -42,6 +43,11 @@ type LayerBase = {
isVisible: boolean; isVisible: boolean;
}; };
type ControlLayer = LayerBase & {
type: 'control_layer';
controlAdapter: ControlAdapterConfig;
};
type MaskLayerBase = LayerBase & { type MaskLayerBase = LayerBase & {
positivePrompt: string | null; positivePrompt: string | null;
negativePrompt: string | null; // Up to one text prompt per mask negativePrompt: string | null; // Up to one text prompt per mask
@ -56,7 +62,7 @@ export type VectorMaskLayer = MaskLayerBase & {
objects: (VectorMaskLine | VectorMaskRect)[]; objects: (VectorMaskLine | VectorMaskRect)[];
}; };
export type Layer = VectorMaskLayer; export type Layer = VectorMaskLayer | ControlLayer;
type RegionalPromptsState = { type RegionalPromptsState = {
_version: 1; _version: 1;
@ -78,12 +84,24 @@ export const initialRegionalPromptsState: RegionalPromptsState = {
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line'; const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
export const isVectorMaskLayer = (layer?: Layer): layer is VectorMaskLayer => layer?.type === 'vector_mask_layer'; export const isVectorMaskLayer = (layer?: Layer): layer is VectorMaskLayer => layer?.type === 'vector_mask_layer';
const resetLayer = (layer: VectorMaskLayer) => { const resetLayer = (layer: Layer) => {
layer.objects = []; if (layer.type === 'vector_mask_layer') {
layer.bbox = null; layer.objects = [];
layer.isVisible = true; layer.bbox = null;
layer.needsPixelBbox = false; layer.isVisible = true;
layer.bboxNeedsUpdate = false; layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
return;
}
if (layer.type === 'control_layer') {
// TODO
}
};
const getVectorMaskPreviewColor = (state: RegionalPromptsState): RgbColor => {
const vmLayers = state.layers.filter(isVectorMaskLayer);
const lastColor = vmLayers[vmLayers.length - 1]?.previewColor;
return LayerColors.next(lastColor);
}; };
export const regionalPromptsSlice = createSlice({ export const regionalPromptsSlice = createSlice({
@ -93,18 +111,16 @@ export const regionalPromptsSlice = createSlice({
//#region All Layers //#region All Layers
layerAdded: { layerAdded: {
reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => { reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => {
const kind = action.payload; const type = action.payload;
if (action.payload === 'vector_mask_layer') { if (type === 'vector_mask_layer') {
const lastColor = state.layers[state.layers.length - 1]?.previewColor;
const previewColor = LayerColors.next(lastColor);
const layer: VectorMaskLayer = { const layer: VectorMaskLayer = {
id: getVectorMaskLayerId(action.meta.uuid), id: getVectorMaskLayerId(action.meta.uuid),
type: kind, type,
isVisible: true, isVisible: true,
bbox: null, bbox: null,
bboxNeedsUpdate: false, bboxNeedsUpdate: false,
objects: [], objects: [],
previewColor, previewColor: getVectorMaskPreviewColor(state),
x: 0, x: 0,
y: 0, y: 0,
autoNegative: 'invert', autoNegative: 'invert',
@ -117,6 +133,11 @@ export const regionalPromptsSlice = createSlice({
state.selectedLayerId = layer.id; state.selectedLayerId = layer.id;
return; return;
} }
if (type === 'control_layer') {
// TODO
return;
}
}, },
prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }), prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }),
}, },
@ -196,21 +217,21 @@ export const regionalPromptsSlice = createSlice({
maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload; const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (layer) { if (layer?.type === 'vector_mask_layer') {
layer.positivePrompt = prompt; layer.positivePrompt = prompt;
} }
}, },
maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload; const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (layer) { if (layer?.type === 'vector_mask_layer') {
layer.negativePrompt = prompt; layer.negativePrompt = prompt;
} }
}, },
maskLayerIPAdapterAdded: { maskLayerIPAdapterAdded: {
reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => { reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => {
const layer = state.layers.find((l) => l.id === action.payload); const layer = state.layers.find((l) => l.id === action.payload);
if (layer) { if (layer?.type === 'vector_mask_layer') {
layer.ipAdapterIds.push(action.meta.uuid); layer.ipAdapterIds.push(action.meta.uuid);
} }
}, },
@ -219,7 +240,7 @@ export const regionalPromptsSlice = createSlice({
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => { maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload; const { layerId, color } = action.payload;
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (layer) { if (layer?.type === 'vector_mask_layer') {
layer.previewColor = color; layer.previewColor = color;
} }
}, },
@ -234,7 +255,7 @@ export const regionalPromptsSlice = createSlice({
) => { ) => {
const { layerId, points, tool } = action.payload; const { layerId, points, tool } = action.payload;
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (layer) { if (layer?.type === 'vector_mask_layer') {
const lineId = getVectorMaskLayerLineId(layer.id, action.meta.uuid); const lineId = getVectorMaskLayerLineId(layer.id, action.meta.uuid);
layer.objects.push({ layer.objects.push({
type: 'vector_mask_line', type: 'vector_mask_line',
@ -259,7 +280,7 @@ export const regionalPromptsSlice = createSlice({
maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => { maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
const { layerId, point } = action.payload; const { layerId, point } = action.payload;
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (layer) { if (layer?.type === 'vector_mask_layer') {
const lastLine = layer.objects.findLast(isLine); const lastLine = layer.objects.findLast(isLine);
if (!lastLine) { if (!lastLine) {
return; return;
@ -278,7 +299,7 @@ export const regionalPromptsSlice = createSlice({
return; return;
} }
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (layer) { if (layer?.type === 'vector_mask_layer') {
const id = getVectorMaskLayerRectId(layer.id, action.meta.uuid); const id = getVectorMaskLayerRectId(layer.id, action.meta.uuid);
layer.objects.push({ layer.objects.push({
type: 'vector_mask_rect', type: 'vector_mask_rect',
@ -299,7 +320,7 @@ export const regionalPromptsSlice = createSlice({
) => { ) => {
const { layerId, autoNegative } = action.payload; const { layerId, autoNegative } = action.payload;
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (layer) { if (layer?.type === 'vector_mask_layer') {
layer.autoNegative = autoNegative; layer.autoNegative = autoNegative;
} }
}, },
@ -331,9 +352,9 @@ export const regionalPromptsSlice = createSlice({
}, },
extraReducers(builder) { extraReducers(builder) {
builder.addCase(controlAdapterRemoved, (state, action) => { builder.addCase(controlAdapterRemoved, (state, action) => {
for (const layer of state.layers) { state.layers.filter(isVectorMaskLayer).forEach((layer) => {
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id); layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id);
} });
}); });
}, },
}); });

View File

@ -1,7 +1,7 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { isVectorMaskLayer, VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { renderers } from 'features/regionalPrompts/util/renderers'; import { renderers } from 'features/regionalPrompts/util/renderers';
import Konva from 'konva'; import Konva from 'konva';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -17,7 +17,7 @@ export const getRegionalPromptLayerBlobs = async (
preview: boolean = false preview: boolean = false
): Promise<Record<string, Blob>> => { ): Promise<Record<string, Blob>> => {
const state = getStore().getState(); const state = getStore().getState();
const reduxLayers = state.regionalPrompts.present.layers; const reduxLayers = state.regionalPrompts.present.layers.filter(isVectorMaskLayer);
const container = document.createElement('div'); const container = document.createElement('div');
const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height }); const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height });
renderers.renderLayers(stage, reduxLayers, 1, 'brush'); renderers.renderLayers(stage, reduxLayers, 1, 'brush');

View File

@ -494,35 +494,38 @@ const renderBbox = (
} }
for (const reduxLayer of reduxLayers) { for (const reduxLayer of reduxLayers) {
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`); if (reduxLayer.type === 'vector_mask_layer') {
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`); const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
let bbox = reduxLayer.bbox; let bbox = reduxLayer.bbox;
// We only need to recalculate the bbox if the layer has changed and it has objects // We only need to recalculate the bbox if the layer has changed and it has objects
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) { if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes // We only need to use the pixel-perfect bounding box if the layer has eraser strokes
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer); bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
// Update the layer's bbox in the redux store // Update the layer's bbox in the redux store
onBboxChanged(reduxLayer.id, bbox); onBboxChanged(reduxLayer.id, bbox);
}
if (!bbox) {
continue;
}
const rect =
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ??
createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
rect.setAttrs({
visible: true,
listening: true,
x: bbox.x,
y: bbox.y,
width: bbox.width,
height: bbox.height,
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
});
} }
if (!bbox) {
continue;
}
const rect =
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
rect.setAttrs({
visible: true,
listening: true,
x: bbox.x,
y: bbox.y,
width: bbox.width,
height: bbox.height,
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
});
} }
}; };

View File

@ -13,7 +13,7 @@ import {
selectValidIPAdapters, selectValidIPAdapters,
selectValidT2IAdapters, selectValidT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react'; import { Fragment, memo } from 'react';
@ -26,15 +26,17 @@ const selector = createMemoizedSelector(
const badges: string[] = []; const badges: string[] = [];
let isError = false; let isError = false;
const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters) const enabledNonRegionalIPAdapterCount = selectAllIPAdapters(controlAdapters)
.filter((ca) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id))) .filter(
(ca) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(ca.id))
)
.filter((ca) => ca.isEnabled).length; .filter((ca) => ca.isEnabled).length;
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length; const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
if (enabledIPAdapterCount > 0) { if (enabledNonRegionalIPAdapterCount > 0) {
badges.push(`${enabledIPAdapterCount} IP`); badges.push(`${enabledNonRegionalIPAdapterCount} IP`);
} }
if (enabledIPAdapterCount > validIPAdapterCount) { if (enabledNonRegionalIPAdapterCount > validIPAdapterCount) {
isError = true; isError = true;
} }
@ -57,7 +59,7 @@ const selector = createMemoizedSelector(
} }
const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter( const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter(
(id) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(id)) (id) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(id))
); );
return { return {