refactor(ui): add control layers separate control adapter implementation (wip)

- Revise control adapter config types
- Recreate all control adapter mutations in control layers slice
- Bit of renaming along the way - typing 'RegionalGuidanceLayer' over and over again was getting tedious
This commit is contained in:
psychedelicious 2024-05-01 13:21:49 +10:00 committed by Kent Keirsey
parent 3717321480
commit 121918352a
16 changed files with 775 additions and 303 deletions

View File

@ -5,12 +5,12 @@ import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdap
import type { ControlNetConfig, IPAdapterConfig } from 'features/controlAdapters/store/types';
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import {
controlAdapterLayerAdded,
ipAdapterLayerAdded,
caLayerAdded,
ipaLayerAdded,
layerDeleted,
maskLayerIPAdapterAdded,
maskLayerIPAdapterDeleted,
regionalGuidanceLayerAdded,
rgLayerAdded,
rgLayerIPAdapterAdded,
rgLayerIPAdapterDeleted,
} from 'features/controlLayers/store/controlLayersSlice';
import type { Layer } from 'features/controlLayers/store/types';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
@ -33,7 +33,7 @@ export const addControlLayersToControlAdapterBridge = (startAppListening: AppSta
const type = action.payload;
const layerId = uuidv4();
if (type === 'regional_guidance_layer') {
dispatch(regionalGuidanceLayerAdded({ layerId }));
dispatch(rgLayerAdded({ layerId }));
return;
}
@ -53,7 +53,7 @@ export const addControlLayersToControlAdapterBridge = (startAppListening: AppSta
overrides.model = models.find((m) => m.base === baseModel) ?? null;
}
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId }));
dispatch(ipaLayerAdded({ layerId, ipAdapterId }));
return;
}
@ -73,7 +73,7 @@ export const addControlLayersToControlAdapterBridge = (startAppListening: AppSta
overrides.processorNode = CONTROLNET_PROCESSORS[overrides.processorType].buildDefaults(baseModel);
}
dispatch(controlAdapterAdded({ type: 'controlnet', overrides }));
dispatch(controlAdapterLayerAdded({ layerId, controlNetId }));
dispatch(caLayerAdded({ layerId, controlNetId }));
return;
}
},
@ -129,7 +129,7 @@ export const addControlLayersToControlAdapterBridge = (startAppListening: AppSta
}
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId }));
dispatch(rgLayerIPAdapterAdded({ layerId, ipAdapterId }));
},
});
@ -138,7 +138,7 @@ export const addControlLayersToControlAdapterBridge = (startAppListening: AppSta
effect: (action, { dispatch }) => {
const { layerId, ipAdapterId } = action.payload;
dispatch(controlAdapterRemoved({ id: ipAdapterId }));
dispatch(maskLayerIPAdapterDeleted({ layerId, ipAdapterId }));
dispatch(rgLayerIPAdapterDeleted({ layerId, ipAdapterId }));
},
});
};

View File

@ -4,8 +4,8 @@ import { guidanceLayerIPAdapterAdded } from 'app/store/middleware/listenerMiddle
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
isRegionalGuidanceLayer,
maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged,
rgLayerNegativePromptChanged,
rgLayerPositivePromptChanged,
selectControlLayersSlice,
} from 'features/controlLayers/store/controlLayersSlice';
import { useCallback, useMemo } from 'react';
@ -33,10 +33,10 @@ export const AddPromptButtons = ({ layerId }: AddPromptButtonProps) => {
);
const validActions = useAppSelector(selectValidActions);
const addPositivePrompt = useCallback(() => {
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: '' }));
dispatch(rgLayerPositivePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]);
const addNegativePrompt = useCallback(() => {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
dispatch(rgLayerNegativePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]);
const addIPAdapter = useCallback(() => {
dispatch(guidanceLayerIPAdapterAdded(layerId));

View File

@ -15,7 +15,7 @@ import {
import { useAppDispatch } from 'app/store/storeHooks';
import { stopPropagation } from 'common/util/stopPropagation';
import { useLayerOpacity } from 'features/controlLayers/hooks/layerStateHooks';
import { isFilterEnabledChanged, layerOpacityChanged } from 'features/controlLayers/store/controlLayersSlice';
import { caLayerIsFilterEnabledChanged, layerOpacityChanged } from 'features/controlLayers/store/controlLayersSlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -40,7 +40,7 @@ const CALayerOpacity = ({ layerId }: Props) => {
);
const onChangeFilter = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(isFilterEnabledChanged({ layerId, isFilterEnabled: e.target.checked }));
dispatch(caLayerIsFilterEnabledChanged({ layerId, isFilterEnabled: e.target.checked }));
},
[dispatch, layerId]
);

View File

@ -4,8 +4,8 @@ import { guidanceLayerIPAdapterAdded } from 'app/store/middleware/listenerMiddle
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
isRegionalGuidanceLayer,
maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged,
rgLayerNegativePromptChanged,
rgLayerPositivePromptChanged,
selectControlLayersSlice,
} from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback, useMemo } from 'react';
@ -32,10 +32,10 @@ export const LayerMenuRGActions = memo(({ layerId }: Props) => {
);
const validActions = useAppSelector(selectValidActions);
const addPositivePrompt = useCallback(() => {
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: '' }));
dispatch(rgLayerPositivePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]);
const addNegativePrompt = useCallback(() => {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
dispatch(rgLayerNegativePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]);
const addIPAdapter = useCallback(() => {
dispatch(guidanceLayerIPAdapterAdded(layerId));

View File

@ -3,7 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
isRegionalGuidanceLayer,
maskLayerAutoNegativeChanged,
rgLayerAutoNegativeChanged,
selectControlLayersSlice,
} from 'features/controlLayers/store/controlLayersSlice';
import type { ChangeEvent } from 'react';
@ -35,7 +35,7 @@ export const RGLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
const autoNegative = useAutoNegative(layerId);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(maskLayerAutoNegativeChanged({ layerId, autoNegative: e.target.checked ? 'invert' : 'off' }));
dispatch(rgLayerAutoNegativeChanged({ layerId, autoNegative: e.target.checked ? 'invert' : 'off' }));
},
[dispatch, layerId]
);

View File

@ -6,7 +6,7 @@ import { stopPropagation } from 'common/util/stopPropagation';
import { rgbColorToString } from 'features/canvas/util/colorToString';
import {
isRegionalGuidanceLayer,
maskLayerPreviewColorChanged,
rgLayerPreviewColorChanged,
selectControlLayersSlice,
} from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback, useMemo } from 'react';
@ -33,7 +33,7 @@ export const RGLayerColorPicker = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch();
const onColorChange = useCallback(
(color: RgbColor) => {
dispatch(maskLayerPreviewColorChanged({ layerId, color }));
dispatch(rgLayerPreviewColorChanged({ layerId, color }));
},
[dispatch, layerId]
);

