feat(ui): add 'control_layer' type

This commit is contained in:
psychedelicious 2024-04-24 16:05:34 +10:00 committed by Kent Keirsey
parent d861bc690e
commit c686625076
7 changed files with 105 additions and 71 deletions

View File

@ -2,6 +2,8 @@ import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { differenceBy } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
@ -19,16 +21,21 @@ export const addIPAdapterToLinearGraph = async (
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters)
.filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
})
.filter((ca) => !state.regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)));
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
});
if (validIPAdapters.length) {
const regionalIPAdapterIds = state.regionalPrompts.present.layers
.filter(isVectorMaskLayer)
.map((l) => l.ipAdapterIds)
.flat();
const nonRegionalIPAdapters = differenceBy(validIPAdapters, regionalIPAdapterIds, 'id');
if (nonRegionalIPAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
@ -46,7 +53,7 @@ export const addIPAdapterToLinearGraph = async (
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
for (const ipAdapter of validIPAdapters) {
for (const ipAdapter of nonRegionalIPAdapters) {
if (!ipAdapter.model) {
return;
}

View File

@ -2,7 +2,7 @@ import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { isVectorMaskLayer,selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useMemo } from 'react';
import { assert } from 'tsafe';
@ -14,7 +14,7 @@ export const RPLayerIPAdapterList = memo(({ layerId }: Props) => {
const selectIPAdapterIds = useMemo(
() =>
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
const layer = regionalPrompts.present.layers.filter(isVectorMaskLayer).find((l) => l.id === layerId);
assert(layer, `Layer ${layerId} not found`);
return layer.ipAdapterIds;
}),

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -9,6 +9,7 @@ const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (region
return 0;
}
const validLayers = regionalPrompts.present.layers
.filter(isVectorMaskLayer)
.filter((l) => l.isVisible)
.filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);

View File

@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterConfig } from 'features/controlAdapters/store/types';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import type { IRect, Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash-es';
@ -42,6 +43,11 @@ type LayerBase = {
isVisible: boolean;
};
type ControlLayer = LayerBase & {
type: 'control_layer';
controlAdapter: ControlAdapterConfig;
};
type MaskLayerBase = LayerBase & {
positivePrompt: string | null;
negativePrompt: string | null; // Up to one text prompt per mask
@ -56,7 +62,7 @@ export type VectorMaskLayer = MaskLayerBase & {
objects: (VectorMaskLine | VectorMaskRect)[];
};
export type Layer = VectorMaskLayer;
export type Layer = VectorMaskLayer | ControlLayer;
type RegionalPromptsState = {
_version: 1;
@ -78,12 +84,24 @@ export const initialRegionalPromptsState: RegionalPromptsState = {
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
export const isVectorMaskLayer = (layer?: Layer): layer is VectorMaskLayer => layer?.type === 'vector_mask_layer';
const resetLayer = (layer: VectorMaskLayer) => {
layer.objects = [];
layer.bbox = null;
layer.isVisible = true;
layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
const resetLayer = (layer: Layer) => {
if (layer.type === 'vector_mask_layer') {
layer.objects = [];
layer.bbox = null;
layer.isVisible = true;
layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
return;
}
if (layer.type === 'control_layer') {
// TODO
}
};
const getVectorMaskPreviewColor = (state: RegionalPromptsState): RgbColor => {
const vmLayers = state.layers.filter(isVectorMaskLayer);
const lastColor = vmLayers[vmLayers.length - 1]?.previewColor;
return LayerColors.next(lastColor);
};
export const regionalPromptsSlice = createSlice({
@ -93,18 +111,16 @@ export const regionalPromptsSlice = createSlice({
//#region All Layers
layerAdded: {
reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => {
const kind = action.payload;
if (action.payload === 'vector_mask_layer') {
const lastColor = state.layers[state.layers.length - 1]?.previewColor;
const previewColor = LayerColors.next(lastColor);
const type = action.payload;
if (type === 'vector_mask_layer') {
const layer: VectorMaskLayer = {
id: getVectorMaskLayerId(action.meta.uuid),
type: kind,
type,
isVisible: true,
bbox: null,
bboxNeedsUpdate: false,
objects: [],
previewColor,
previewColor: getVectorMaskPreviewColor(state),
x: 0,
y: 0,
autoNegative: 'invert',
@ -117,6 +133,11 @@ export const regionalPromptsSlice = createSlice({
state.selectedLayerId = layer.id;
return;
}
if (type === 'control_layer') {
// TODO
return;
}
},
prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }),
},
@ -196,21 +217,21 @@ export const regionalPromptsSlice = createSlice({
maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
layer.positivePrompt = prompt;
}
},
maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
layer.negativePrompt = prompt;
}
},
maskLayerIPAdapterAdded: {
reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => {
const layer = state.layers.find((l) => l.id === action.payload);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
layer.ipAdapterIds.push(action.meta.uuid);
}
},
@ -219,7 +240,7 @@ export const regionalPromptsSlice = createSlice({
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
layer.previewColor = color;
}
},
@ -234,7 +255,7 @@ export const regionalPromptsSlice = createSlice({
) => {
const { layerId, points, tool } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
const lineId = getVectorMaskLayerLineId(layer.id, action.meta.uuid);
layer.objects.push({
type: 'vector_mask_line',
@ -259,7 +280,7 @@ export const regionalPromptsSlice = createSlice({
maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
const { layerId, point } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
const lastLine = layer.objects.findLast(isLine);
if (!lastLine) {
return;
@ -278,7 +299,7 @@ export const regionalPromptsSlice = createSlice({
return;
}
const layer = state.layers.find((l) => l.id === layerId);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
const id = getVectorMaskLayerRectId(layer.id, action.meta.uuid);
layer.objects.push({
type: 'vector_mask_rect',
@ -299,7 +320,7 @@ export const regionalPromptsSlice = createSlice({
) => {
const { layerId, autoNegative } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer) {
if (layer?.type === 'vector_mask_layer') {
layer.autoNegative = autoNegative;
}
},
@ -331,9 +352,9 @@ export const regionalPromptsSlice = createSlice({
},
extraReducers(builder) {
builder.addCase(controlAdapterRemoved, (state, action) => {
for (const layer of state.layers) {
state.layers.filter(isVectorMaskLayer).forEach((layer) => {
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id);
}
});
});
},
});

View File

@ -1,7 +1,7 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { isVectorMaskLayer, VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { renderers } from 'features/regionalPrompts/util/renderers';
import Konva from 'konva';
import { assert } from 'tsafe';
@ -17,7 +17,7 @@ export const getRegionalPromptLayerBlobs = async (
preview: boolean = false
): Promise<Record<string, Blob>> => {
const state = getStore().getState();
const reduxLayers = state.regionalPrompts.present.layers;
const reduxLayers = state.regionalPrompts.present.layers.filter(isVectorMaskLayer);
const container = document.createElement('div');
const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height });
renderers.renderLayers(stage, reduxLayers, 1, 'brush');

View File

@ -494,35 +494,38 @@ const renderBbox = (
}
for (const reduxLayer of reduxLayers) {
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
if (reduxLayer.type === 'vector_mask_layer') {
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
let bbox = reduxLayer.bbox;
let bbox = reduxLayer.bbox;
// We only need to recalculate the bbox if the layer has changed and it has objects
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
// Update the layer's bbox in the redux store
onBboxChanged(reduxLayer.id, bbox);
// We only need to recalculate the bbox if the layer has changed and it has objects
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
// Update the layer's bbox in the redux store
onBboxChanged(reduxLayer.id, bbox);
}
if (!bbox) {
continue;
}
const rect =
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ??
createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
rect.setAttrs({
visible: true,
listening: true,
x: bbox.x,
y: bbox.y,
width: bbox.width,
height: bbox.height,
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
});
}
if (!bbox) {
continue;
}
const rect =
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
rect.setAttrs({
visible: true,
listening: true,
x: bbox.x,
y: bbox.y,
width: bbox.width,
height: bbox.height,
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
});
}
};

View File

@ -13,7 +13,7 @@ import {
selectValidIPAdapters,
selectValidT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react';
@ -26,15 +26,17 @@ const selector = createMemoizedSelector(
const badges: string[] = [];
let isError = false;
const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters)
.filter((ca) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)))
const enabledNonRegionalIPAdapterCount = selectAllIPAdapters(controlAdapters)
.filter(
(ca) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(ca.id))
)
.filter((ca) => ca.isEnabled).length;
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
if (enabledIPAdapterCount > 0) {
badges.push(`${enabledIPAdapterCount} IP`);
if (enabledNonRegionalIPAdapterCount > 0) {
badges.push(`${enabledNonRegionalIPAdapterCount} IP`);
}
if (enabledIPAdapterCount > validIPAdapterCount) {
if (enabledNonRegionalIPAdapterCount > validIPAdapterCount) {
isError = true;
}
@ -57,7 +59,7 @@ const selector = createMemoizedSelector(
}
const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter(
(id) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(id))
(id) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(id))
);
return {