mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add 'control_layer' type
This commit is contained in:
parent
d861bc690e
commit
c686625076
@ -2,6 +2,8 @@ import type { RootState } from 'app/store/store';
|
||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { differenceBy } from 'lodash-es';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
@ -19,16 +21,21 @@ export const addIPAdapterToLinearGraph = async (
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters)
|
||||
.filter(({ model, controlImage, isEnabled }) => {
|
||||
const hasModel = Boolean(model);
|
||||
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||
const hasControlImage = controlImage;
|
||||
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||
})
|
||||
.filter((ca) => !state.regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)));
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
|
||||
const hasModel = Boolean(model);
|
||||
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||
const hasControlImage = controlImage;
|
||||
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||
});
|
||||
|
||||
if (validIPAdapters.length) {
|
||||
const regionalIPAdapterIds = state.regionalPrompts.present.layers
|
||||
.filter(isVectorMaskLayer)
|
||||
.map((l) => l.ipAdapterIds)
|
||||
.flat();
|
||||
|
||||
const nonRegionalIPAdapters = differenceBy(validIPAdapters, regionalIPAdapterIds, 'id');
|
||||
|
||||
if (nonRegionalIPAdapters.length) {
|
||||
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
||||
const ipAdapterCollectNode: CollectInvocation = {
|
||||
id: IP_ADAPTER_COLLECT,
|
||||
@ -46,7 +53,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
|
||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||
|
||||
for (const ipAdapter of validIPAdapters) {
|
||||
for (const ipAdapter of nonRegionalIPAdapters) {
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
|
||||
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { isVectorMaskLayer,selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@ -14,7 +14,7 @@ export const RPLayerIPAdapterList = memo(({ layerId }: Props) => {
|
||||
const selectIPAdapterIds = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
|
||||
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
|
||||
const layer = regionalPrompts.present.layers.filter(isVectorMaskLayer).find((l) => l.id === layerId);
|
||||
assert(layer, `Layer ${layerId} not found`);
|
||||
return layer.ipAdapterIds;
|
||||
}),
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@ -9,6 +9,7 @@ const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (region
|
||||
return 0;
|
||||
}
|
||||
const validLayers = regionalPrompts.present.layers
|
||||
.filter(isVectorMaskLayer)
|
||||
.filter((l) => l.isVisible)
|
||||
.filter((l) => {
|
||||
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
|
||||
|
@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
|
||||
import { controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { isEqual } from 'lodash-es';
|
||||
@ -42,6 +43,11 @@ type LayerBase = {
|
||||
isVisible: boolean;
|
||||
};
|
||||
|
||||
type ControlLayer = LayerBase & {
|
||||
type: 'control_layer';
|
||||
controlAdapter: ControlAdapterConfig;
|
||||
};
|
||||
|
||||
type MaskLayerBase = LayerBase & {
|
||||
positivePrompt: string | null;
|
||||
negativePrompt: string | null; // Up to one text prompt per mask
|
||||
@ -56,7 +62,7 @@ export type VectorMaskLayer = MaskLayerBase & {
|
||||
objects: (VectorMaskLine | VectorMaskRect)[];
|
||||
};
|
||||
|
||||
export type Layer = VectorMaskLayer;
|
||||
export type Layer = VectorMaskLayer | ControlLayer;
|
||||
|
||||
type RegionalPromptsState = {
|
||||
_version: 1;
|
||||
@ -78,12 +84,24 @@ export const initialRegionalPromptsState: RegionalPromptsState = {
|
||||
|
||||
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
|
||||
export const isVectorMaskLayer = (layer?: Layer): layer is VectorMaskLayer => layer?.type === 'vector_mask_layer';
|
||||
const resetLayer = (layer: VectorMaskLayer) => {
|
||||
layer.objects = [];
|
||||
layer.bbox = null;
|
||||
layer.isVisible = true;
|
||||
layer.needsPixelBbox = false;
|
||||
layer.bboxNeedsUpdate = false;
|
||||
const resetLayer = (layer: Layer) => {
|
||||
if (layer.type === 'vector_mask_layer') {
|
||||
layer.objects = [];
|
||||
layer.bbox = null;
|
||||
layer.isVisible = true;
|
||||
layer.needsPixelBbox = false;
|
||||
layer.bboxNeedsUpdate = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (layer.type === 'control_layer') {
|
||||
// TODO
|
||||
}
|
||||
};
|
||||
const getVectorMaskPreviewColor = (state: RegionalPromptsState): RgbColor => {
|
||||
const vmLayers = state.layers.filter(isVectorMaskLayer);
|
||||
const lastColor = vmLayers[vmLayers.length - 1]?.previewColor;
|
||||
return LayerColors.next(lastColor);
|
||||
};
|
||||
|
||||
export const regionalPromptsSlice = createSlice({
|
||||
@ -93,18 +111,16 @@ export const regionalPromptsSlice = createSlice({
|
||||
//#region All Layers
|
||||
layerAdded: {
|
||||
reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => {
|
||||
const kind = action.payload;
|
||||
if (action.payload === 'vector_mask_layer') {
|
||||
const lastColor = state.layers[state.layers.length - 1]?.previewColor;
|
||||
const previewColor = LayerColors.next(lastColor);
|
||||
const type = action.payload;
|
||||
if (type === 'vector_mask_layer') {
|
||||
const layer: VectorMaskLayer = {
|
||||
id: getVectorMaskLayerId(action.meta.uuid),
|
||||
type: kind,
|
||||
type,
|
||||
isVisible: true,
|
||||
bbox: null,
|
||||
bboxNeedsUpdate: false,
|
||||
objects: [],
|
||||
previewColor,
|
||||
previewColor: getVectorMaskPreviewColor(state),
|
||||
x: 0,
|
||||
y: 0,
|
||||
autoNegative: 'invert',
|
||||
@ -117,6 +133,11 @@ export const regionalPromptsSlice = createSlice({
|
||||
state.selectedLayerId = layer.id;
|
||||
return;
|
||||
}
|
||||
|
||||
if (type === 'control_layer') {
|
||||
// TODO
|
||||
return;
|
||||
}
|
||||
},
|
||||
prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }),
|
||||
},
|
||||
@ -196,21 +217,21 @@ export const regionalPromptsSlice = createSlice({
|
||||
maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
|
||||
const { layerId, prompt } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
layer.positivePrompt = prompt;
|
||||
}
|
||||
},
|
||||
maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
|
||||
const { layerId, prompt } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
layer.negativePrompt = prompt;
|
||||
}
|
||||
},
|
||||
maskLayerIPAdapterAdded: {
|
||||
reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => {
|
||||
const layer = state.layers.find((l) => l.id === action.payload);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
layer.ipAdapterIds.push(action.meta.uuid);
|
||||
}
|
||||
},
|
||||
@ -219,7 +240,7 @@ export const regionalPromptsSlice = createSlice({
|
||||
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
|
||||
const { layerId, color } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
layer.previewColor = color;
|
||||
}
|
||||
},
|
||||
@ -234,7 +255,7 @@ export const regionalPromptsSlice = createSlice({
|
||||
) => {
|
||||
const { layerId, points, tool } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
const lineId = getVectorMaskLayerLineId(layer.id, action.meta.uuid);
|
||||
layer.objects.push({
|
||||
type: 'vector_mask_line',
|
||||
@ -259,7 +280,7 @@ export const regionalPromptsSlice = createSlice({
|
||||
maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
|
||||
const { layerId, point } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
const lastLine = layer.objects.findLast(isLine);
|
||||
if (!lastLine) {
|
||||
return;
|
||||
@ -278,7 +299,7 @@ export const regionalPromptsSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
const id = getVectorMaskLayerRectId(layer.id, action.meta.uuid);
|
||||
layer.objects.push({
|
||||
type: 'vector_mask_rect',
|
||||
@ -299,7 +320,7 @@ export const regionalPromptsSlice = createSlice({
|
||||
) => {
|
||||
const { layerId, autoNegative } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (layer) {
|
||||
if (layer?.type === 'vector_mask_layer') {
|
||||
layer.autoNegative = autoNegative;
|
||||
}
|
||||
},
|
||||
@ -331,9 +352,9 @@ export const regionalPromptsSlice = createSlice({
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(controlAdapterRemoved, (state, action) => {
|
||||
for (const layer of state.layers) {
|
||||
state.layers.filter(isVectorMaskLayer).forEach((layer) => {
|
||||
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id);
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
});
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { isVectorMaskLayer, VECTOR_MASK_LAYER_NAME } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { renderers } from 'features/regionalPrompts/util/renderers';
|
||||
import Konva from 'konva';
|
||||
import { assert } from 'tsafe';
|
||||
@ -17,7 +17,7 @@ export const getRegionalPromptLayerBlobs = async (
|
||||
preview: boolean = false
|
||||
): Promise<Record<string, Blob>> => {
|
||||
const state = getStore().getState();
|
||||
const reduxLayers = state.regionalPrompts.present.layers;
|
||||
const reduxLayers = state.regionalPrompts.present.layers.filter(isVectorMaskLayer);
|
||||
const container = document.createElement('div');
|
||||
const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height });
|
||||
renderers.renderLayers(stage, reduxLayers, 1, 'brush');
|
||||
|
@ -494,35 +494,38 @@ const renderBbox = (
|
||||
}
|
||||
|
||||
for (const reduxLayer of reduxLayers) {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
|
||||
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
|
||||
if (reduxLayer.type === 'vector_mask_layer') {
|
||||
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`);
|
||||
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`);
|
||||
|
||||
let bbox = reduxLayer.bbox;
|
||||
let bbox = reduxLayer.bbox;
|
||||
|
||||
// We only need to recalculate the bbox if the layer has changed and it has objects
|
||||
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
|
||||
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes
|
||||
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
|
||||
// Update the layer's bbox in the redux store
|
||||
onBboxChanged(reduxLayer.id, bbox);
|
||||
// We only need to recalculate the bbox if the layer has changed and it has objects
|
||||
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
|
||||
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes
|
||||
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
|
||||
// Update the layer's bbox in the redux store
|
||||
onBboxChanged(reduxLayer.id, bbox);
|
||||
}
|
||||
|
||||
if (!bbox) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const rect =
|
||||
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ??
|
||||
createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
|
||||
|
||||
rect.setAttrs({
|
||||
visible: true,
|
||||
listening: true,
|
||||
x: bbox.x,
|
||||
y: bbox.y,
|
||||
width: bbox.width,
|
||||
height: bbox.height,
|
||||
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
|
||||
});
|
||||
}
|
||||
|
||||
if (!bbox) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const rect =
|
||||
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
|
||||
|
||||
rect.setAttrs({
|
||||
visible: true,
|
||||
listening: true,
|
||||
x: bbox.x,
|
||||
y: bbox.y,
|
||||
width: bbox.width,
|
||||
height: bbox.height,
|
||||
stroke: reduxLayer.id === selectedLayerId ? BBOX_SELECTED_STROKE : BBOX_NOT_SELECTED_STROKE,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -13,7 +13,7 @@ import {
|
||||
selectValidIPAdapters,
|
||||
selectValidT2IAdapters,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { isVectorMaskLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { Fragment, memo } from 'react';
|
||||
@ -26,15 +26,17 @@ const selector = createMemoizedSelector(
|
||||
const badges: string[] = [];
|
||||
let isError = false;
|
||||
|
||||
const enabledIPAdapterCount = selectAllIPAdapters(controlAdapters)
|
||||
.filter((ca) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)))
|
||||
const enabledNonRegionalIPAdapterCount = selectAllIPAdapters(controlAdapters)
|
||||
.filter(
|
||||
(ca) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(ca.id))
|
||||
)
|
||||
.filter((ca) => ca.isEnabled).length;
|
||||
|
||||
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
|
||||
if (enabledIPAdapterCount > 0) {
|
||||
badges.push(`${enabledIPAdapterCount} IP`);
|
||||
if (enabledNonRegionalIPAdapterCount > 0) {
|
||||
badges.push(`${enabledNonRegionalIPAdapterCount} IP`);
|
||||
}
|
||||
if (enabledIPAdapterCount > validIPAdapterCount) {
|
||||
if (enabledNonRegionalIPAdapterCount > validIPAdapterCount) {
|
||||
isError = true;
|
||||
}
|
||||
|
||||
@ -57,7 +59,7 @@ const selector = createMemoizedSelector(
|
||||
}
|
||||
|
||||
const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter(
|
||||
(id) => !regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(id))
|
||||
(id) => !regionalPrompts.present.layers.filter(isVectorMaskLayer).some((l) => l.ipAdapterIds.includes(id))
|
||||
);
|
||||
|
||||
return {
|
||||
|
Loading…
x
Reference in New Issue
Block a user