View File

@ -2,7 +2,7 @@ import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { RGLayerPromptDeleteButton } from 'features/controlLayers/components/RGLayer/RGLayerPromptDeleteButton';
import { useLayerNegativePrompt } from 'features/controlLayers/hooks/layerStateHooks';
import { maskLayerNegativePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { rgLayerNegativePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
@ -21,7 +21,7 @@ export const RGLayerNegativePrompt = memo(({ layerId }: Props) => {
const { t } = useTranslation();
const _onChange = useCallback(
(v: string) => {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: v }));
dispatch(rgLayerNegativePromptChanged({ layerId, prompt: v }));
},
[dispatch, layerId]
);

View File

@ -2,7 +2,7 @@ import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { RGLayerPromptDeleteButton } from 'features/controlLayers/components/RGLayer/RGLayerPromptDeleteButton';
import { useLayerPositivePrompt } from 'features/controlLayers/hooks/layerStateHooks';
import { maskLayerPositivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { rgLayerPositivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
@ -21,7 +21,7 @@ export const RGLayerPositivePrompt = memo(({ layerId }: Props) => {
const { t } = useTranslation();
const _onChange = useCallback(
(v: string) => {
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: v }));
dispatch(rgLayerPositivePromptChanged({ layerId, prompt: v }));
},
[dispatch, layerId]
);

View File

@ -1,8 +1,8 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import {
maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged,
rgLayerNegativePromptChanged,
rgLayerPositivePromptChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -18,9 +18,9 @@ export const RGLayerPromptDeleteButton = memo(({ layerId, polarity }: Props) =>
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
if (polarity === 'positive') {
dispatch(maskLayerPositivePromptChanged({ layerId, prompt: null }));
dispatch(rgLayerPositivePromptChanged({ layerId, prompt: null }));
} else {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: null }));
dispatch(rgLayerNegativePromptChanged({ layerId, prompt: null }));
}
}, [dispatch, layerId, polarity]);
return (

View File

@ -9,9 +9,9 @@ import {
$lastMouseDownPos,
$tool,
brushSizeChanged,
maskLayerLineAdded,
maskLayerPointsAdded,
maskLayerRectAdded,
rfLayerLineAdded,
rgLayerPointsAdded,
rgLayerRectAdded,
} from 'features/controlLayers/store/controlLayersSlice';
import type Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
@ -71,7 +71,7 @@ export const useMouseEvents = () => {
}
if (tool === 'brush' || tool === 'eraser') {
dispatch(
maskLayerLineAdded({
rfLayerLineAdded({
layerId: selectedLayerId,
points: [pos.x, pos.y, pos.x, pos.y],
tool,
@ -94,7 +94,7 @@ export const useMouseEvents = () => {
const tool = $tool.get();
if (pos && lastPos && selectedLayerId && tool === 'rect') {
dispatch(
maskLayerRectAdded({
rgLayerRectAdded({
layerId: selectedLayerId,
rect: {
x: Math.min(pos.x, lastPos.x),
@ -128,7 +128,7 @@ export const useMouseEvents = () => {
}
}
lastCursorPosRef.current = [pos.x, pos.y];
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
}
},
[dispatch, selectedLayerId, tool]
@ -149,7 +149,7 @@ export const useMouseEvents = () => {
$isMouseDown.get() &&
(tool === 'brush' || tool === 'eraser')
) {
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
}
$isMouseOver.set(false);
$isMouseDown.set(false);
@ -181,7 +181,7 @@ export const useMouseEvents = () => {
}
if (tool === 'brush' || tool === 'eraser') {
dispatch(
maskLayerLineAdded({
rfLayerLineAdded({
layerId: selectedLayerId,
points: [pos.x, pos.y, pos.x, pos.y],
tool,

View File

@ -3,12 +3,17 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
isAnyControlAdapterAdded,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type {
CLIPVisionModel,
ControlMode,
ControlNetConfig,
IPAdapterConfig,
IPMethod,
ProcessorConfig,
T2IAdapterConfig,
} from 'features/controlLayers/util/controlAdapters';
import { buildControlAdapterProcessor, imageDTOToImageWithDims } from 'features/controlLayers/util/controlAdapters';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
@ -20,6 +25,7 @@ import { isEqual, partition } from 'lodash-es';
import { atom } from 'nanostores';
import type { RgbColor } from 'react-colorful';
import type { UndoableOptions } from 'redux-undo';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
@ -47,7 +53,6 @@ export const initialControlLayersState: ControlLayersState = {
positivePrompt2: '',
negativePrompt2: '',
shouldConcatPrompts: true,
initialImage: null,
size: {
width: 512,
height: 512,
@ -82,76 +87,35 @@ const getVectorMaskPreviewColor = (state: ControlLayersState): RgbColor => {
const lastColor = vmLayers[vmLayers.length - 1]?.previewColor;
return LayerColors.next(lastColor);
};
const getCALayer = (state: ControlLayersState, layerId: string): ControlAdapterLayer => {
const layer = state.layers.find((l) => l.id === layerId);
assert(isControlAdapterLayer(layer));
return layer;
};
const getIPALayer = (state: ControlLayersState, layerId: string): IPAdapterLayer => {
const layer = state.layers.find((l) => l.id === layerId);
assert(isIPAdapterLayer(layer));
return layer;
};
const getRGLayer = (state: ControlLayersState, layerId: string): RegionalGuidanceLayer => {
const layer = state.layers.find((l) => l.id === layerId);
assert(isRegionalGuidanceLayer(layer));
return layer;
};
const getRGLayerIPAdapter = (state: ControlLayersState, layerId: string, ipAdapterId: string): IPAdapterConfig => {
const layer = state.layers.find((l) => l.id === layerId);
assert(isRegionalGuidanceLayer(layer));
const ipAdapter = layer.ipAdapters.find((ipAdapter) => ipAdapter.id === ipAdapterId);
assert(ipAdapter);
return ipAdapter;
};
export const controlLayersSlice = createSlice({
name: 'controlLayers',
initialState: initialControlLayersState,
reducers: {
//#region All Layers
regionalGuidanceLayerAdded: (state, action: PayloadAction<{ layerId: string }>) => {
const { layerId } = action.payload;
const layer: RegionalGuidanceLayer = {
id: getRegionalGuidanceLayerId(layerId),
type: 'regional_guidance_layer',
isEnabled: true,
bbox: null,
bboxNeedsUpdate: false,
maskObjects: [],
previewColor: getVectorMaskPreviewColor(state),
x: 0,
y: 0,
autoNegative: 'invert',
needsPixelBbox: false,
positivePrompt: '',
negativePrompt: null,
ipAdapterIds: [],
isSelected: true,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
for (const layer of state.layers.filter(isRenderableLayer)) {
if (layer.id !== layerId) {
layer.isSelected = false;
}
}
return;
},
ipAdapterLayerAdded: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
const layer: IPAdapterLayer = {
id: getIPAdapterLayerId(layerId),
type: 'ip_adapter_layer',
isEnabled: true,
ipAdapterId,
};
state.layers.push(layer);
return;
},
controlAdapterLayerAdded: (state, action: PayloadAction<{ layerId: string; controlNetId: string }>) => {
const { layerId, controlNetId } = action.payload;
const layer: ControlAdapterLayer = {
id: getControlNetLayerId(layerId),
type: 'control_adapter_layer',
controlNetId,
x: 0,
y: 0,
bbox: null,
bboxNeedsUpdate: false,
isEnabled: true,
imageName: null,
opacity: 1,
isSelected: true,
isFilterEnabled: true,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
for (const layer of state.layers.filter(isRenderableLayer)) {
if (layer.id !== layerId) {
layer.isSelected = false;
}
}
return;
},
layerSelected: (state, action: PayloadAction<string>) => {
for (const layer of state.layers.filter(isRenderableLayer)) {
if (layer.id === action.payload) {
@ -245,7 +209,103 @@ export const controlLayersSlice = createSlice({
//#endregion
//#region CA Layers
isFilterEnabledChanged: (state, action: PayloadAction<{ layerId: string; isFilterEnabled: boolean }>) => {
caLayerAdded: {
reducer: (
state,
action: PayloadAction<{ layerId: string; controlAdapter: ControlNetConfig | T2IAdapterConfig }>
) => {
const { layerId, controlAdapter } = action.payload;
const layer: ControlAdapterLayer = {
id: getCALayerId(layerId),
type: 'control_adapter_layer',
x: 0,
y: 0,
bbox: null,
bboxNeedsUpdate: false,
isEnabled: true,
opacity: 1,
isSelected: true,
isFilterEnabled: true,
controlAdapter,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
for (const layer of state.layers.filter(isRenderableLayer)) {
if (layer.id !== layerId) {
layer.isSelected = false;
}
}
},
prepare: (controlAdapter: ControlNetConfig | T2IAdapterConfig) => ({
payload: { layerId: uuidv4(), controlAdapter },
}),
},
caLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload;
const layer = getCALayer(state, layerId);
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.controlAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
caLayerProcessedImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload;
const layer = getCALayer(state, layerId);
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.controlAdapter.processedImage = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
caLayerModelChanged: (
state,
action: PayloadAction<{
layerId: string;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
}>
) => {
const { layerId, modelConfig } = action.payload;
const layer = getCALayer(state, layerId);
if (!modelConfig) {
layer.controlAdapter.model = null;
return;
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
const candidateProcessorConfig = buildControlAdapterProcessor(modelConfig);
if (candidateProcessorConfig?.type !== layer.controlAdapter.processorConfig?.type) {
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
// model. We need to use the new processor.
layer.controlAdapter.processedImage = null;
layer.controlAdapter.processorConfig = candidateProcessorConfig;
}
},
caLayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => {
const { layerId, weight } = action.payload;
const layer = getCALayer(state, layerId);
layer.controlAdapter.weight = weight;
},
caLayerBeginEndStepPctChanged: (
state,
action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }>
) => {
const { layerId, beginEndStepPct } = action.payload;
const layer = getCALayer(state, layerId);
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
},
caLayerControlModeChanged: (state, action: PayloadAction<{ layerId: string; controlMode: ControlMode }>) => {
const { layerId, controlMode } = action.payload;
const layer = getCALayer(state, layerId);
assert(layer.controlAdapter.type === 'controlnet');
layer.controlAdapter.controlMode = controlMode;
},
caLayerProcessorConfigChanged: (
state,
action: PayloadAction<{ layerId: string; processorConfig: ProcessorConfig }>
) => {
const { layerId, processorConfig } = action.payload;
const layer = getCALayer(state, layerId);
layer.controlAdapter.processorConfig = processorConfig;
},
caLayerIsFilterEnabledChanged: (state, action: PayloadAction<{ layerId: string; isFilterEnabled: boolean }>) => {
const { layerId, isFilterEnabled } = action.payload;
const layer = state.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
if (layer) {
@ -254,121 +314,217 @@ export const controlLayersSlice = createSlice({
},
//#endregion
//#region Mask Layers
maskLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
//#region IP Adapter Layers
ipaLayerAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; ipAdapter: IPAdapterConfig }>) => {
const { layerId, ipAdapter } = action.payload;
const layer: IPAdapterLayer = {
id: getIPALayerId(layerId),
type: 'ip_adapter_layer',
isEnabled: true,
ipAdapter,
};
state.layers.push(layer);
},
prepare: (ipAdapter: IPAdapterConfig) => ({ payload: { layerId: uuidv4(), ipAdapter } }),
},
ipaLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload;
const layer = getIPALayer(state, layerId);
layer.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
ipaLayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => {
const { layerId, weight } = action.payload;
const layer = getIPALayer(state, layerId);
layer.ipAdapter.weight = weight;
},
ipaLayerBeginEndStepPctChanged: (
state,
action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }>
) => {
const { layerId, beginEndStepPct } = action.payload;
const layer = getIPALayer(state, layerId);
layer.ipAdapter.beginEndStepPct = beginEndStepPct;
},
ipaLayerMethodChanged: (state, action: PayloadAction<{ layerId: string; method: IPMethod }>) => {
const { layerId, method } = action.payload;
const layer = getIPALayer(state, layerId);
layer.ipAdapter.method = method;
},
ipaLayerCLIPVisionModelChanged: (
state,
action: PayloadAction<{ layerId: string; clipVisionModel: CLIPVisionModel }>
) => {
const { layerId, clipVisionModel } = action.payload;
const layer = getIPALayer(state, layerId);
layer.ipAdapter.clipVisionModel = clipVisionModel;
},
//#endregion
//#region RG Layers
rgLayerAdded: (state, action: PayloadAction<{ layerId: string }>) => {
const { layerId } = action.payload;
const layer: RegionalGuidanceLayer = {
id: getRGLayerId(layerId),
type: 'regional_guidance_layer',
isEnabled: true,
bbox: null,
bboxNeedsUpdate: false,
maskObjects: [],
previewColor: getVectorMaskPreviewColor(state),
x: 0,
y: 0,
autoNegative: 'invert',
needsPixelBbox: false,
positivePrompt: '',
negativePrompt: null,
ipAdapters: [],
isSelected: true,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
for (const layer of state.layers.filter(isRenderableLayer)) {
if (layer.id !== layerId) {
layer.isSelected = false;
}
}
},
rgLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
layer.positivePrompt = prompt;
}
const layer = getRGLayer(state, layerId);
layer.positivePrompt = prompt;
},
maskLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
rgLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
layer.negativePrompt = prompt;
}
const layer = getRGLayer(state, layerId);
layer.negativePrompt = prompt;
},
maskLayerIPAdapterAdded: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
rgLayerIPAdapterAdded: (state, action: PayloadAction<{ layerId: string; ipAdapter: IPAdapterConfig }>) => {
const { layerId, ipAdapter } = action.payload;
const layer = getRGLayer(state, layerId);
layer.ipAdapters.push(ipAdapter);
},
rgLayerIPAdapterDeleted: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
layer.ipAdapterIds.push(ipAdapterId);
}
const layer = getRGLayer(state, layerId);
layer.ipAdapters = layer.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId);
},
maskLayerIPAdapterDeleted: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== ipAdapterId);
}
},
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
rgLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
layer.previewColor = color;
}
const layer = getRGLayer(state, layerId);
layer.previewColor = color;
},
maskLayerLineAdded: {
rgLayerLineAdded: {
reducer: (
state,
action: PayloadAction<
{ layerId: string; points: [number, number, number, number]; tool: DrawingTool },
string,
{ uuid: string }
>
action: PayloadAction<{
layerId: string;
points: [number, number, number, number];
tool: DrawingTool;
lineUuid: string;
}>
) => {
const { layerId, points, tool } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
const lineId = getRegionalGuidanceLayerLineId(layer.id, action.meta.uuid);
layer.maskObjects.push({
type: 'vector_mask_line',
tool: tool,
id: lineId,
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener?
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize,
});
layer.bboxNeedsUpdate = true;
if (!layer.needsPixelBbox && tool === 'eraser') {
layer.needsPixelBbox = true;
}
const { layerId, points, tool, lineUuid } = action.payload;
const layer = getRGLayer(state, layerId);
const lineId = getRGLayerLineId(layer.id, lineUuid);
layer.maskObjects.push({
type: 'vector_mask_line',
tool: tool,
id: lineId,
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener?
points: [points[0] - layer.x, points[1] - layer.y, points[2] - layer.x, points[3] - layer.y],
strokeWidth: state.brushSize,
});
layer.bboxNeedsUpdate = true;
if (!layer.needsPixelBbox && tool === 'eraser') {
layer.needsPixelBbox = true;
}
},
prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({
payload,
meta: { uuid: uuidv4() },
payload: { ...payload, lineUuid: uuidv4() },
}),
},
maskLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
const { layerId, point } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
const lastLine = layer.maskObjects.findLast(isLine);
if (!lastLine) {
return;
}
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
layer.bboxNeedsUpdate = true;
const layer = getRGLayer(state, layerId);
const lastLine = layer.maskObjects.findLast(isLine);
if (!lastLine) {
return;
}
// Points must be offset by the layer's x and y coordinates
// TODO: Handle this in the event listener
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
layer.bboxNeedsUpdate = true;
},
maskLayerRectAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect }, string, { uuid: string }>) => {
const { layerId, rect } = action.payload;
rgLayerRectAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect; rectUuid: string }>) => {
const { layerId, rect, rectUuid } = action.payload;
if (rect.height === 0 || rect.width === 0) {
// Ignore zero-area rectangles
return;
}
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
const id = getMaskedGuidnaceLayerRectId(layer.id, action.meta.uuid);
layer.maskObjects.push({
type: 'vector_mask_rect',
id,
x: rect.x - layer.x,
y: rect.y - layer.y,
width: rect.width,
height: rect.height,
});
layer.bboxNeedsUpdate = true;
}
const layer = getRGLayer(state, layerId);
const id = getRGLayerRectId(layer.id, rectUuid);
layer.maskObjects.push({
type: 'vector_mask_rect',
id,
x: rect.x - layer.x,
y: rect.y - layer.y,
width: rect.width,
height: rect.height,
});
layer.bboxNeedsUpdate = true;
},
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload, meta: { uuid: uuidv4() } }),
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
},
maskLayerAutoNegativeChanged: (
rgLayerAutoNegativeChanged: (
state,
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
) => {
const { layerId, autoNegative } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'regional_guidance_layer') {
layer.autoNegative = autoNegative;
}
const layer = getRGLayer(state, layerId);
layer.autoNegative = autoNegative;
},
rgLayerIPAdapterImageChanged: (
state,
action: PayloadAction<{ layerId: string; ipAdapterId: string; imageDTO: ImageDTO | null }>
) => {
const { layerId, ipAdapterId, imageDTO } = action.payload;
const ipAdapter = getRGLayerIPAdapter(state, layerId, ipAdapterId);
ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
rgLayerIPAdapterWeightChanged: (
state,
action: PayloadAction<{ layerId: string; ipAdapterId: string; weight: number }>
) => {
const { layerId, ipAdapterId, weight } = action.payload;
const ipAdapter = getRGLayerIPAdapter(state, layerId, ipAdapterId);
ipAdapter.weight = weight;
},
rgLayerIPAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<{ layerId: string; ipAdapterId: string; beginEndStepPct: [number, number] }>
) => {
const { layerId, ipAdapterId, beginEndStepPct } = action.payload;
const ipAdapter = getRGLayerIPAdapter(state, layerId, ipAdapterId);
ipAdapter.beginEndStepPct = beginEndStepPct;
},
rgLayerIPAdapterMethodChanged: (
state,
action: PayloadAction<{ layerId: string; ipAdapterId: string; method: IPMethod }>
) => {
const { layerId, ipAdapterId, method } = action.payload;
const ipAdapter = getRGLayerIPAdapter(state, layerId, ipAdapterId);
ipAdapter.method = method;
},
rgLayerIPAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<{ layerId: string; ipAdapterId: string; clipVisionModel: CLIPVisionModel }>
) => {
const { layerId, ipAdapterId, clipVisionModel } = action.payload;
const ipAdapter = getRGLayerIPAdapter(state, layerId, ipAdapterId);
ipAdapter.clipVisionModel = clipVisionModel;
},
//#endregion
@ -451,36 +607,14 @@ export const controlLayersSlice = createSlice({
state.size.height = height;
});
builder.addCase(controlAdapterImageChanged, (state, action) => {
const { id, controlImage } = action.payload;
const layer = state.layers.filter(isControlAdapterLayer).find((l) => l.controlNetId === id);
if (layer) {
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.imageName = controlImage?.image_name ?? null;
}
});
builder.addCase(controlAdapterProcessedImageChanged, (state, action) => {
const { id, processedControlImage } = action.payload;
const layer = state.layers.filter(isControlAdapterLayer).find((l) => l.controlNetId === id);
if (layer) {
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.imageName = processedControlImage?.image_name ?? null;
}
});
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
// factor than the UNet. Hopefully we get an upstream fix in diffusers.
builder.addMatcher(isAnyControlAdapterAdded, (state, action) => {
if (action.payload.type === 't2i_adapter') {
state.size.width = roundToMultiple(state.size.width, 64);
state.size.height = roundToMultiple(state.size.height, 64);
}
});
// // TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
// // factor than the UNet. Hopefully we get an upstream fix in diffusers.
// builder.addMatcher(isAnyControlAdapterAdded, (state, action) => {
// if (action.payload.type === 't2i_adapter') {
// state.size.width = roundToMultiple(state.size.width, 64);
// state.size.height = roundToMultiple(state.size.height, 64);
// }
// });
},
});
@ -529,22 +663,22 @@ export const {
layerVisibilityToggled,
selectedLayerReset,
selectedLayerDeleted,
regionalGuidanceLayerAdded,
ipAdapterLayerAdded,
controlAdapterLayerAdded,
rgLayerAdded: regionalGuidanceLayerAdded,
ipaLayerAdded: ipAdapterLayerAdded,
caLayerAdded: controlAdapterLayerAdded,
layerOpacityChanged,
// CA layer actions
isFilterEnabledChanged,
caLayerIsFilterEnabledChanged: isFilterEnabledChanged,
// Mask layer actions
maskLayerLineAdded,
maskLayerPointsAdded,
maskLayerRectAdded,
maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged,
maskLayerIPAdapterAdded,
maskLayerIPAdapterDeleted,
maskLayerAutoNegativeChanged,
maskLayerPreviewColorChanged,
rgLayerLineAdded: maskLayerLineAdded,
rgLayerPointsAdded: maskLayerPointsAdded,
rgLayerRectAdded: maskLayerRectAdded,
rgLayerNegativePromptChanged: maskLayerNegativePromptChanged,
rgLayerPositivePromptChanged: maskLayerPositivePromptChanged,
rgLayerIPAdapterAdded: maskLayerIPAdapterAdded,
rgLayerIPAdapterDeleted: maskLayerIPAdapterDeleted,
rgLayerAutoNegativeChanged: maskLayerAutoNegativeChanged,
rgLayerPreviewColorChanged: maskLayerPreviewColorChanged,
// Base layer actions
positivePromptChanged,
negativePromptChanged,
@ -561,20 +695,6 @@ export const {
redo,
} = controlLayersSlice.actions;
export const selectAllControlAdapterIds = (controlLayers: ControlLayersState) =>
controlLayers.layers.flatMap((l) => {
if (l.type === 'control_adapter_layer') {
return [l.controlNetId];
}
if (l.type === 'ip_adapter_layer') {
return [l.ipAdapterId];
}
if (l.type === 'regional_guidance_layer') {
return l.ipAdapterIds;
}
return [];
});
export const selectControlLayersSlice = (state: RootState) => state.controlLayers;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@ -600,24 +720,23 @@ export const BACKGROUND_RECT_ID = 'background_layer.rect';
export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message';
// Names (aka classes) for Konva layers and objects
export const CONTROLNET_LAYER_NAME = 'control_adapter_layer';
export const CONTROLNET_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
export const regional_guidance_layer_NAME = 'regional_guidance_layer';
export const regional_guidance_layer_LINE_NAME = 'regional_guidance_layer.line';
export const regional_guidance_layer_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
export const regional_guidance_layer_RECT_NAME = 'regional_guidance_layer.rect';
export const CA_LAYER_NAME = 'control_adapter_layer';
export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
export const RG_LAYER_NAME = 'regional_guidance_layer';
export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line';
export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect';
export const LAYER_BBOX_NAME = 'layer.bbox';
// Getters for non-singleton layer and object IDs
const getRegionalGuidanceLayerId = (layerId: string) => `${regional_guidance_layer_NAME}_${layerId}`;
const getRegionalGuidanceLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
const getMaskedGuidnaceLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
export const getRegionalGuidanceLayerObjectGroupId = (layerId: string, groupId: string) =>
`${layerId}.objectGroup_${groupId}`;
const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`;
const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
const getControlNetLayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
export const getControlNetLayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
const getIPAdapterLayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
export const controlLayersPersistConfig: PersistConfig<ControlLayersState> = {
name: controlLayersSlice.name,

View File

@ -1,3 +1,4 @@
import type { ControlNetConfig, IPAdapterConfig,T2IAdapterConfig } from 'features/controlLayers/util/controlAdapters';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type {
ParameterAutoNegative,
@ -47,15 +48,14 @@ type RenderableLayerBase = LayerBase & {
export type ControlAdapterLayer = RenderableLayerBase & {
type: 'control_adapter_layer'; // technically, also t2i adapter layer
controlNetId: string;
imageName: string | null;
opacity: number;
isFilterEnabled: boolean;
controlAdapter: ControlNetConfig | T2IAdapterConfig;
};
export type IPAdapterLayer = LayerBase & {
type: 'ip_adapter_layer'; // technically, also t2i adapter layer
ipAdapterId: string;
type: 'ip_adapter_layer';
ipAdapter: IPAdapterConfig;
};
export type RegionalGuidanceLayer = RenderableLayerBase & {
@ -63,7 +63,7 @@ export type RegionalGuidanceLayer = RenderableLayerBase & {
maskObjects: (VectorMaskLine | VectorMaskRect)[];
positivePrompt: ParameterPositivePrompt | null;
negativePrompt: ParameterNegativePrompt | null; // Up to one text prompt per mask
ipAdapterIds: string[]; // Any number of image prompts
ipAdapters: IPAdapterConfig[]; // Any number of image prompts
previewColor: RgbColor;
autoNegative: ParameterAutoNegative;
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
@ -83,7 +83,6 @@ export type ControlLayersState = {
positivePrompt2: ParameterPositiveStylePromptSDXL;
negativePrompt2: ParameterNegativeStylePromptSDXL;
shouldConcatPrompts: boolean;
initialImage: string | null;
size: {
width: ParameterWidth;
height: ParameterHeight;

View File

@ -1,6 +1,6 @@
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL';
import { regional_guidance_layer_OBJECT_GROUP_NAME } from 'features/controlLayers/store/controlLayersSlice';
import { RG_LAYER_OBJECT_GROUP_NAME } from 'features/controlLayers/store/controlLayersSlice';
import Konva from 'konva';
import type { Layer as KonvaLayerType } from 'konva/lib/Layer';
import type { IRect } from 'konva/lib/types';
@ -81,7 +81,7 @@ export const getLayerBboxPixels = (layer: KonvaLayerType, preview: boolean = fal
offscreenStage.add(layerClone);
for (const child of layerClone.getChildren()) {
if (child.name() === regional_guidance_layer_OBJECT_GROUP_NAME) {
if (child.name() === RG_LAYER_OBJECT_GROUP_NAME) {
// We need to cache the group to ensure it composites out eraser strokes correctly
child.opacity(1);
child.cache();

View File

@ -0,0 +1,354 @@
import { deepClone } from 'common/util/deepClone';
import type {
ParameterControlNetModel,
ParameterIPAdapterModel,
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { merge } from 'lodash-es';
import type {
BaseModelType,
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
ControlNetInvocation,
ControlNetModelConfig,
DepthAnythingImageProcessorInvocation,
DWOpenposeImageProcessorInvocation,
HedImageProcessorInvocation,
ImageDTO,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
MediapipeFaceProcessorInvocation,
MidasDepthImageProcessorInvocation,
MlsdImageProcessorInvocation,
NormalbaeImageProcessorInvocation,
PidiImageProcessorInvocation,
T2IAdapterModelConfig,
ZoeDepthImageProcessorInvocation,
} from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import { z } from 'zod';
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
zDepthAnythingModelSize.safeParse(v).success;
export type CannyProcessorConfig = Required<
Pick<CannyImageProcessorInvocation, 'type' | 'low_threshold' | 'high_threshold'>
>;
export type ColorMapProcessorConfig = Required<Pick<ColorMapImageProcessorInvocation, 'type' | 'color_map_tile_size'>>;
export type ContentShuffleProcessorConfig = Required<
Pick<ContentShuffleImageProcessorInvocation, 'type' | 'w' | 'h' | 'f'>
>;
export type DepthAnythingProcessorConfig = Required<Pick<DepthAnythingImageProcessorInvocation, 'type' | 'model_size'>>;
export type HedProcessorConfig = Required<Pick<HedImageProcessorInvocation, 'type' | 'scribble'>>;
export type LineartAnimeProcessorConfig = Required<Pick<LineartAnimeImageProcessorInvocation, 'type'>>;
export type LineartProcessorConfig = Required<Pick<LineartImageProcessorInvocation, 'type' | 'coarse'>>;
export type MediapipeFaceProcessorConfig = Required<
Pick<MediapipeFaceProcessorInvocation, 'type' | 'max_faces' | 'min_confidence'>
>;
export type MidasDepthProcessorConfig = Required<Pick<MidasDepthImageProcessorInvocation, 'type' | 'a_mult' | 'bg_th'>>;
export type MlsdProcessorConfig = Required<Pick<MlsdImageProcessorInvocation, 'type' | 'thr_v' | 'thr_d'>>;
export type NormalbaeProcessorConfig = Required<Pick<NormalbaeImageProcessorInvocation, 'type'>>;
export type DWOpenposeProcessorConfig = Required<
Pick<DWOpenposeImageProcessorInvocation, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>;
export type PidiProcessorConfig = Required<Pick<PidiImageProcessorInvocation, 'type' | 'safe' | 'scribble'>>;
export type ZoeDepthProcessorConfig = Required<Pick<ZoeDepthImageProcessorInvocation, 'type'>>;
export type ProcessorConfig =
| CannyProcessorConfig
| ColorMapProcessorConfig
| ContentShuffleProcessorConfig
| DepthAnythingProcessorConfig
| HedProcessorConfig
| LineartAnimeProcessorConfig
| LineartProcessorConfig
| MediapipeFaceProcessorConfig
| MidasDepthProcessorConfig
| MlsdProcessorConfig
| NormalbaeProcessorConfig
| DWOpenposeProcessorConfig
| PidiProcessorConfig
| ZoeDepthProcessorConfig;
type ImageWithDims = {
imageName: string;
width: number;
height: number;
};
type ControlAdapterBase = {
id: string;
isEnabled: boolean;
weight: number;
image: ImageWithDims | null;
processedImage: ImageWithDims | null;
processorConfig: ProcessorConfig | null;
beginEndStepPct: [number, number];
};
export type ControlMode = NonNullable<ControlNetInvocation['control_mode']>;
export type ControlNetConfig = ControlAdapterBase & {
type: 'controlnet';
model: ParameterControlNetModel | null;
controlMode: ControlMode;
};
export type T2IAdapterConfig = ControlAdapterBase & {
type: 't2i_adapter';
model: ParameterT2IAdapterModel | null;
};
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
const zIPMethod = z.enum(['full', 'style', 'composition']);
export type IPMethod = z.infer<typeof zIPMethod>;
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
export type IPAdapterConfig = {
id: string;
type: 'ip_adapter';
isEnabled: boolean;
weight: number;
method: IPMethod;
image: ImageWithDims | null;
model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel;
beginEndStepPct: [number, number];
};
type ProcessorData<T extends ProcessorConfig['type']> = {
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
};
type ControlNetProcessorsDict = {
[key in ProcessorConfig['type']]: ProcessorData<key>;
};
/**
* A dict of ControlNet processors, including:
* - label translation key
* - description translation key
* - a builder to create default values for the config
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
canny_image_processor: {
labelTKey: 'controlnet.canny',
descriptionTKey: 'controlnet.cannyDescription',
buildDefaults: () => ({
id: `canny_image_processor_${uuidv4()}`,
type: 'canny_image_processor',
low_threshold: 100,
high_threshold: 200,
}),
},
color_map_image_processor: {
labelTKey: 'controlnet.colorMap',
descriptionTKey: 'controlnet.colorMapDescription',
buildDefaults: () => ({
id: `color_map_image_processor_${uuidv4()}`,
type: 'color_map_image_processor',
color_map_tile_size: 64,
}),
},
content_shuffle_image_processor: {
labelTKey: 'controlnet.contentShuffle',
descriptionTKey: 'controlnet.contentShuffleDescription',
buildDefaults: (baseModel) => ({
id: `content_shuffle_image_processor_${uuidv4()}`,
type: 'content_shuffle_image_processor',
h: baseModel === 'sdxl' ? 1024 : 512,
w: baseModel === 'sdxl' ? 1024 : 512,
f: baseModel === 'sdxl' ? 512 : 256,
}),
},
depth_anything_image_processor: {
labelTKey: 'controlnet.depthAnything',
descriptionTKey: 'controlnet.depthAnythingDescription',
buildDefaults: () => ({
id: `depth_anything_image_processor_${uuidv4()}`,
type: 'depth_anything_image_processor',
model_size: 'small',
}),
},
hed_image_processor: {
labelTKey: 'controlnet.hed',
descriptionTKey: 'controlnet.hedDescription',
buildDefaults: () => ({
id: `hed_image_processor_${uuidv4()}`,
type: 'hed_image_processor',
scribble: false,
}),
},
lineart_anime_image_processor: {
labelTKey: 'controlnet.lineartAnime',
descriptionTKey: 'controlnet.lineartAnimeDescription',
buildDefaults: () => ({
id: `lineart_anime_image_processor_${uuidv4()}`,
type: 'lineart_anime_image_processor',
}),
},
lineart_image_processor: {
labelTKey: 'controlnet.lineart',
descriptionTKey: 'controlnet.lineartDescription',
buildDefaults: () => ({
id: `lineart_image_processor_${uuidv4()}`,
type: 'lineart_image_processor',
coarse: false,
}),
},
mediapipe_face_processor: {
labelTKey: 'controlnet.mediapipeFace',
descriptionTKey: 'controlnet.mediapipeFaceDescription',
buildDefaults: () => ({
id: `mediapipe_face_processor_${uuidv4()}`,
type: 'mediapipe_face_processor',
max_faces: 1,
min_confidence: 0.5,
}),
},
midas_depth_image_processor: {
labelTKey: 'controlnet.depthMidas',
descriptionTKey: 'controlnet.depthMidasDescription',
buildDefaults: () => ({
id: `midas_depth_image_processor_${uuidv4()}`,
type: 'midas_depth_image_processor',
a_mult: 2,
bg_th: 0.1,
}),
},
mlsd_image_processor: {
labelTKey: 'controlnet.mlsd',
descriptionTKey: 'controlnet.mlsdDescription',
buildDefaults: () => ({
id: `mlsd_image_processor_${uuidv4()}`,
type: 'mlsd_image_processor',
thr_d: 0.1,
thr_v: 0.1,
}),
},
normalbae_image_processor: {
labelTKey: 'controlnet.normalBae',
descriptionTKey: 'controlnet.normalBaeDescription',
buildDefaults: () => ({
id: `normalbae_image_processor_${uuidv4()}`,
type: 'normalbae_image_processor',
}),
},
dw_openpose_image_processor: {
labelTKey: 'controlnet.dwOpenpose',
descriptionTKey: 'controlnet.dwOpenposeDescription',
buildDefaults: () => ({
id: `dw_openpose_image_processor_${uuidv4()}`,
type: 'dw_openpose_image_processor',
draw_body: true,
draw_face: false,
draw_hands: false,
}),
},
pidi_image_processor: {
labelTKey: 'controlnet.pidi',
descriptionTKey: 'controlnet.pidiDescription',
buildDefaults: () => ({
id: `pidi_image_processor_${uuidv4()}`,
type: 'pidi_image_processor',
scribble: false,
safe: false,
}),
},
zoe_depth_image_processor: {
labelTKey: 'controlnet.depthZoe',
descriptionTKey: 'controlnet.depthZoeDescription',
buildDefaults: () => ({
id: `zoe_depth_image_processor_${uuidv4()}`,
type: 'zoe_depth_image_processor',
}),
},
};
export const zProcessorType = z.enum([
'canny_image_processor',
'color_map_image_processor',
'content_shuffle_image_processor',
'depth_anything_image_processor',
'hed_image_processor',
'lineart_anime_image_processor',
'lineart_image_processor',
'mediapipe_face_processor',
'midas_depth_image_processor',
'mlsd_image_processor',
'normalbae_image_processor',
'dw_openpose_image_processor',
'pidi_image_processor',
'zoe_depth_image_processor',
]);
export type ProcessorType = z.infer<typeof zProcessorType>;
export const isControlAdapterProcessorType = (v: unknown): v is ProcessorType => zProcessorType.safeParse(v).success;
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
type: 'controlnet',
isEnabled: true,
model: null,
weight: 1,
beginEndStepPct: [0, 0],
controlMode: 'balanced',
image: null,
processedImage: null,
processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(),
};
export const initialT2IAdapter: Omit<T2IAdapterConfig, 'id'> = {
type: 't2i_adapter',
isEnabled: true,
model: null,
weight: 1,
beginEndStepPct: [0, 0],
image: null,
processedImage: null,
processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(),
};
export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
type: 'ip_adapter',
isEnabled: true,
image: null,
model: null,
beginEndStepPct: [0, 0],
method: 'full',
clipVisionModel: 'ViT-H',
weight: 1,
};
export const buildControlNet = (id: string, overrides?: Partial<ControlNetConfig>): ControlNetConfig => {
return merge(deepClone(initialControlNet), { id, overrides });
};
export const buildT2IAdapter = (id: string, overrides?: Partial<T2IAdapterConfig>): T2IAdapterConfig => {
return merge(deepClone(initialT2IAdapter), { id, overrides });
};
export const buildIPAdapter = (id: string, overrides?: Partial<IPAdapterConfig>): IPAdapterConfig => {
return merge(deepClone(initialIPAdapter), { id, overrides });
};
export const buildControlAdapterProcessor = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
): ProcessorConfig | null => {
const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
if (!isControlAdapterProcessorType(defaultPreprocessor)) {
return null;
}
const processorConfig = CONTROLNET_PROCESSORS[defaultPreprocessor].buildDefaults(modelConfig.base);
return processorConfig;
};
export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO): ImageWithDims => ({
imageName: image_name,
width,
height,
});

View File

@ -1,7 +1,7 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { isRegionalGuidanceLayer, regional_guidance_layer_NAME } from 'features/controlLayers/store/controlLayersSlice';
import { isRegionalGuidanceLayer, RG_LAYER_NAME } from 'features/controlLayers/store/controlLayersSlice';
import { renderers } from 'features/controlLayers/util/renderers';
import Konva from 'konva';
import { assert } from 'tsafe';
@ -24,7 +24,7 @@ export const getRegionalPromptLayerBlobs = async (
const stage = new Konva.Stage({ container, width, height });
renderers.renderLayers(stage, reduxLayers, 1, 'brush');
const konvaLayers = stage.find<Konva.Layer>(`.${regional_guidance_layer_NAME}`);
const konvaLayers = stage.find<Konva.Layer>(`.${RG_LAYER_NAME}`);
const blobs: Record<string, Blob> = {};
// First remove all layers

View File

@ -5,20 +5,20 @@ import {
$tool,
BACKGROUND_LAYER_ID,
BACKGROUND_RECT_ID,
CONTROLNET_LAYER_IMAGE_NAME,
CONTROLNET_LAYER_NAME,
getControlNetLayerImageId,
CA_LAYER_IMAGE_NAME,
CA_LAYER_NAME,
getCALayerImageId,
getLayerBboxId,
getRegionalGuidanceLayerObjectGroupId,
getRGLayerObjectGroupId,
isControlAdapterLayer,
isRegionalGuidanceLayer,
isRenderableLayer,
LAYER_BBOX_NAME,
NO_LAYERS_MESSAGE_LAYER_ID,
regional_guidance_layer_LINE_NAME,
regional_guidance_layer_NAME,
regional_guidance_layer_OBJECT_GROUP_NAME,
regional_guidance_layer_RECT_NAME,
RG_LAYER_LINE_NAME,
RG_LAYER_NAME,
RG_LAYER_OBJECT_GROUP_NAME,
RG_LAYER_RECT_NAME,
TOOL_PREVIEW_BRUSH_BORDER_INNER_ID,
TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID,
TOOL_PREVIEW_BRUSH_FILL_ID,
@ -53,10 +53,10 @@ const STAGE_BG_DATAURL =
const mapId = (object: { id: string }) => object.id;
const selectRenderableLayers = (n: Konva.Node) =>
n.name() === regional_guidance_layer_NAME || n.name() === CONTROLNET_LAYER_NAME;
n.name() === RG_LAYER_NAME || n.name() === CA_LAYER_NAME;
const selectVectorMaskObjects = (node: Konva.Node) => {
return node.name() === regional_guidance_layer_LINE_NAME || node.name() === regional_guidance_layer_RECT_NAME;
return node.name() === RG_LAYER_LINE_NAME || node.name() === RG_LAYER_RECT_NAME;
};
/**
@ -141,7 +141,7 @@ const renderToolPreview = (
isMouseOver: boolean,
brushSize: number
) => {
const layerCount = stage.find(`.${regional_guidance_layer_NAME}`).length;
const layerCount = stage.find(`.${RG_LAYER_NAME}`).length;
// Update the stage's pointer style
if (layerCount === 0) {
// We have no layers, so we should not render any tool
@ -233,7 +233,7 @@ const createRegionalGuidanceLayer = (
// This layer hasn't been added to the konva state yet
const konvaLayer = new Konva.Layer({
id: reduxLayer.id,
name: regional_guidance_layer_NAME,
name: RG_LAYER_NAME,
draggable: true,
dragDistance: 0,
});
@ -265,8 +265,8 @@ const createRegionalGuidanceLayer = (
// The object group holds all of the layer's objects (e.g. lines and rects)
const konvaObjectGroup = new Konva.Group({
id: getRegionalGuidanceLayerObjectGroupId(reduxLayer.id, uuidv4()),
name: regional_guidance_layer_OBJECT_GROUP_NAME,
id: getRGLayerObjectGroupId(reduxLayer.id, uuidv4()),
name: RG_LAYER_OBJECT_GROUP_NAME,
listening: false,
});
konvaLayer.add(konvaObjectGroup);
@ -285,7 +285,7 @@ const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Gro
const vectorMaskLine = new Konva.Line({
id: reduxObject.id,
key: reduxObject.id,
name: regional_guidance_layer_LINE_NAME,
name: RG_LAYER_LINE_NAME,
strokeWidth: reduxObject.strokeWidth,
tension: 0,
lineCap: 'round',
@ -307,7 +307,7 @@ const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Gro
const vectorMaskRect = new Konva.Rect({
id: reduxObject.id,
key: reduxObject.id,
name: regional_guidance_layer_RECT_NAME,
name: RG_LAYER_RECT_NAME,
x: reduxObject.x,
y: reduxObject.y,
width: reduxObject.width,
@ -347,7 +347,7 @@ const renderRegionalGuidanceLayer = (
// Convert the color to a string, stripping the alpha - the object group will handle opacity.
const rgbColor = rgbColorToString(reduxLayer.previewColor);
const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${regional_guidance_layer_OBJECT_GROUP_NAME}`);
const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${RG_LAYER_OBJECT_GROUP_NAME}`);
assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`);
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
@ -411,7 +411,7 @@ const renderRegionalGuidanceLayer = (
const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer): Konva.Layer => {
const konvaLayer = new Konva.Layer({
id: reduxLayer.id,
name: CONTROLNET_LAYER_NAME,
name: CA_LAYER_NAME,
imageSmoothingEnabled: true,
});
stage.add(konvaLayer);
@ -420,7 +420,7 @@ const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay
const createControlNetLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({
name: CONTROLNET_LAYER_IMAGE_NAME,
name: CA_LAYER_IMAGE_NAME,
image,
});
konvaLayer.add(konvaImage);
@ -438,11 +438,11 @@ const updateControlNetLayerImageSource = async (
const imageDTO = await req.unwrap();
req.unsubscribe();
const image = new Image();
const imageId = getControlNetLayerImageId(reduxLayer.id, imageName);
const imageId = getCALayerImageId(reduxLayer.id, imageName);
image.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage =
konvaLayer.findOne<Konva.Image>(`.${CONTROLNET_LAYER_IMAGE_NAME}`) ??
konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`) ??
createControlNetLayerImage(konvaLayer, image);
// Update the image's attributes
@ -457,7 +457,7 @@ const updateControlNetLayerImageSource = async (
};
image.src = imageDTO.image_url;
} else {
konvaLayer.findOne(`.${CONTROLNET_LAYER_IMAGE_NAME}`)?.destroy();
konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`)?.destroy();
}
};
@ -497,13 +497,13 @@ const updateControlNetLayerImageAttrs = (
const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer) => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createControlNetLayer(stage, reduxLayer);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CONTROLNET_LAYER_IMAGE_NAME}`);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) {
if (
reduxLayer.imageName &&
canvasImageSource.id !== getControlNetLayerImageId(reduxLayer.id, reduxLayer.imageName)
canvasImageSource.id !== getCALayerImageId(reduxLayer.id, reduxLayer.imageName)
) {
imageSourceNeedsUpdate = true;
} else if (!reduxLayer.imageName) {