WIP control adapters in regional

This commit is contained in:
psychedelicious 2024-04-29 23:31:33 +10:00 committed by Kent Keirsey
parent e822897b1c
commit ded8267505
45 changed files with 1689 additions and 337 deletions

View File

@ -156,6 +156,7 @@
"balanced": "Balanced", "balanced": "Balanced",
"base": "Base", "base": "Base",
"beginEndStepPercent": "Begin / End Step Percentage", "beginEndStepPercent": "Begin / End Step Percentage",
"beginEndStepPercentShort": "Begin/End %",
"bgth": "bg_th", "bgth": "bg_th",
"canny": "Canny", "canny": "Canny",
"cannyDescription": "Canny edge detection", "cannyDescription": "Canny edge detection",
@ -1531,6 +1532,10 @@
"maskPreviewColor": "Mask Preview Color", "maskPreviewColor": "Mask Preview Color",
"addPositivePrompt": "Add $t(common.positivePrompt)", "addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)", "addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)" "addIPAdapter": "Add $t(common.ipAdapter)",
"maskedGuidance": "Masked Guidance",
"maskedGuidanceLayer": "$t(regionalPrompts.maskedGuidance) $t(unifiedCanvas.layer)",
"controlNetLayer": "$t(common.controlNet) $t(unifiedCanvas.layer)",
"ipAdapterLayer": "$t(common.ipAdapter) $t(unifiedCanvas.layer)"
} }
} }

View File

@ -35,6 +35,7 @@ import { addInitialImageSelectedListener } from 'app/store/middleware/listenerMi
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected'; import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded'; import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged'; import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addRegionalControlToControlAdapterBridge } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress'; import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
@ -157,3 +158,5 @@ addUpscaleRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening); addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening); addSetDefaultSettingsListener(startAppListening);
addRegionalControlToControlAdapterBridge(startAppListening);

View File

@ -48,12 +48,10 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
}) })
).unwrap(); ).unwrap();
const { image_name } = imageDTO;
dispatch( dispatch(
controlAdapterImageChanged({ controlAdapterImageChanged({
id, id,
controlImage: image_name, controlImage: imageDTO,
}) })
); );
}, },

View File

@ -58,12 +58,10 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
}) })
).unwrap(); ).unwrap();
const { image_name } = imageDTO;
dispatch( dispatch(
controlAdapterImageChanged({ controlAdapterImageChanged({
id, id,
controlImage: image_name, controlImage: imageDTO,
}) })
); );
}, },

View File

@ -91,7 +91,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
dispatch( dispatch(
controlAdapterProcessedImageChanged({ controlAdapterProcessedImageChanged({
id, id,
processedControlImage: processedControlImage.image_name, processedControlImage,
}) })
); );
} }

View File

@ -71,7 +71,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
dispatch( dispatch(
controlAdapterImageChanged({ controlAdapterImageChanged({
id, id,
controlImage: activeData.payload.imageDTO.image_name, controlImage: activeData.payload.imageDTO,
}) })
); );
dispatch( dispatch(

View File

@ -96,7 +96,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
dispatch( dispatch(
controlAdapterImageChanged({ controlAdapterImageChanged({
id, id,
controlImage: imageDTO.image_name, controlImage: imageDTO,
}) })
); );
dispatch( dispatch(

View File

@ -0,0 +1,93 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
controlAdapterLayerAdded,
ipAdapterLayerAdded,
layerDeleted,
maskedGuidanceLayerAdded,
maskLayerIPAdapterAdded,
maskLayerIPAdapterDeleted,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import type { Layer } from 'features/regionalPrompts/store/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
export const guidanceLayerAdded = createAction<Layer['type']>('regionalPrompts/guidanceLayerAdded');
export const guidanceLayerDeleted = createAction<string>('regionalPrompts/guidanceLayerDeleted');
export const allLayersDeleted = createAction('regionalPrompts/allLayersDeleted');
export const guidanceLayerIPAdapterAdded = createAction<string>('regionalPrompts/guidanceLayerIPAdapterAdded');
export const guidanceLayerIPAdapterDeleted = createAction<{ layerId: string; ipAdapterId: string }>(
'regionalPrompts/guidanceLayerIPAdapterDeleted'
);
export const addRegionalControlToControlAdapterBridge = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: guidanceLayerAdded,
effect: (action, { dispatch }) => {
const type = action.payload;
const layerId = uuidv4();
if (type === 'ip_adapter_layer') {
const ipAdapterId = uuidv4();
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId }));
} else if (type === 'control_adapter_layer') {
const controlNetId = uuidv4();
dispatch(controlAdapterAdded({ type: 'controlnet', overrides: { id: controlNetId } }));
dispatch(controlAdapterLayerAdded({ layerId, controlNetId }));
} else if (type === 'masked_guidance_layer') {
dispatch(maskedGuidanceLayerAdded({ layerId }));
}
},
});
startAppListening({
actionCreator: guidanceLayerDeleted,
effect: (action, { getState, dispatch }) => {
const layerId = action.payload;
const state = getState();
const layer = state.regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(layer, `Layer ${layerId} not found`);
if (layer.type === 'ip_adapter_layer') {
dispatch(controlAdapterRemoved({ id: layer.ipAdapterId }));
} else if (layer.type === 'control_adapter_layer') {
dispatch(controlAdapterRemoved({ id: layer.controlNetId }));
} else if (layer.type === 'masked_guidance_layer') {
for (const ipAdapterId of layer.ipAdapterIds) {
dispatch(controlAdapterRemoved({ id: ipAdapterId }));
}
}
dispatch(layerDeleted(layerId));
},
});
startAppListening({
actionCreator: allLayersDeleted,
effect: (action, { dispatch, getOriginalState }) => {
const state = getOriginalState();
for (const layer of state.regionalPrompts.present.layers) {
dispatch(guidanceLayerDeleted(layer.id));
}
},
});
startAppListening({
actionCreator: guidanceLayerIPAdapterAdded,
effect: (action, { dispatch }) => {
const layerId = action.payload;
const ipAdapterId = uuidv4();
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId }));
},
});
startAppListening({
actionCreator: guidanceLayerIPAdapterDeleted,
effect: (action, { dispatch }) => {
const { layerId, ipAdapterId } = action.payload;
dispatch(controlAdapterRemoved({ id: ipAdapterId }));
dispatch(maskLayerIPAdapterDeleted({ layerId, ipAdapterId }));
},
});
};

View File

@ -6,9 +6,8 @@ import { deepClone } from 'common/util/deepClone';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter'; import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { maskLayerIPAdapterAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { merge, uniq } from 'lodash-es'; import { merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions'; import { socketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
@ -135,23 +134,46 @@ export const controlAdaptersSlice = createSlice({
const { id, isEnabled } = action.payload; const { id, isEnabled } = action.payload;
caAdapter.updateOne(state, { id, changes: { isEnabled } }); caAdapter.updateOne(state, { id, changes: { isEnabled } });
}, },
controlAdapterImageChanged: ( controlAdapterImageChanged: (state, action: PayloadAction<{ id: string; controlImage: ImageDTO | null }>) => {
state,
action: PayloadAction<{
id: string;
controlImage: string | null;
}>
) => {
const { id, controlImage } = action.payload; const { id, controlImage } = action.payload;
const ca = selectControlAdapterById(state, id); const ca = selectControlAdapterById(state, id);
if (!ca) { if (!ca) {
return; return;
} }
caAdapter.updateOne(state, { if (isControlNetOrT2IAdapter(ca)) {
id, if (controlImage) {
changes: { controlImage, processedControlImage: null }, const { image_name, width, height } = controlImage;
}); const processorNode = deepClone(ca.processorNode);
const minDim = Math.min(controlImage.width, controlImage.height);
if ('detect_resolution' in processorNode) {
processorNode.detect_resolution = minDim;
}
if ('image_resolution' in processorNode) {
processorNode.image_resolution = minDim;
}
if ('resolution' in processorNode) {
processorNode.resolution = minDim;
}
caAdapter.updateOne(state, {
id,
changes: {
processorNode,
controlImage: image_name,
controlImageDimensions: { width, height },
processedControlImage: null,
},
});
} else {
caAdapter.updateOne(state, {
id,
changes: { controlImage: null, controlImageDimensions: null, processedControlImage: null },
});
}
} else {
// ip adapter
caAdapter.updateOne(state, { id, changes: { controlImage: controlImage?.image_name ?? null } });
}
if (controlImage !== null && isControlNetOrT2IAdapter(ca) && ca.processorType !== 'none') { if (controlImage !== null && isControlNetOrT2IAdapter(ca) && ca.processorType !== 'none') {
state.pendingControlImages.push(id); state.pendingControlImages.push(id);
@ -161,7 +183,7 @@ export const controlAdaptersSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
id: string; id: string;
processedControlImage: string | null; processedControlImage: ImageDTO | null;
}> }>
) => { ) => {
const { id, processedControlImage } = action.payload; const { id, processedControlImage } = action.payload;
@ -174,12 +196,24 @@ export const controlAdaptersSlice = createSlice({
return; return;
} }
caAdapter.updateOne(state, { if (processedControlImage) {
id, const { image_name, width, height } = processedControlImage;
changes: { caAdapter.updateOne(state, {
processedControlImage, id,
}, changes: {
}); processedControlImage: image_name,
processedControlImageDimensions: { width, height },
},
});
} else {
caAdapter.updateOne(state, {
id,
changes: {
processedControlImage: null,
processedControlImageDimensions: null,
},
});
}
state.pendingControlImages = state.pendingControlImages.filter((pendingId) => pendingId !== id); state.pendingControlImages = state.pendingControlImages.filter((pendingId) => pendingId !== id);
}, },
@ -222,9 +256,22 @@ export const controlAdaptersSlice = createSlice({
} }
const processor = buildControlAdapterProcessor(modelConfig); const processor = buildControlAdapterProcessor(modelConfig);
update.changes.processorType = processor.processorType; if (processor.processorType !== cn.processorNode.type) {
update.changes.processorNode = processor.processorNode; update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
if (cn.controlImageDimensions) {
const minDim = Math.min(cn.controlImageDimensions.width, cn.controlImageDimensions.height);
if ('detect_resolution' in update.changes.processorNode) {
update.changes.processorNode.detect_resolution = minDim;
}
if ('image_resolution' in update.changes.processorNode) {
update.changes.processorNode.image_resolution = minDim;
}
if ('resolution' in update.changes.processorNode) {
update.changes.processorNode.resolution = minDim;
}
}
}
caAdapter.updateOne(state, update); caAdapter.updateOne(state, update);
}, },
controlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { controlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
@ -341,8 +388,23 @@ export const controlAdaptersSlice = createSlice({
if (update.changes.shouldAutoConfig && modelConfig) { if (update.changes.shouldAutoConfig && modelConfig) {
const processor = buildControlAdapterProcessor(modelConfig); const processor = buildControlAdapterProcessor(modelConfig);
update.changes.processorType = processor.processorType; if (processor.processorType !== cn.processorNode.type) {
update.changes.processorNode = processor.processorNode; update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
// Copy image resolution settings, urgh
if (cn.controlImageDimensions) {
const minDim = Math.min(cn.controlImageDimensions.width, cn.controlImageDimensions.height);
if ('detect_resolution' in update.changes.processorNode) {
update.changes.processorNode.detect_resolution = minDim;
}
if ('image_resolution' in update.changes.processorNode) {
update.changes.processorNode.image_resolution = minDim;
}
if ('resolution' in update.changes.processorNode) {
update.changes.processorNode.resolution = minDim;
}
}
}
} }
caAdapter.updateOne(state, update); caAdapter.updateOne(state, update);
@ -383,10 +445,6 @@ export const controlAdaptersSlice = createSlice({
builder.addCase(socketInvocationError, (state) => { builder.addCase(socketInvocationError, (state) => {
state.pendingControlImages = []; state.pendingControlImages = [];
}); });
builder.addCase(maskLayerIPAdapterAdded, (state, action) => {
caAdapter.addOne(state, buildControlAdapter(action.meta.uuid, 'ip_adapter'));
});
}, },
}); });

View File

@ -225,7 +225,9 @@ export type ControlNetConfig = {
controlMode: ControlMode; controlMode: ControlMode;
resizeMode: ResizeMode; resizeMode: ResizeMode;
controlImage: string | null; controlImage: string | null;
controlImageDimensions: { width: number; height: number } | null;
processedControlImage: string | null; processedControlImage: string | null;
processedControlImageDimensions: { width: number; height: number } | null;
processorType: ControlAdapterProcessorType; processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode; processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean; shouldAutoConfig: boolean;
@ -241,7 +243,9 @@ export type T2IAdapterConfig = {
endStepPct: number; endStepPct: number;
resizeMode: ResizeMode; resizeMode: ResizeMode;
controlImage: string | null; controlImage: string | null;
controlImageDimensions: { width: number; height: number } | null;
processedControlImage: string | null; processedControlImage: string | null;
processedControlImageDimensions: { width: number; height: number } | null;
processorType: ControlAdapterProcessorType; processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode; processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean; shouldAutoConfig: boolean;

View File

@ -20,7 +20,9 @@ export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
controlMode: 'balanced', controlMode: 'balanced',
resizeMode: 'just_resize', resizeMode: 'just_resize',
controlImage: null, controlImage: null,
controlImageDimensions: null,
processedControlImage: null, processedControlImage: null,
processedControlImageDimensions: null,
processorType: 'canny_image_processor', processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults() as RequiredCannyImageProcessorInvocation, processorNode: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults() as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true, shouldAutoConfig: true,
@ -35,7 +37,9 @@ export const initialT2IAdapter: Omit<T2IAdapterConfig, 'id'> = {
endStepPct: 1, endStepPct: 1,
resizeMode: 'just_resize', resizeMode: 'just_resize',
controlImage: null, controlImage: null,
controlImageDimensions: null,
processedControlImage: null, processedControlImage: null,
processedControlImageDimensions: null,
processorType: 'canny_image_processor', processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults() as RequiredCannyImageProcessorInvocation, processorNode: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults() as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true, shouldAutoConfig: true,

View File

@ -286,7 +286,9 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
controlMode: control_mode ?? initialControlNet.controlMode, controlMode: control_mode ?? initialControlNet.controlMode,
resizeMode: resize_mode ?? initialControlNet.resizeMode, resizeMode: resize_mode ?? initialControlNet.resizeMode,
controlImage: image?.image_name ?? null, controlImage: image?.image_name ?? null,
controlImageDimensions: null,
processedControlImage: processedImage?.image_name ?? null, processedControlImage: processedImage?.image_name ?? null,
processedControlImageDimensions: null,
processorType, processorType,
processorNode, processorNode,
shouldAutoConfig: true, shouldAutoConfig: true,
@ -350,9 +352,11 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct, endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode, resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
controlImage: image?.image_name ?? null, controlImage: image?.image_name ?? null,
controlImageDimensions: null,
processedControlImage: processedImage?.image_name ?? null, processedControlImage: processedImage?.image_name ?? null,
processorType, processedControlImageDimensions: null,
processorNode, processorNode,
processorType,
shouldAutoConfig: true, shouldAutoConfig: true,
id: uuidv4(), id: uuidv4(),
}; };

View File

@ -2,6 +2,9 @@ import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice'; import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types'; import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
import { isControlAdapterLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { differenceWith, intersectionWith } from 'lodash-es';
import type { import type {
CollectInvocation, CollectInvocation,
ControlNetInvocation, ControlNetInvocation,
@ -14,11 +17,8 @@ import { assert } from 'tsafe';
import { CONTROL_NET_COLLECT } from './constants'; import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata'; import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = async ( const getControlNets = (state: RootState) => {
state: RootState, // Start with the valid controlnets
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter( const validControlNets = selectValidControlNets(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => { ({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model); const hasModel = Boolean(model);
@ -29,9 +29,33 @@ export const addControlNetToLinearGraph = async (
} }
); );
// txt2img tab has special handling - it uses layers exclusively, while the other tabs use the older control adapters
// accordion. We need to filter the list of valid T2I adapters according to the tab.
const activeTabName = activeTabNameSelector(state);
// Collect all ControlNet ids for ControlNet layers
const layerControlNetIds = state.regionalPrompts.present.layers
.filter(isControlAdapterLayer)
.map((l) => l.controlNetId);
if (activeTabName === 'txt2img') {
// Add only the cnets that are used in control layers
return intersectionWith(validControlNets, layerControlNetIds, (a, b) => a.id === b);
} else {
// Else, we want to exclude the cnets that are used in control layers
return differenceWith(validControlNets, layerControlNetIds, (a, b) => a.id === b);
}
};
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const controlNets = getControlNets(state);
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = []; const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
if (validControlNets.length) { if (controlNets.length) {
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect // Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,
@ -47,7 +71,7 @@ export const addControlNetToLinearGraph = async (
}, },
}); });
for (const controlNet of validControlNets) { for (const controlNet of controlNets) {
if (!controlNet.model) { if (!controlNet.model) {
return; return;
} }

View File

@ -2,8 +2,9 @@ import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice'; import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types'; import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
import { isMaskedGuidanceLayer } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { isIPAdapterLayer, isMaskedGuidanceLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { differenceBy } from 'lodash-es'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { differenceWith, intersectionWith } from 'lodash-es';
import type { import type {
CollectInvocation, CollectInvocation,
CoreMetadataInvocation, CoreMetadataInvocation,
@ -16,11 +17,8 @@ import { assert } from 'tsafe';
import { IP_ADAPTER_COLLECT } from './constants'; import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata'; import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = async ( const getIPAdapters = (state: RootState) => {
state: RootState, // Start with the valid IP adapters
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => { const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model); const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base; const doesBaseMatch = model?.base === state.generation.model?.base;
@ -28,14 +26,37 @@ export const addIPAdapterToLinearGraph = async (
return isEnabled && hasModel && doesBaseMatch && hasControlImage; return isEnabled && hasModel && doesBaseMatch && hasControlImage;
}); });
const regionalIPAdapterIds = state.regionalPrompts.present.layers // Masked IP adapters are handled in the graph helper for regional control - skip them here
const maskedIPAdapterIds = state.regionalPrompts.present.layers
.filter(isMaskedGuidanceLayer) .filter(isMaskedGuidanceLayer)
.map((l) => l.ipAdapterIds) .map((l) => l.ipAdapterIds)
.flat(); .flat();
const nonMaskedIPAdapters = differenceWith(validIPAdapters, maskedIPAdapterIds, (a, b) => a.id === b);
const nonRegionalIPAdapters = differenceBy(validIPAdapters, regionalIPAdapterIds, 'id'); // txt2img tab has special handling - it uses layers exclusively, while the other tabs use the older control adapters
// accordion. We need to filter the list of valid IP adapters according to the tab.
const activeTabName = activeTabNameSelector(state);
if (nonRegionalIPAdapters.length) { // Collect all IP Adapter ids for IP adapter layers
const layerIPAdapterIds = state.regionalPrompts.present.layers.filter(isIPAdapterLayer).map((l) => l.ipAdapterId);
if (activeTabName === 'txt2img') {
// If we are on the t2i tab, we only want to add the IP adapters that are used in unmasked IP Adapter layers
return intersectionWith(nonMaskedIPAdapters, layerIPAdapterIds, (a, b) => a.id === b);
} else {
// Else, we want to exclude the IP adapters that are used in IP Adapter layers
return differenceWith(nonMaskedIPAdapters, layerIPAdapterIds, (a, b) => a.id === b);
}
};
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const ipAdapters = getIPAdapters(state);
if (ipAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect // Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = { const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT, id: IP_ADAPTER_COLLECT,
@ -53,7 +74,7 @@ export const addIPAdapterToLinearGraph = async (
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = []; const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
for (const ipAdapter of nonRegionalIPAdapters) { for (const ipAdapter of ipAdapters) {
if (!ipAdapter.model) { if (!ipAdapter.model) {
return; return;
} }

View File

@ -31,7 +31,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// TODO: Image masks // TODO: Image masks
.filter(isMaskedGuidanceLayer) .filter(isMaskedGuidanceLayer)
// Only visible layers are rendered on the canvas // Only visible layers are rendered on the canvas
.filter((l) => l.isVisible) .filter((l) => l.isEnabled)
// Only layers with prompts get added to the graph // Only layers with prompts get added to the graph
.filter((l) => { .filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt); const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
@ -39,12 +39,15 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
return hasTextPrompt || hasIPAdapter; return hasTextPrompt || hasIPAdapter;
}); });
// Collect all IP Adapter ids for IP adapter layers
const layerIPAdapterIds = layers.flatMap((l) => l.ipAdapterIds);
const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter( const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter(
({ id, model, controlImage, isEnabled }) => { ({ id, model, controlImage, isEnabled }) => {
const hasModel = Boolean(model); const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base; const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage; const hasControlImage = controlImage;
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id)); const isRegional = layerIPAdapterIds.includes(id);
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional; return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
} }
); );

View File

@ -2,6 +2,9 @@ import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice'; import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types'; import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
import { isControlAdapterLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { differenceWith, intersectionWith } from 'lodash-es';
import type { import type {
CollectInvocation, CollectInvocation,
CoreMetadataInvocation, CoreMetadataInvocation,
@ -14,11 +17,8 @@ import { assert } from 'tsafe';
import { T2I_ADAPTER_COLLECT } from './constants'; import { T2I_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata'; import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = async ( const getT2IAdapters = (state: RootState) => {
state: RootState, // Start with the valid controlnets
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter( const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => { ({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model); const hasModel = Boolean(model);
@ -29,7 +29,32 @@ export const addT2IAdaptersToLinearGraph = async (
} }
); );
if (validT2IAdapters.length) { // txt2img tab has special handling - it uses layers exclusively, while the other tabs use the older control adapters
// accordion. We need to filter the list of valid T2I adapters according to the tab.
const activeTabName = activeTabNameSelector(state);
// Collect all ids for control adapter layers
const layerControlAdapterIds = state.regionalPrompts.present.layers
.filter(isControlAdapterLayer)
.map((l) => l.controlNetId);
if (activeTabName === 'txt2img') {
// Add only the T2Is that are used in control layers
return intersectionWith(validT2IAdapters, layerControlAdapterIds, (a, b) => a.id === b);
} else {
// Else, we want to exclude the T2Is that are used in control layers
return differenceWith(validT2IAdapters, layerControlAdapterIds, (a, b) => a.id === b);
}
};
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const t2iAdapters = getT2IAdapters(state);
if (t2iAdapters.length) {
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect // Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = { const t2iAdapterCollectNode: CollectInvocation = {
id: T2I_ADAPTER_COLLECT, id: T2I_ADAPTER_COLLECT,
@ -47,7 +72,7 @@ export const addT2IAdaptersToLinearGraph = async (
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = []; const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
for (const t2iAdapter of validT2IAdapters) { for (const t2iAdapter of t2iAdapters) {
if (!t2iAdapter.model) { if (!t2iAdapter.model) {
return; return;
} }

View File

@ -1,6 +1,6 @@
import { Button } from '@invoke-ai/ui-library'; import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { layerAdded } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { guidanceLayerAdded } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi'; import { PiPlusBold } from 'react-icons/pi';
@ -8,14 +8,27 @@ import { PiPlusBold } from 'react-icons/pi';
export const AddLayerButton = memo(() => { export const AddLayerButton = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const onClick = useCallback(() => { const addMaskedGuidanceLayer = useCallback(() => {
dispatch(layerAdded('masked_guidance_layer')); dispatch(guidanceLayerAdded('masked_guidance_layer'));
}, [dispatch]);
const addControlNetLayer = useCallback(() => {
dispatch(guidanceLayerAdded('control_adapter_layer'));
}, [dispatch]);
const addIPAdapterLayer = useCallback(() => {
dispatch(guidanceLayerAdded('ip_adapter_layer'));
}, [dispatch]); }, [dispatch]);
return ( return (
<Button onClick={onClick} leftIcon={<PiPlusBold />} variant="ghost"> <Menu>
{t('regionalPrompts.addLayer')} <MenuButton as={Button} leftIcon={<PiPlusBold />} variant="ghost">
</Button> {t('regionalPrompts.addLayer')}
</MenuButton>
<MenuList>
<MenuItem onClick={addMaskedGuidanceLayer}> {t('regionalPrompts.maskedGuidanceLayer')}</MenuItem>
<MenuItem onClick={addControlNetLayer}> {t('regionalPrompts.controlNetLayer')}</MenuItem>
<MenuItem onClick={addIPAdapterLayer}> {t('regionalPrompts.ipAdapterLayer')}</MenuItem>
</MenuList>
</Menu>
); );
}); });

View File

@ -1,9 +1,9 @@
import { Button, Flex } from '@invoke-ai/ui-library'; import { Button, Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { guidanceLayerIPAdapterAdded } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import {
isMaskedGuidanceLayer, isMaskedGuidanceLayer,
maskLayerIPAdapterAdded,
maskLayerNegativePromptChanged, maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged, maskLayerPositivePromptChanged,
selectRegionalPromptsSlice, selectRegionalPromptsSlice,
@ -39,7 +39,7 @@ export const AddPromptButtons = ({ layerId }: AddPromptButtonProps) => {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' })); dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]); }, [dispatch, layerId]);
const addIPAdapter = useCallback(() => { const addIPAdapter = useCallback(() => {
dispatch(maskLayerIPAdapterAdded(layerId)); dispatch(guidanceLayerIPAdapterAdded(layerId));
}, [dispatch, layerId]); }, [dispatch, layerId]);
return ( return (

View File

@ -23,7 +23,7 @@ export const BrushSize = memo(() => {
const brushSize = useAppSelector((s) => s.regionalPrompts.present.brushSize); const brushSize = useAppSelector((s) => s.regionalPrompts.present.brushSize);
const onChange = useCallback( const onChange = useCallback(
(v: number) => { (v: number) => {
dispatch(brushSizeChanged(v)); dispatch(brushSizeChanged(Math.round(v)));
}, },
[dispatch] [dispatch]
); );

View File

@ -0,0 +1,62 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ControlAdapterLayerConfig from 'features/regionalPrompts/components/controlAdapterOverrides/ControlAdapterLayerConfig';
import { LayerTitle } from 'features/regionalPrompts/components/LayerTitle';
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
import { RPLayerVisibilityToggle } from 'features/regionalPrompts/components/RPLayerVisibilityToggle';
import {
isControlAdapterLayer,
layerSelected,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useMemo } from 'react';
import { assert } from 'tsafe';
type Props = {
layerId: string;
};
export const ControlAdapterLayerListItem = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(isControlAdapterLayer(layer), `Layer ${layerId} not found or not a ControlNet layer`);
return {
controlNetId: layer.controlNetId,
isSelected: layerId === regionalPrompts.present.selectedLayerId,
};
}),
[layerId]
);
const { controlNetId, isSelected } = useAppSelector(selector);
const onClickCapture = useCallback(() => {
// Must be capture so that the layer is selected before deleting/resetting/etc
dispatch(layerSelected(layerId));
}, [dispatch, layerId]);
return (
<Flex
gap={2}
onClickCapture={onClickCapture}
bg={isSelected ? 'base.400' : 'base.800'}
ps={2}
borderRadius="base"
pe="1px"
py="1px"
>
<Flex flexDir="column" gap={4} w="full" bg="base.850" p={3} borderRadius="base">
<Flex gap={3} alignItems="center">
<RPLayerVisibilityToggle layerId={layerId} />
<LayerTitle type="control_adapter_layer" />
<Spacer />
<RPLayerDeleteButton layerId={layerId} />
</Flex>
<ControlAdapterLayerConfig id={controlNetId} />
</Flex>
</Flex>
);
});
ControlAdapterLayerListItem.displayName = 'ControlAdapterLayerListItem';

View File

@ -1,6 +1,6 @@
import { Button } from '@invoke-ai/ui-library'; import { Button } from '@invoke-ai/ui-library';
import { allLayersDeleted } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { allLayersDeleted } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi'; import { PiTrashSimpleBold } from 'react-icons/pi';

View File

@ -0,0 +1,62 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ControlAdapterLayerConfig from 'features/regionalPrompts/components/controlAdapterOverrides/ControlAdapterLayerConfig';
import { LayerTitle } from 'features/regionalPrompts/components/LayerTitle';
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
import { RPLayerVisibilityToggle } from 'features/regionalPrompts/components/RPLayerVisibilityToggle';
import {
isIPAdapterLayer,
layerSelected,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useMemo } from 'react';
import { assert } from 'tsafe';
type Props = {
layerId: string;
};
export const IPAdapterLayerListItem = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(isIPAdapterLayer(layer), `Layer ${layerId} not found or not an IP Adapter layer`);
return {
ipAdapterId: layer.ipAdapterId,
isSelected: layerId === regionalPrompts.present.selectedLayerId,
};
}),
[layerId]
);
const { ipAdapterId, isSelected } = useAppSelector(selector);
const onClickCapture = useCallback(() => {
// Must be capture so that the layer is selected before deleting/resetting/etc
dispatch(layerSelected(layerId));
}, [dispatch, layerId]);
return (
<Flex
gap={2}
onClickCapture={onClickCapture}
bg={isSelected ? 'base.400' : 'base.800'}
ps={2}
borderRadius="base"
pe="1px"
py="1px"
>
<Flex flexDir="column" gap={4} w="full" bg="base.850" p={3} borderRadius="base">
<Flex gap={3} alignItems="center">
<RPLayerVisibilityToggle layerId={layerId} />
<LayerTitle type="ip_adapter_layer" />
<Spacer />
<RPLayerDeleteButton layerId={layerId} />
</Flex>
<ControlAdapterLayerConfig id={ipAdapterId} />
</Flex>
</Flex>
);
});
IPAdapterLayerListItem.displayName = 'IPAdapterLayerListItem';

View File

@ -0,0 +1,29 @@
import { Text } from '@invoke-ai/ui-library';
import type { Layer } from 'features/regionalPrompts/store/types';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
type: Layer['type'];
};
export const LayerTitle = memo(({ type }: Props) => {
const { t } = useTranslation();
const title = useMemo(() => {
if (type === 'masked_guidance_layer') {
return t('regionalPrompts.maskedGuidance');
} else if (type === 'control_adapter_layer') {
return t('common.controlNet');
} else if (type === 'ip_adapter_layer') {
return t('common.ipAdapter');
}
}, [t, type]);
return (
<Text size="sm" fontWeight="semibold" pointerEvents="none" color="base.300">
{title}
</Text>
);
});
LayerTitle.displayName = 'LayerTitle';

View File

@ -2,6 +2,7 @@ import { Badge, Flex, Spacer } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { rgbColorToString } from 'features/canvas/util/colorToString'; import { rgbColorToString } from 'features/canvas/util/colorToString';
import { LayerTitle } from 'features/regionalPrompts/components/LayerTitle';
import { RPLayerColorPicker } from 'features/regionalPrompts/components/RPLayerColorPicker'; import { RPLayerColorPicker } from 'features/regionalPrompts/components/RPLayerColorPicker';
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton'; import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
import { RPLayerIPAdapterList } from 'features/regionalPrompts/components/RPLayerIPAdapterList'; import { RPLayerIPAdapterList } from 'features/regionalPrompts/components/RPLayerIPAdapterList';
@ -25,7 +26,7 @@ type Props = {
layerId: string; layerId: string;
}; };
export const RPLayerListItem = memo(({ layerId }: Props) => { export const MaskedGuidanceLayerListItem = memo(({ layerId }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo( const selector = useMemo(
@ -59,21 +60,21 @@ export const RPLayerListItem = memo(({ layerId }: Props) => {
borderRadius="base" borderRadius="base"
pe="1px" pe="1px"
py="1px" py="1px"
cursor="pointer"
> >
<Flex flexDir="column" gap={2} w="full" bg="base.850" p={2} borderRadius="base"> <Flex flexDir="column" w="full" bg="base.850" p={3} gap={3} borderRadius="base">
<Flex gap={3} alignItems="center"> <Flex gap={3} alignItems="center">
<RPLayerVisibilityToggle layerId={layerId} /> <RPLayerVisibilityToggle layerId={layerId} />
<RPLayerColorPicker layerId={layerId} /> <LayerTitle type="masked_guidance_layer" />
<Spacer /> <Spacer />
{autoNegative === 'invert' && ( {autoNegative === 'invert' && (
<Badge color="base.300" bg="transparent" borderWidth={1}> <Badge color="base.300" bg="transparent" borderWidth={1}>
{t('regionalPrompts.autoNegative')} {t('regionalPrompts.autoNegative')}
</Badge> </Badge>
)} )}
<RPLayerDeleteButton layerId={layerId} /> <RPLayerColorPicker layerId={layerId} />
<RPLayerSettingsPopover layerId={layerId} /> <RPLayerSettingsPopover layerId={layerId} />
<RPLayerMenu layerId={layerId} /> <RPLayerMenu layerId={layerId} />
<RPLayerDeleteButton layerId={layerId} />
</Flex> </Flex>
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons layerId={layerId} />} {!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons layerId={layerId} />}
{hasPositivePrompt && <RPLayerPositivePrompt layerId={layerId} />} {hasPositivePrompt && <RPLayerPositivePrompt layerId={layerId} />}
@ -84,4 +85,4 @@ export const RPLayerListItem = memo(({ layerId }: Props) => {
); );
}); });
RPLayerListItem.displayName = 'RPLayerListItem'; MaskedGuidanceLayerListItem.displayName = 'MaskedGuidanceLayerListItem';

View File

@ -29,7 +29,7 @@ const useAutoNegative = (layerId: string) => {
return autoNegative; return autoNegative;
}; };
export const RPLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => { export const MaskedGuidanceLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const autoNegative = useAutoNegative(layerId); const autoNegative = useAutoNegative(layerId);
@ -48,4 +48,4 @@ export const RPLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
); );
}); });
RPLayerAutoNegativeCheckbox.displayName = 'RPLayerAutoNegativeCheckbox'; MaskedGuidanceLayerAutoNegativeCheckbox.displayName = 'MaskedGuidanceLayerAutoNegativeCheckbox';

View File

@ -1,9 +1,11 @@
import { Flex } from '@invoke-ai/ui-library'; import { Divider, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { guidanceLayerIPAdapterDeleted } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ControlAdapterLayerConfig from 'features/regionalPrompts/components/controlAdapterOverrides/ControlAdapterLayerConfig';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
type Props = { type Props = {
@ -22,13 +24,55 @@ export const RPLayerIPAdapterList = memo(({ layerId }: Props) => {
); );
const ipAdapterIds = useAppSelector(selectIPAdapterIds); const ipAdapterIds = useAppSelector(selectIPAdapterIds);
if (ipAdapterIds.length === 0) {
return null;
}
return ( return (
<Flex w="full" flexDir="column" gap={2}> <>
{ipAdapterIds.map((id, index) => ( {ipAdapterIds.map((id, index) => (
<ControlAdapterConfig key={id} id={id} number={index + 1} /> <Flex flexDir="column" key={id}>
<Flex pb={3}>
<Divider />
</Flex>
<IPAdapterListItem layerId={layerId} ipAdapterId={id} ipAdapterNumber={index + 1} />
</Flex>
))} ))}
</Flex> </>
); );
}); });
RPLayerIPAdapterList.displayName = 'RPLayerIPAdapterList'; RPLayerIPAdapterList.displayName = 'RPLayerIPAdapterList';
type IPAdapterListItemProps = {
layerId: string;
ipAdapterId: string;
ipAdapterNumber: number;
};
const IPAdapterListItem = memo(({ layerId, ipAdapterId, ipAdapterNumber }: IPAdapterListItemProps) => {
const dispatch = useAppDispatch();
const onDeleteIPAdapter = useCallback(() => {
dispatch(guidanceLayerIPAdapterDeleted({ layerId, ipAdapterId }));
}, [dispatch, ipAdapterId, layerId]);
return (
<Flex flexDir="column" gap={3}>
<Flex alignItems="center" gap={3}>
<Text fontWeight="semibold" color="base.400">{`IP Adapter ${ipAdapterNumber}`}</Text>
<Spacer />
<IconButton
size="sm"
icon={<PiTrashSimpleBold />}
aria-label="Delete IP Adapter"
onClick={onDeleteIPAdapter}
variant="ghost"
colorScheme="error"
/>
</Flex>
<ControlAdapterLayerConfig id={ipAdapterId} />
</Flex>
);
});
IPAdapterListItem.displayName = 'IPAdapterListItem';

View File

@ -1,5 +1,6 @@
import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { guidanceLayerIPAdapterAdded } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import {
isMaskedGuidanceLayer, isMaskedGuidanceLayer,
@ -9,7 +10,6 @@ import {
layerMovedToBack, layerMovedToBack,
layerMovedToFront, layerMovedToFront,
layerReset, layerReset,
maskLayerIPAdapterAdded,
maskLayerNegativePromptChanged, maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged, maskLayerPositivePromptChanged,
selectRegionalPromptsSlice, selectRegionalPromptsSlice,
@ -59,7 +59,7 @@ export const RPLayerMenu = memo(({ layerId }: Props) => {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' })); dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]); }, [dispatch, layerId]);
const addIPAdapter = useCallback(() => { const addIPAdapter = useCallback(() => {
dispatch(maskLayerIPAdapterAdded(layerId)); dispatch(guidanceLayerIPAdapterAdded(layerId));
}, [dispatch, layerId]); }, [dispatch, layerId]);
const moveForward = useCallback(() => { const moveForward = useCallback(() => {
dispatch(layerMovedForward(layerId)); dispatch(layerMovedForward(layerId));

View File

@ -9,7 +9,7 @@ import {
PopoverContent, PopoverContent,
PopoverTrigger, PopoverTrigger,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import { RPLayerAutoNegativeCheckbox } from 'features/regionalPrompts/components/RPLayerAutoNegativeCheckbox'; import { MaskedGuidanceLayerAutoNegativeCheckbox } from 'features/regionalPrompts/components/RPLayerAutoNegativeCheckbox';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiGearSixBold } from 'react-icons/pi'; import { PiGearSixBold } from 'react-icons/pi';
@ -41,7 +41,7 @@ const RPLayerSettingsPopover = ({ layerId }: Props) => {
<PopoverBody> <PopoverBody>
<Flex direction="column" gap={2}> <Flex direction="column" gap={2}>
<FormControlGroup formLabelProps={formLabelProps}> <FormControlGroup formLabelProps={formLabelProps}>
<RPLayerAutoNegativeCheckbox layerId={layerId} /> <MaskedGuidanceLayerAutoNegativeCheckbox layerId={layerId} />
</FormControlGroup> </FormControlGroup>
</Flex> </Flex>
</PopoverBody> </PopoverBody>

View File

@ -4,20 +4,43 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButton'; import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButton';
import { ControlAdapterLayerListItem } from 'features/regionalPrompts/components/ControlAdapterLayerListItem';
import { DeleteAllLayersButton } from 'features/regionalPrompts/components/DeleteAllLayersButton'; import { DeleteAllLayersButton } from 'features/regionalPrompts/components/DeleteAllLayersButton';
import { RPLayerListItem } from 'features/regionalPrompts/components/RPLayerListItem'; import { IPAdapterLayerListItem } from 'features/regionalPrompts/components/IPAdapterLayerListItem';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { MaskedGuidanceLayerListItem } from 'features/regionalPrompts/components/MaskedGuidanceLayerListItem';
import {
isControlAdapterLayer,
isIPAdapterLayer,
isMaskedGuidanceLayer,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo } from 'react'; import { memo } from 'react';
const selectRPLayerIdsReversed = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => const selectMaskedGuidanceLayerIds = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.present.layers regionalPrompts.present.layers
.filter(isMaskedGuidanceLayer) .filter(isMaskedGuidanceLayer)
.map((l) => l.id) .map((l) => l.id)
.reverse() .reverse()
); );
const selectControlNetLayerIds = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.present.layers
.filter(isControlAdapterLayer)
.map((l) => l.id)
.reverse()
);
const selectIPAdapterLayerIds = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.present.layers
.filter(isIPAdapterLayer)
.map((l) => l.id)
.reverse()
);
export const RegionalPromptsPanelContent = memo(() => { export const RegionalPromptsPanelContent = memo(() => {
const rpLayerIdsReversed = useAppSelector(selectRPLayerIdsReversed); const maskedGuidanceLayerIds = useAppSelector(selectMaskedGuidanceLayerIds);
const controlNetLayerIds = useAppSelector(selectControlNetLayerIds);
const ipAdapterLayerIds = useAppSelector(selectIPAdapterLayerIds);
return ( return (
<Flex flexDir="column" gap={4} w="full" h="full"> <Flex flexDir="column" gap={4} w="full" h="full">
<Flex justifyContent="space-around"> <Flex justifyContent="space-around">
@ -26,8 +49,14 @@ export const RegionalPromptsPanelContent = memo(() => {
</Flex> </Flex>
<ScrollableContent> <ScrollableContent>
<Flex flexDir="column" gap={4}> <Flex flexDir="column" gap={4}>
{rpLayerIdsReversed.map((id) => ( {maskedGuidanceLayerIds.map((id) => (
<RPLayerListItem key={id} layerId={id} /> <MaskedGuidanceLayerListItem key={id} layerId={id} />
))}
{controlNetLayerIds.map((id) => (
<ControlAdapterLayerListItem key={id} layerId={id} />
))}
{ipAdapterLayerIds.map((id) => (
<IPAdapterLayerListItem key={id} layerId={id} />
))} ))}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>

View File

@ -18,9 +18,8 @@ import {
import { debouncedRenderers, renderers as normalRenderers } from 'features/regionalPrompts/util/renderers'; import { debouncedRenderers, renderers as normalRenderers } from 'features/regionalPrompts/util/renderers';
import Konva from 'konva'; import Konva from 'konva';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
import type { MutableRefObject } from 'react'; import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react';
import { memo, useCallback, useLayoutEffect, useMemo, useRef, useState } from 'react'; import { v4 as uuidv4 } from 'uuid';
import { assert } from 'tsafe';
// This will log warnings when layers > 5 - maybe use `import.meta.env.MODE === 'development'` instead? // This will log warnings when layers > 5 - maybe use `import.meta.env.MODE === 'development'` instead?
Konva.showWarnings = false; Konva.showWarnings = false;
@ -28,16 +27,14 @@ Konva.showWarnings = false;
const log = logger('regionalPrompts'); const log = logger('regionalPrompts');
const selectSelectedLayerColor = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => { const selectSelectedLayerColor = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === regionalPrompts.present.selectedLayerId); const layer = regionalPrompts.present.layers
if (!layer) { .filter(isMaskedGuidanceLayer)
return null; .find((l) => l.id === regionalPrompts.present.selectedLayerId);
} return layer?.previewColor ?? null;
assert(isMaskedGuidanceLayer(layer), `Layer ${regionalPrompts.present.selectedLayerId} is not an RP layer`);
return layer.previewColor;
}); });
const useStageRenderer = ( const useStageRenderer = (
stageRef: MutableRefObject<Konva.Stage>, stage: Konva.Stage,
container: HTMLDivElement | null, container: HTMLDivElement | null,
wrapper: HTMLDivElement | null, wrapper: HTMLDivElement | null,
asPreview: boolean asPreview: boolean
@ -79,25 +76,24 @@ const useStageRenderer = (
if (!container) { if (!container) {
return; return;
} }
const stage = stageRef.current.container(container); stage.container(container);
return () => { return () => {
log.trace('Cleaning up stage'); log.trace('Cleaning up stage');
stage.destroy(); stage.destroy();
}; };
}, [container, stageRef]); }, [container, stage]);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Adding stage listeners'); log.trace('Adding stage listeners');
if (asPreview) { if (asPreview) {
return; return;
} }
stageRef.current.on('mousedown', onMouseDown); stage.on('mousedown', onMouseDown);
stageRef.current.on('mouseup', onMouseUp); stage.on('mouseup', onMouseUp);
stageRef.current.on('mousemove', onMouseMove); stage.on('mousemove', onMouseMove);
stageRef.current.on('mouseenter', onMouseEnter); stage.on('mouseenter', onMouseEnter);
stageRef.current.on('mouseleave', onMouseLeave); stage.on('mouseleave', onMouseLeave);
stageRef.current.on('wheel', onMouseWheel); stage.on('wheel', onMouseWheel);
const stage = stageRef.current;
return () => { return () => {
log.trace('Cleaning up stage listeners'); log.trace('Cleaning up stage listeners');
@ -108,7 +104,7 @@ const useStageRenderer = (
stage.off('mouseleave', onMouseLeave); stage.off('mouseleave', onMouseLeave);
stage.off('wheel', onMouseWheel); stage.off('wheel', onMouseWheel);
}; };
}, [stageRef, asPreview, onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel]); }, [stage, asPreview, onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel]);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Updating stage dimensions'); log.trace('Updating stage dimensions');
@ -116,8 +112,6 @@ const useStageRenderer = (
return; return;
} }
const stage = stageRef.current;
const fitStageToContainer = () => { const fitStageToContainer = () => {
const newXScale = wrapper.offsetWidth / state.size.width; const newXScale = wrapper.offsetWidth / state.size.width;
const newYScale = wrapper.offsetHeight / state.size.height; const newYScale = wrapper.offsetHeight / state.size.height;
@ -135,7 +129,7 @@ const useStageRenderer = (
return () => { return () => {
resizeObserver.disconnect(); resizeObserver.disconnect();
}; };
}, [stageRef, state.size.width, state.size.height, wrapper]); }, [stage, state.size.width, state.size.height, wrapper]);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Rendering tool preview'); log.trace('Rendering tool preview');
@ -144,7 +138,7 @@ const useStageRenderer = (
return; return;
} }
renderers.renderToolPreview( renderers.renderToolPreview(
stageRef.current, stage,
tool, tool,
selectedLayerIdColor, selectedLayerIdColor,
state.globalMaskLayerOpacity, state.globalMaskLayerOpacity,
@ -155,7 +149,7 @@ const useStageRenderer = (
); );
}, [ }, [
asPreview, asPreview,
stageRef, stage,
tool, tool,
selectedLayerIdColor, selectedLayerIdColor,
state.globalMaskLayerOpacity, state.globalMaskLayerOpacity,
@ -168,8 +162,17 @@ const useStageRenderer = (
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Rendering layers'); log.trace('Rendering layers');
renderers.renderLayers(stageRef.current, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged); renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
}, [stageRef, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged, renderers]); }, [
stage,
state.layers,
state.globalMaskLayerOpacity,
tool,
onLayerPosChanged,
renderers,
state.size.width,
state.size.height,
]);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Rendering bbox'); log.trace('Rendering bbox');
@ -177,8 +180,8 @@ const useStageRenderer = (
// Preview should not display bboxes // Preview should not display bboxes
return; return;
} }
renderers.renderBbox(stageRef.current, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown); renderers.renderBbox(stage, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
}, [stageRef, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderers]); }, [stage, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderers]);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Rendering background'); log.trace('Rendering background');
@ -186,13 +189,13 @@ const useStageRenderer = (
// The preview should not have a background // The preview should not have a background
return; return;
} }
renderers.renderBackground(stageRef.current, state.size.width, state.size.height); renderers.renderBackground(stage, state.size.width, state.size.height);
}, [stageRef, asPreview, state.size.width, state.size.height, renderers]); }, [stage, asPreview, state.size.width, state.size.height, renderers]);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Arranging layers'); log.trace('Arranging layers');
renderers.arrangeLayers(stageRef.current, layerIds); renderers.arrangeLayers(stage, layerIds);
}, [stageRef, layerIds, renderers]); }, [stage, layerIds, renderers]);
}; };
type Props = { type Props = {
@ -200,10 +203,8 @@ type Props = {
}; };
export const StageComponent = memo(({ asPreview = false }: Props) => { export const StageComponent = memo(({ asPreview = false }: Props) => {
const stageRef = useRef<Konva.Stage>( const [stage] = useState(
new Konva.Stage({ () => new Konva.Stage({ id: uuidv4(), container: document.createElement('div'), listening: !asPreview })
container: document.createElement('div'), // We will overwrite this shortly...
})
); );
const [container, setContainer] = useState<HTMLDivElement | null>(null); const [container, setContainer] = useState<HTMLDivElement | null>(null);
const [wrapper, setWrapper] = useState<HTMLDivElement | null>(null); const [wrapper, setWrapper] = useState<HTMLDivElement | null>(null);
@ -216,7 +217,7 @@ export const StageComponent = memo(({ asPreview = false }: Props) => {
setWrapper(el); setWrapper(el);
}, []); }, []);
useStageRenderer(stageRef, container, wrapper, asPreview); useStageRenderer(stage, container, wrapper, asPreview);
return ( return (
<Flex overflow="hidden" w="full" h="full"> <Flex overflow="hidden" w="full" h="full">

View File

@ -1,12 +1,7 @@
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library'; import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import { $tool, selectedLayerDeleted, selectedLayerReset } from 'features/regionalPrompts/store/regionalPromptsSlice';
$tool,
layerAdded,
selectedLayerDeleted,
selectedLayerReset,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -40,11 +35,6 @@ export const ToolChooser: React.FC = () => {
}, [dispatch]); }, [dispatch]);
useHotkeys('shift+c', resetSelectedLayer); useHotkeys('shift+c', resetSelectedLayer);
const addLayer = useCallback(() => {
dispatch(layerAdded('masked_guidance_layer'));
}, [dispatch]);
useHotkeys('shift+a', addLayer);
const deleteSelectedLayer = useCallback(() => { const deleteSelectedLayer = useCallback(() => {
dispatch(selectedLayerDeleted()); dispatch(selectedLayerDeleted());
}, [dispatch]); }, [dispatch]);

View File

@ -0,0 +1,230 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
import {
controlAdapterImageChanged,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { heightChanged, widthChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiFloppyDiskBold, PiRulerBold } from 'react-icons/pi';
import {
useAddImageToBoardMutation,
useChangeImageIsIntermediateMutation,
useGetImageDTOQuery,
useRemoveImageFromBoardMutation,
} from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
type Props = {
id: string;
isSmall?: boolean;
};
const selectPendingControlImages = createMemoizedSelector(
selectControlAdaptersSlice,
(controlAdapters) => controlAdapters.pendingControlImages
);
const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const controlImageName = useControlAdapterControlImage(id);
const processedControlImageName = useControlAdapterProcessedControlImage(id);
const processorType = useControlAdapterProcessorType(id);
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const isConnected = useAppSelector((s) => s.system.isConnected);
const activeTabName = useAppSelector(activeTabNameSelector);
const optimalDimension = useAppSelector(selectOptimalDimension);
const pendingControlImages = useAppSelector(selectPendingControlImages);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
controlImageName ?? skipToken
);
const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery(
processedControlImageName ?? skipToken
);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => {
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
}, [id, dispatch]);
const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) {
return;
}
await changeIsIntermediate({
imageDTO: processedControlImage,
is_intermediate: false,
}).unwrap();
if (autoAddBoardId !== 'none') {
addToBoard({
imageDTO: processedControlImage,
board_id: autoAddBoardId,
});
} else {
removeFromBoard({ imageDTO: processedControlImage });
}
}, [processedControlImage, changeIsIntermediate, autoAddBoardId, addToBoard, removeFromBoard]);
const handleSetControlImageToDimensions = useCallback(() => {
if (!controlImage) {
return;
}
if (activeTabName === 'unifiedCanvas') {
dispatch(setBoundingBoxDimensions({ width: controlImage.width, height: controlImage.height }, optimalDimension));
} else {
const { width, height } = calculateNewSize(
controlImage.width / controlImage.height,
optimalDimension * optimalDimension
);
dispatch(widthChanged({ width: controlImage.width, updateAspectRatio: true }));
dispatch(heightChanged({ height: controlImage.height, updateAspectRatio: true }));
}
}, [controlImage, activeTabName, dispatch, optimalDimension]);
const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true);
}, []);
const handleMouseLeave = useCallback(() => {
setIsMouseOverImage(false);
}, []);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) {
return {
id,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, id]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id,
actionType: 'SET_CONTROL_ADAPTER_IMAGE',
context: { id },
}),
[id]
);
const postUploadAction = useMemo<PostUploadAction>(() => ({ type: 'SET_CONTROL_ADAPTER_IMAGE', id }), [id]);
const shouldShowProcessedImage =
controlImage &&
processedControlImage &&
!isMouseOverImage &&
!pendingControlImages.includes(id) &&
processorType !== 'none';
useEffect(() => {
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
handleResetControlImage();
}
}, [handleResetControlImage, isConnected, isErrorControlImage, isErrorProcessedControlImage]);
return (
<Flex
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
position="relative"
w="full"
h={isSmall ? 36 : 366} // magic no touch
alignItems="center"
justifyContent="center"
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage}
postUploadAction={postUploadAction}
/>
<Box
position="absolute"
top={0}
insetInlineStart={0}
w="full"
h="full"
opacity={shouldShowProcessedImage ? 1 : 0}
transitionProperty="common"
transitionDuration="normal"
pointerEvents="none"
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={processedControlImage}
isUploadDisabled={true}
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{pendingControlImages.includes(id) && (
<Flex
position="absolute"
top={0}
insetInlineStart={0}
w="full"
h="full"
alignItems="center"
justifyContent="center"
opacity={0.8}
borderRadius="base"
bg="base.900"
>
<Spinner size="xl" color="base.400" />
</Flex>
)}
</Flex>
);
};
export default memo(ControlAdapterImagePreview);
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@ -0,0 +1,70 @@
import { Box, Flex, Icon, IconButton } from '@invoke-ai/ui-library';
import ControlAdapterProcessorComponent from 'features/controlAdapters/components/ControlAdapterProcessorComponent';
import ControlAdapterShouldAutoConfig from 'features/controlAdapters/components/ControlAdapterShouldAutoConfig';
import ParamControlAdapterIPMethod from 'features/controlAdapters/components/parameters/ParamControlAdapterIPMethod';
import ParamControlAdapterProcessorSelect from 'features/controlAdapters/components/parameters/ParamControlAdapterProcessorSelect';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretUpBold } from 'react-icons/pi';
import { useToggle } from 'react-use';
import ControlAdapterImagePreview from './ControlAdapterImagePreview';
import { ParamControlAdapterBeginEnd } from './ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './ParamControlAdapterControlMode';
import ParamControlAdapterModel from './ParamControlAdapterModel';
import ParamControlAdapterWeight from './ParamControlAdapterWeight';
const ControlAdapterLayerConfig = (props: { id: string }) => {
const { id } = props;
const controlAdapterType = useControlAdapterType(id);
const { t } = useTranslation();
const [isExpanded, toggleIsExpanded] = useToggle(false);
return (
<Flex flexDir="column" gap={4} position="relative">
<Flex gap={3} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<ParamControlAdapterModel id={id} />{' '}
</Box>
<IconButton
size="sm"
tooltip={isExpanded ? t('controlnet.hideAdvanced') : t('controlnet.showAdvanced')}
aria-label={isExpanded ? t('controlnet.hideAdvanced') : t('controlnet.showAdvanced')}
onClick={toggleIsExpanded}
variant="ghost"
icon={
<Icon
boxSize={4}
as={PiCaretUpBold}
transform={isExpanded ? 'rotate(0deg)' : 'rotate(180deg)'}
transitionProperty="common"
transitionDuration="normal"
/>
}
/>
</Flex>
<Flex gap={4} w="full" alignItems="center">
<Flex flexDir="column" gap={3} w="full">
{controlAdapterType === 'ip_adapter' && <ParamControlAdapterIPMethod id={id} />}
{controlAdapterType === 'controlnet' && <ParamControlAdapterControlMode id={id} />}
<ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<ControlAdapterImagePreview id={id} isSmall />
</Flex>
</Flex>
{isExpanded && (
<>
<ControlAdapterShouldAutoConfig id={id} />
<ParamControlAdapterProcessorSelect id={id} />
<ControlAdapterProcessorComponent id={id} />
</>
)}
</Flex>
);
};
export default memo(ControlAdapterLayerConfig);

View File

@ -0,0 +1,89 @@
import { CompositeRangeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterBeginEndStepPct } from 'features/controlAdapters/hooks/useControlAdapterBeginEndStepPct';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import {
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
export const ParamControlAdapterBeginEnd = memo(({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const stepPcts = useControlAdapterBeginEndStepPct(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const onChange = useCallback(
(v: [number, number]) => {
dispatch(
controlAdapterBeginStepPctChanged({
id,
beginStepPct: v[0],
})
);
dispatch(
controlAdapterEndStepPctChanged({
id,
endStepPct: v[1],
})
);
},
[dispatch, id]
);
const onReset = useCallback(() => {
dispatch(
controlAdapterBeginStepPctChanged({
id,
beginStepPct: 0,
})
);
dispatch(
controlAdapterEndStepPctChanged({
id,
endStepPct: 1,
})
);
}, [dispatch, id]);
const value = useMemo<[number, number]>(() => [stepPcts?.beginStepPct ?? 0, stepPcts?.endStepPct ?? 1], [stepPcts]);
if (!stepPcts) {
return null;
}
return (
<FormControl isDisabled={!isEnabled} orientation="horizontal">
<InformationalPopover feature="controlNetBeginEnd">
<FormLabel m={0}>{t('controlnet.beginEndStepPercentShort')}</FormLabel>
</InformationalPopover>
<CompositeRangeSlider
aria-label={ariaLabel}
value={value}
onChange={onChange}
onReset={onReset}
min={0}
max={1}
step={0.05}
fineStep={0.01}
minStepsBetweenThumbs={1}
formatValue={formatPct}
marks
withThumbTooltip
/>
</FormControl>
);
});
ParamControlAdapterBeginEnd.displayName = 'ParamControlAdapterBeginEnd';
const ariaLabel = ['Begin Step %', 'End Step %'];

View File

@ -0,0 +1,66 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterControlMode } from 'features/controlAdapters/hooks/useControlAdapterControlMode';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { controlAdapterControlModeChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlMode } from 'features/controlAdapters/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const ParamControlAdapterControlMode = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlMode = useControlAdapterControlMode(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const CONTROL_MODE_DATA = useMemo(
() => [
{ label: t('controlnet.balanced'), value: 'balanced' },
{ label: t('controlnet.prompt'), value: 'more_prompt' },
{ label: t('controlnet.control'), value: 'more_control' },
{ label: t('controlnet.megaControl'), value: 'unbalanced' },
],
[t]
);
const handleControlModeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
return;
}
dispatch(
controlAdapterControlModeChanged({
id,
controlMode: v.value as ControlMode,
})
);
},
[id, dispatch]
);
const value = useMemo(
() => CONTROL_MODE_DATA.filter((o) => o.value === controlMode)[0],
[CONTROL_MODE_DATA, controlMode]
);
if (!controlMode) {
return null;
}
return (
<FormControl isDisabled={!isEnabled}>
<InformationalPopover feature="controlNetControlMode">
<FormLabel m={0}>{t('controlnet.control')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={CONTROL_MODE_DATA} onChange={handleControlModeChange} />
</FormControl>
);
};
export default memo(ParamControlAdapterControlMode);

View File

@ -0,0 +1,136 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import {
controlAdapterCLIPVisionModelChanged,
controlAdapterModelChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type {
AnyModelConfig,
ControlNetModelConfig,
IPAdapterModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
type ParamControlAdapterModelProps = {
id: string;
};
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation();
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
return;
}
dispatch(
controlAdapterModelChanged({
id,
modelConfig,
})
);
},
[dispatch, id]
);
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v?.value) {
return;
}
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
},
[dispatch, id]
);
const selectedModel = useMemo(
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, modelConfig]
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel,
getIsDisabled,
isLoading,
});
const clipVisionOptions = useMemo<ComboboxOption[]>(
() => [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
],
[]
);
const clipVisionModel = useMemo(
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
[clipVisionOptions, currentCLIPVisionModel]
);
return (
<Flex gap={4}>
<Tooltip label={selectedModel?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base} w="full">
<Combobox
options={options}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
<FormControl
isDisabled={!isEnabled}
isInvalid={!value || mainModel?.base !== modelConfig?.base}
width="max-content"
minWidth={28}
>
<Combobox
options={clipVisionOptions}
placeholder={t('controlnet.selectCLIPVisionModel')}
value={clipVisionModel}
onChange={onCLIPVisionModelChange}
/>
</FormControl>
)}
</Flex>
);
};
export default memo(ParamControlAdapterModel);

View File

@ -0,0 +1,74 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterWeight } from 'features/controlAdapters/hooks/useControlAdapterWeight';
import { controlAdapterWeightChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlAdapterWeightProps = {
id: string;
};
const formatValue = (v: number) => v.toFixed(2);
const ParamControlAdapterWeight = ({ id }: ParamControlAdapterWeightProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isEnabled = useControlAdapterIsEnabled(id);
const weight = useControlAdapterWeight(id);
const initial = useAppSelector((s) => s.config.sd.ca.weight.initial);
const sliderMin = useAppSelector((s) => s.config.sd.ca.weight.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.ca.weight.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.ca.weight.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.ca.weight.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.ca.weight.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.ca.weight.fineStep);
const onChange = useCallback(
(weight: number) => {
dispatch(controlAdapterWeightChanged({ id, weight }));
},
[dispatch, id]
);
if (isNil(weight)) {
// should never happen
return null;
}
return (
<FormControl isDisabled={!isEnabled} orientation="horizontal">
<InformationalPopover feature="controlNetWeight">
<FormLabel m={0}>{t('controlnet.weight')}</FormLabel>
</InformationalPopover>
<CompositeSlider
value={weight}
onChange={onChange}
defaultValue={initial}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
marks={marks}
formatValue={formatValue}
/>
<CompositeNumberInput
value={weight}
onChange={onChange}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
maxW={20}
defaultValue={initial}
/>
</FormControl>
);
};
export default memo(ParamControlAdapterWeight);
const marks = [0, 1, 2];

View File

@ -1,6 +1,9 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import {
isMaskedGuidanceLayer,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -39,8 +42,8 @@ export const useLayerIsVisible = (layerId: string) => {
() => () =>
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => { createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId); const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(isMaskedGuidanceLayer(layer), `Layer ${layerId} not found or not an RP layer`); assert(layer, `Layer ${layerId} not found`);
return layer.isVisible; return layer.isEnabled;
}), }),
[layerId] [layerId]
); );

View File

@ -10,7 +10,7 @@ const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (region
} }
const validLayers = regionalPrompts.present.layers const validLayers = regionalPrompts.present.layers
.filter(isMaskedGuidanceLayer) .filter(isMaskedGuidanceLayer)
.filter((l) => l.isVisible) .filter((l) => l.isEnabled)
.filter((l) => { .filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt); const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
const hasAtLeastOneImagePrompt = l.ipAdapterIds.length > 0; const hasAtLeastOneImagePrompt = l.ipAdapterIds.length > 0;

View File

@ -4,20 +4,16 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils'; import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { roundToMultiple } from 'common/util/roundDownToMultiple'; import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { controlAdapterRemoved, isAnyControlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
isAnyControlAdapterAdded,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants'; import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import { modelChanged } from 'features/parameters/store/generationSlice'; import { modelChanged } from 'features/parameters/store/generationSlice';
import type { import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
ParameterAutoNegative,
ParameterHeight,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect, Vector2d } from 'konva/lib/types'; import type { IRect, Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -27,81 +23,17 @@ import type { UndoableOptions } from 'redux-undo';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
type DrawingTool = 'brush' | 'eraser'; import type {
ControlAdapterLayer,
export type Tool = DrawingTool | 'move' | 'rect'; DrawingTool,
IPAdapterLayer,
export type VectorMaskLine = { Layer,
id: string; MaskedGuidanceLayer,
type: 'vector_mask_line'; RegionalPromptsState,
tool: DrawingTool; Tool,
strokeWidth: number; VectorMaskLine,
points: number[]; VectorMaskRect,
}; } from './types';
export type VectorMaskRect = {
id: string;
type: 'vector_mask_rect';
x: number;
y: number;
width: number;
height: number;
};
type LayerBase = {
id: string;
isVisible: boolean;
};
type RenderableLayerBase = LayerBase & {
x: number;
y: number;
bbox: IRect | null;
bboxNeedsUpdate: boolean;
};
type ControlAdapterLayer = RenderableLayerBase & {
type: 'controlnet_layer'; // technically, also t2i adapter layer
controlAdapterId: string;
};
type IPAdapterLayer = LayerBase & {
type: 'ip_adapter_layer'; // technically, also t2i adapter layer
ipAdapterId: string;
};
export type MaskedGuidanceLayer = RenderableLayerBase & {
type: 'masked_guidance_layer';
maskObjects: (VectorMaskLine | VectorMaskRect)[];
positivePrompt: ParameterPositivePrompt | null;
negativePrompt: ParameterNegativePrompt | null; // Up to one text prompt per mask
ipAdapterIds: string[]; // Any number of image prompts
previewColor: RgbColor;
autoNegative: ParameterAutoNegative;
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
};
export type Layer = MaskedGuidanceLayer | ControlAdapterLayer | IPAdapterLayer;
type RegionalPromptsState = {
_version: 1;
selectedLayerId: string | null;
layers: Layer[];
brushSize: number;
globalMaskLayerOpacity: number;
isEnabled: boolean;
positivePrompt: ParameterPositivePrompt;
negativePrompt: ParameterNegativePrompt;
positivePrompt2: ParameterPositiveStylePromptSDXL;
negativePrompt2: ParameterNegativeStylePromptSDXL;
shouldConcatPrompts: boolean;
initialImage: string | null;
size: {
width: ParameterWidth;
height: ParameterHeight;
aspectRatio: AspectRatioState;
};
};
export const initialRegionalPromptsState: RegionalPromptsState = { export const initialRegionalPromptsState: RegionalPromptsState = {
_version: 1, _version: 1,
@ -126,19 +58,22 @@ export const initialRegionalPromptsState: RegionalPromptsState = {
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line'; const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
export const isMaskedGuidanceLayer = (layer?: Layer): layer is MaskedGuidanceLayer => export const isMaskedGuidanceLayer = (layer?: Layer): layer is MaskedGuidanceLayer =>
layer?.type === 'masked_guidance_layer'; layer?.type === 'masked_guidance_layer';
export const isRenderableLayer = (layer?: Layer): layer is MaskedGuidanceLayer => export const isControlAdapterLayer = (layer?: Layer): layer is ControlAdapterLayer =>
layer?.type === 'masked_guidance_layer' || layer?.type === 'controlnet_layer'; layer?.type === 'control_adapter_layer';
export const isIPAdapterLayer = (layer?: Layer): layer is IPAdapterLayer => layer?.type === 'ip_adapter_layer';
export const isRenderableLayer = (layer?: Layer): layer is MaskedGuidanceLayer | ControlAdapterLayer =>
layer?.type === 'masked_guidance_layer' || layer?.type === 'control_adapter_layer';
const resetLayer = (layer: Layer) => { const resetLayer = (layer: Layer) => {
if (layer.type === 'masked_guidance_layer') { if (layer.type === 'masked_guidance_layer') {
layer.maskObjects = []; layer.maskObjects = [];
layer.bbox = null; layer.bbox = null;
layer.isVisible = true; layer.isEnabled = true;
layer.needsPixelBbox = false; layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false; layer.bboxNeedsUpdate = false;
return; return;
} }
if (layer.type === 'controlnet_layer') { if (layer.type === 'control_adapter_layer') {
// TODO // TODO
} }
}; };
@ -153,59 +88,71 @@ export const regionalPromptsSlice = createSlice({
initialState: initialRegionalPromptsState, initialState: initialRegionalPromptsState,
reducers: { reducers: {
//#region All Layers //#region All Layers
layerAdded: { maskedGuidanceLayerAdded: (state, action: PayloadAction<{ layerId: string }>) => {
reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => { const { layerId } = action.payload;
const type = action.payload; const layer: MaskedGuidanceLayer = {
if (type === 'masked_guidance_layer') { id: getMaskedGuidanceLayerId(layerId),
const layer: MaskedGuidanceLayer = { type: 'masked_guidance_layer',
id: getMaskedGuidanceLayerId(action.meta.uuid), isEnabled: true,
type: 'masked_guidance_layer', bbox: null,
isVisible: true, bboxNeedsUpdate: false,
bbox: null, maskObjects: [],
bboxNeedsUpdate: false, previewColor: getVectorMaskPreviewColor(state),
maskObjects: [], x: 0,
previewColor: getVectorMaskPreviewColor(state), y: 0,
x: 0, autoNegative: 'invert',
y: 0, needsPixelBbox: false,
autoNegative: 'invert', positivePrompt: '',
needsPixelBbox: false, negativePrompt: null,
positivePrompt: '', ipAdapterIds: [],
negativePrompt: null, isSelected: true,
ipAdapterIds: [], };
}; state.layers.push(layer);
state.layers.push(layer); state.selectedLayerId = layer.id;
state.selectedLayerId = layer.id; return;
return; },
} ipAdapterLayerAdded: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
if (type === 'controlnet_layer') { const layer: IPAdapterLayer = {
const layer: ControlAdapterLayer = { id: getIPAdapterLayerId(layerId),
id: getControlLayerId(action.meta.uuid), type: 'ip_adapter_layer',
type: 'controlnet_layer', isEnabled: true,
controlAdapterId: action.meta.uuid, ipAdapterId,
x: 0, };
y: 0, state.layers.push(layer);
bbox: null, return;
bboxNeedsUpdate: false, },
isVisible: true, controlAdapterLayerAdded: (state, action: PayloadAction<{ layerId: string; controlNetId: string }>) => {
}; const { layerId, controlNetId } = action.payload;
state.layers.push(layer); const layer: ControlAdapterLayer = {
state.selectedLayerId = layer.id; id: getControlNetLayerId(layerId),
return; type: 'control_adapter_layer',
} controlNetId,
}, x: 0,
prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }), y: 0,
bbox: null,
bboxNeedsUpdate: false,
isEnabled: true,
imageName: null,
opacity: 1,
isSelected: true,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
return;
}, },
layerSelected: (state, action: PayloadAction<string>) => { layerSelected: (state, action: PayloadAction<string>) => {
const layer = state.layers.find((l) => l.id === action.payload); for (const layer of state.layers) {
if (layer) { if (isRenderableLayer(layer) && layer.id === action.payload) {
state.selectedLayerId = layer.id; layer.isSelected = true;
state.selectedLayerId = action.payload;
}
} }
}, },
layerVisibilityToggled: (state, action: PayloadAction<string>) => { layerVisibilityToggled: (state, action: PayloadAction<string>) => {
const layer = state.layers.find((l) => l.id === action.payload); const layer = state.layers.find((l) => l.id === action.payload);
if (layer) { if (layer) {
layer.isVisible = !layer.isVisible; layer.isEnabled = !layer.isEnabled;
} }
}, },
layerTranslated: (state, action: PayloadAction<{ layerId: string; x: number; y: number }>) => { layerTranslated: (state, action: PayloadAction<{ layerId: string; x: number; y: number }>) => {
@ -252,10 +199,6 @@ export const regionalPromptsSlice = createSlice({
// Because the layers are in reverse order, moving to the back is equivalent to moving to the front // Because the layers are in reverse order, moving to the back is equivalent to moving to the front
moveToFront(state.layers, cb); moveToFront(state.layers, cb);
}, },
allLayersDeleted: (state) => {
state.layers = [];
state.selectedLayerId = null;
},
selectedLayerReset: (state) => { selectedLayerReset: (state) => {
const layer = state.layers.find((l) => l.id === state.selectedLayerId); const layer = state.layers.find((l) => l.id === state.selectedLayerId);
if (layer) { if (layer) {
@ -283,14 +226,19 @@ export const regionalPromptsSlice = createSlice({
layer.negativePrompt = prompt; layer.negativePrompt = prompt;
} }
}, },
maskLayerIPAdapterAdded: { maskLayerIPAdapterAdded: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => { const { layerId, ipAdapterId } = action.payload;
const layer = state.layers.find((l) => l.id === action.payload); const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'masked_guidance_layer') { if (layer?.type === 'masked_guidance_layer') {
layer.ipAdapterIds.push(action.meta.uuid); layer.ipAdapterIds.push(ipAdapterId);
} }
}, },
prepare: (payload: string) => ({ payload, meta: { uuid: uuidv4() } }), maskLayerIPAdapterDeleted: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'masked_guidance_layer') {
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== ipAdapterId);
}
}, },
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => { maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload; const { layerId, color } = action.payload;
@ -422,10 +370,13 @@ export const regionalPromptsSlice = createSlice({
//#region General //#region General
brushSizeChanged: (state, action: PayloadAction<number>) => { brushSizeChanged: (state, action: PayloadAction<number>) => {
state.brushSize = action.payload; state.brushSize = Math.round(action.payload);
}, },
globalMaskLayerOpacityChanged: (state, action: PayloadAction<number>) => { globalMaskLayerOpacityChanged: (state, action: PayloadAction<number>) => {
state.globalMaskLayerOpacity = action.payload; state.globalMaskLayerOpacity = action.payload;
state.layers.filter(isControlAdapterLayer).forEach((l) => {
l.opacity = action.payload;
});
}, },
isEnabledChanged: (state, action: PayloadAction<boolean>) => { isEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isEnabled = action.payload; state.isEnabled = action.payload;
@ -445,12 +396,6 @@ export const regionalPromptsSlice = createSlice({
//#endregion //#endregion
}, },
extraReducers(builder) { extraReducers(builder) {
builder.addCase(controlAdapterRemoved, (state, action) => {
state.layers.filter(isMaskedGuidanceLayer).forEach((layer) => {
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id);
});
});
builder.addCase(modelChanged, (state, action) => { builder.addCase(modelChanged, (state, action) => {
const newModel = action.payload; const newModel = action.payload;
if (!newModel || action.meta.previousModel?.base === newModel.base) { if (!newModel || action.meta.previousModel?.base === newModel.base) {
@ -466,6 +411,28 @@ export const regionalPromptsSlice = createSlice({
state.size.height = height; state.size.height = height;
}); });
builder.addCase(controlAdapterImageChanged, (state, action) => {
const { id, controlImage } = action.payload;
const layer = state.layers.filter(isControlAdapterLayer).find((l) => l.controlNetId === id);
if (layer) {
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.imageName = controlImage?.image_name ?? null;
}
});
builder.addCase(controlAdapterProcessedImageChanged, (state, action) => {
const { id, processedControlImage } = action.payload;
const layer = state.layers.filter(isControlAdapterLayer).find((l) => l.controlNetId === id);
if (layer) {
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.imageName = processedControlImage?.image_name ?? null;
}
});
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling // TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
// factor than the UNet. Hopefully we get an upstream fix in diffusers. // factor than the UNet. Hopefully we get an upstream fix in diffusers.
builder.addMatcher(isAnyControlAdapterAdded, (state, action) => { builder.addMatcher(isAnyControlAdapterAdded, (state, action) => {
@ -510,7 +477,6 @@ class LayerColors {
export const { export const {
// All layer actions // All layer actions
layerAdded,
layerDeleted, layerDeleted,
layerMovedBackward, layerMovedBackward,
layerMovedForward, layerMovedForward,
@ -521,9 +487,11 @@ export const {
layerTranslated, layerTranslated,
layerBboxChanged, layerBboxChanged,
layerVisibilityToggled, layerVisibilityToggled,
allLayersDeleted,
selectedLayerReset, selectedLayerReset,
selectedLayerDeleted, selectedLayerDeleted,
maskedGuidanceLayerAdded,
ipAdapterLayerAdded,
controlAdapterLayerAdded,
// Mask layer actions // Mask layer actions
maskLayerLineAdded, maskLayerLineAdded,
maskLayerPointsAdded, maskLayerPointsAdded,
@ -531,6 +499,7 @@ export const {
maskLayerNegativePromptChanged, maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged, maskLayerPositivePromptChanged,
maskLayerIPAdapterAdded, maskLayerIPAdapterAdded,
maskLayerIPAdapterDeleted,
maskLayerAutoNegativeChanged, maskLayerAutoNegativeChanged,
maskLayerPreviewColorChanged, maskLayerPreviewColorChanged,
// Base layer actions // Base layer actions
@ -549,6 +518,20 @@ export const {
redo, redo,
} = regionalPromptsSlice.actions; } = regionalPromptsSlice.actions;
export const selectAllControlAdapterIds = (regionalPrompts: RegionalPromptsState) =>
regionalPrompts.layers.flatMap((l) => {
if (l.type === 'control_adapter_layer') {
return [l.controlNetId];
}
if (l.type === 'ip_adapter_layer') {
return [l.ipAdapterId];
}
if (l.type === 'masked_guidance_layer') {
return l.ipAdapterIds;
}
return [];
});
export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts; export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ /* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@ -571,8 +554,11 @@ export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_bord
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect'; export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
export const BACKGROUND_LAYER_ID = 'background_layer'; export const BACKGROUND_LAYER_ID = 'background_layer';
export const BACKGROUND_RECT_ID = 'background_layer.rect'; export const BACKGROUND_RECT_ID = 'background_layer.rect';
export const CONTROLNET_LAYER_TRANSFORMER_ID = 'control_adapter_layer.transformer';
// Names (aka classes) for Konva layers and objects // Names (aka classes) for Konva layers and objects
export const CONTROLNET_LAYER_NAME = 'control_adapter_layer';
export const CONTROLNET_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
export const MASKED_GUIDANCE_LAYER_NAME = 'masked_guidance_layer'; export const MASKED_GUIDANCE_LAYER_NAME = 'masked_guidance_layer';
export const MASKED_GUIDANCE_LAYER_LINE_NAME = 'masked_guidance_layer.line'; export const MASKED_GUIDANCE_LAYER_LINE_NAME = 'masked_guidance_layer.line';
export const MASKED_GUIDANCE_LAYER_OBJECT_GROUP_NAME = 'masked_guidance_layer.object_group'; export const MASKED_GUIDANCE_LAYER_OBJECT_GROUP_NAME = 'masked_guidance_layer.object_group';
@ -586,7 +572,9 @@ const getMaskedGuidnaceLayerRectId = (layerId: string, lineId: string) => `${lay
export const getMaskedGuidanceLayerObjectGroupId = (layerId: string, groupId: string) => export const getMaskedGuidanceLayerObjectGroupId = (layerId: string, groupId: string) =>
`${layerId}.objectGroup_${groupId}`; `${layerId}.objectGroup_${groupId}`;
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`; export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
const getControlLayerId = (layerId: string) => `control_layer_${layerId}`; const getControlNetLayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
export const getControlNetLayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
const getIPAdapterLayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
export const regionalPromptsPersistConfig: PersistConfig<RegionalPromptsState> = { export const regionalPromptsPersistConfig: PersistConfig<RegionalPromptsState> = {
name: regionalPromptsSlice.name, name: regionalPromptsSlice.name,

View File

@ -0,0 +1,91 @@
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type {
ParameterAutoNegative,
ParameterHeight,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import type { IRect } from 'konva/lib/types';
import type { RgbColor } from 'react-colorful';
export type DrawingTool = 'brush' | 'eraser';
export type Tool = DrawingTool | 'move' | 'rect';
export type VectorMaskLine = {
id: string;
type: 'vector_mask_line';
tool: DrawingTool;
strokeWidth: number;
points: number[];
};
export type VectorMaskRect = {
id: string;
type: 'vector_mask_rect';
x: number;
y: number;
width: number;
height: number;
};
export type LayerBase = {
id: string;
isEnabled: boolean;
};
export type RenderableLayerBase = LayerBase & {
x: number;
y: number;
bbox: IRect | null;
bboxNeedsUpdate: boolean;
isSelected: boolean;
};
export type ControlAdapterLayer = RenderableLayerBase & {
type: 'control_adapter_layer'; // technically, also t2i adapter layer
controlNetId: string;
imageName: string | null;
opacity: number;
};
export type IPAdapterLayer = LayerBase & {
type: 'ip_adapter_layer'; // technically, also t2i adapter layer
ipAdapterId: string;
};
export type MaskedGuidanceLayer = RenderableLayerBase & {
type: 'masked_guidance_layer';
maskObjects: (VectorMaskLine | VectorMaskRect)[];
positivePrompt: ParameterPositivePrompt | null;
negativePrompt: ParameterNegativePrompt | null; // Up to one text prompt per mask
ipAdapterIds: string[]; // Any number of image prompts
previewColor: RgbColor;
autoNegative: ParameterAutoNegative;
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
};
export type Layer = MaskedGuidanceLayer | ControlAdapterLayer | IPAdapterLayer;
export type RegionalPromptsState = {
_version: 1;
selectedLayerId: string | null;
layers: Layer[];
brushSize: number;
globalMaskLayerOpacity: number;
isEnabled: boolean;
positivePrompt: ParameterPositivePrompt;
negativePrompt: ParameterNegativePrompt;
positivePrompt2: ParameterPositiveStylePromptSDXL;
negativePrompt2: ParameterNegativeStylePromptSDXL;
shouldConcatPrompts: boolean;
initialImage: string | null;
size: {
width: ParameterWidth;
height: ParameterHeight;
aspectRatio: AspectRatioState;
};
};

View File

@ -1,20 +1,18 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString'; import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
import { getScaledFlooredCursorPosition } from 'features/regionalPrompts/hooks/mouseEventHooks'; import { getScaledFlooredCursorPosition } from 'features/regionalPrompts/hooks/mouseEventHooks';
import type {
Layer,
MaskedGuidanceLayer,
Tool,
VectorMaskLine,
VectorMaskRect,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { import {
$tool, $tool,
BACKGROUND_LAYER_ID, BACKGROUND_LAYER_ID,
BACKGROUND_RECT_ID, BACKGROUND_RECT_ID,
CONTROLNET_LAYER_IMAGE_NAME,
CONTROLNET_LAYER_NAME,
getControlNetLayerImageId,
getLayerBboxId, getLayerBboxId,
getMaskedGuidanceLayerObjectGroupId, getMaskedGuidanceLayerObjectGroupId,
isControlAdapterLayer,
isMaskedGuidanceLayer, isMaskedGuidanceLayer,
isRenderableLayer,
LAYER_BBOX_NAME, LAYER_BBOX_NAME,
MASKED_GUIDANCE_LAYER_LINE_NAME, MASKED_GUIDANCE_LAYER_LINE_NAME,
MASKED_GUIDANCE_LAYER_NAME, MASKED_GUIDANCE_LAYER_NAME,
@ -27,11 +25,20 @@ import {
TOOL_PREVIEW_LAYER_ID, TOOL_PREVIEW_LAYER_ID,
TOOL_PREVIEW_RECT_ID, TOOL_PREVIEW_RECT_ID,
} from 'features/regionalPrompts/store/regionalPromptsSlice'; } from 'features/regionalPrompts/store/regionalPromptsSlice';
import type {
ControlAdapterLayer,
Layer,
MaskedGuidanceLayer,
Tool,
VectorMaskLine,
VectorMaskRect,
} from 'features/regionalPrompts/store/types';
import { getLayerBboxFast, getLayerBboxPixels } from 'features/regionalPrompts/util/bbox'; import { getLayerBboxFast, getLayerBboxPixels } from 'features/regionalPrompts/util/bbox';
import Konva from 'konva'; import Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types'; import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es'; import { debounce } from 'lodash-es';
import type { RgbColor } from 'react-colorful'; import type { RgbColor } from 'react-colorful';
import { imagesApi } from 'services/api/endpoints/images';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
@ -53,6 +60,9 @@ const getIsSelected = (layerId?: string | null) => {
return layerId === getStore().getState().regionalPrompts.present.selectedLayerId; return layerId === getStore().getState().regionalPrompts.present.selectedLayerId;
}; };
const selectRenderableLayers = (n: Konva.Node) =>
n.name() === MASKED_GUIDANCE_LAYER_NAME || n.name() === CONTROLNET_LAYER_NAME;
const selectVectorMaskObjects = (node: Konva.Node) => { const selectVectorMaskObjects = (node: Konva.Node) => {
return node.name() === MASKED_GUIDANCE_LAYER_LINE_NAME || node.name() === MASKED_GUIDANCE_LAYER_RECT_NAME; return node.name() === MASKED_GUIDANCE_LAYER_LINE_NAME || node.name() === MASKED_GUIDANCE_LAYER_RECT_NAME;
}; };
@ -219,7 +229,7 @@ const renderToolPreview = (
* @param reduxLayer The redux layer to create the konva layer from. * @param reduxLayer The redux layer to create the konva layer from.
* @param onLayerPosChanged Callback for when the layer's position changes. * @param onLayerPosChanged Callback for when the layer's position changes.
*/ */
const createVectorMaskLayer = ( const createMaskedGuidanceLayer = (
stage: Konva.Stage, stage: Konva.Stage,
reduxLayer: MaskedGuidanceLayer, reduxLayer: MaskedGuidanceLayer,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void onLayerPosChanged?: (layerId: string, x: number, y: number) => void
@ -320,7 +330,7 @@ const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Gro
* @param globalMaskLayerOpacity The opacity of the global mask layer. * @param globalMaskLayerOpacity The opacity of the global mask layer.
* @param tool The current tool. * @param tool The current tool.
*/ */
const renderVectorMaskLayer = ( const renderMaskedGuidanceLayer = (
stage: Konva.Stage, stage: Konva.Stage,
reduxLayer: MaskedGuidanceLayer, reduxLayer: MaskedGuidanceLayer,
globalMaskLayerOpacity: number, globalMaskLayerOpacity: number,
@ -328,7 +338,7 @@ const renderVectorMaskLayer = (
onLayerPosChanged?: (layerId: string, x: number, y: number) => void onLayerPosChanged?: (layerId: string, x: number, y: number) => void
): void => { ): void => {
const konvaLayer = const konvaLayer =
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createVectorMaskLayer(stage, reduxLayer, onLayerPosChanged); stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createMaskedGuidanceLayer(stage, reduxLayer, onLayerPosChanged);
// Update the layer's position and listening state // Update the layer's position and listening state
konvaLayer.setAttrs({ konvaLayer.setAttrs({
@ -383,8 +393,8 @@ const renderVectorMaskLayer = (
} }
// Only update layer visibility if it has changed. // Only update layer visibility if it has changed.
if (konvaLayer.visible() !== reduxLayer.isVisible) { if (konvaLayer.visible() !== reduxLayer.isEnabled) {
konvaLayer.visible(reduxLayer.isVisible); konvaLayer.visible(reduxLayer.isEnabled);
groupNeedsCache = true; groupNeedsCache = true;
} }
@ -401,6 +411,101 @@ const renderVectorMaskLayer = (
} }
}; };
const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer): Konva.Layer => {
const konvaLayer = new Konva.Layer({
id: reduxLayer.id,
name: CONTROLNET_LAYER_NAME,
imageSmoothingEnabled: false,
});
stage.add(konvaLayer);
return konvaLayer;
};
const createControlNetLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({
name: CONTROLNET_LAYER_IMAGE_NAME,
image,
filters: [LightnessToAlphaFilter],
});
konvaLayer.add(konvaImage);
return konvaImage;
};
const updateControlNetLayerImageSource = async (
stage: Konva.Stage,
konvaLayer: Konva.Layer,
reduxLayer: ControlAdapterLayer
) => {
if (reduxLayer.imageName) {
const imageName = reduxLayer.imageName;
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(reduxLayer.imageName));
const imageDTO = await req.unwrap();
req.unsubscribe();
const image = new Image();
const imageId = getControlNetLayerImageId(reduxLayer.id, imageName);
image.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage =
konvaLayer.findOne<Konva.Image>(`.${CONTROLNET_LAYER_IMAGE_NAME}`) ??
createControlNetLayerImage(konvaLayer, image);
// Update the image's attributes
konvaImage.setAttrs({
id: imageId,
image,
});
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
// Must cache after this to apply the filters
konvaImage.cache();
image.id = imageId;
};
image.src = imageDTO.image_url;
} else {
konvaLayer.findOne(`.${CONTROLNET_LAYER_IMAGE_NAME}`)?.destroy();
}
};
const updateControlNetLayerImageAttrs = (
stage: Konva.Stage,
konvaImage: Konva.Image,
reduxLayer: ControlAdapterLayer
) => {
konvaImage.setAttrs({
opacity: reduxLayer.opacity,
scaleX: 1,
scaleY: 1,
width: stage.width() / stage.scaleX(),
height: stage.height() / stage.scaleY(),
visible: reduxLayer.isEnabled,
});
konvaImage.cache();
};
const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer) => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createControlNetLayer(stage, reduxLayer);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CONTROLNET_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) {
if (
reduxLayer.imageName &&
canvasImageSource.id !== getControlNetLayerImageId(reduxLayer.id, reduxLayer.imageName)
) {
imageSourceNeedsUpdate = true;
} else if (!reduxLayer.imageName) {
imageSourceNeedsUpdate = true;
}
} else if (!canvasImageSource) {
imageSourceNeedsUpdate = true;
}
if (imageSourceNeedsUpdate) {
updateControlNetLayerImageSource(stage, konvaLayer, reduxLayer);
} else if (konvaImage) {
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
}
};
/** /**
* Renders the layers on the stage. * Renders the layers on the stage.
* @param stage The konva stage to render on. * @param stage The konva stage to render on.
@ -416,10 +521,9 @@ const renderLayers = (
tool: Tool, tool: Tool,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void onLayerPosChanged?: (layerId: string, x: number, y: number) => void
) => { ) => {
const reduxLayerIds = reduxLayers.map(mapId); const reduxLayerIds = reduxLayers.filter(isRenderableLayer).map(mapId);
// Remove un-rendered layers // Remove un-rendered layers
for (const konvaLayer of stage.find<Konva.Layer>(`.${MASKED_GUIDANCE_LAYER_NAME}`)) { for (const konvaLayer of stage.find<Konva.Layer>(selectRenderableLayers)) {
if (!reduxLayerIds.includes(konvaLayer.id())) { if (!reduxLayerIds.includes(konvaLayer.id())) {
konvaLayer.destroy(); konvaLayer.destroy();
} }
@ -427,7 +531,10 @@ const renderLayers = (
for (const reduxLayer of reduxLayers) { for (const reduxLayer of reduxLayers) {
if (isMaskedGuidanceLayer(reduxLayer)) { if (isMaskedGuidanceLayer(reduxLayer)) {
renderVectorMaskLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged); renderMaskedGuidanceLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
}
if (isControlAdapterLayer(reduxLayer)) {
renderControlNetLayer(stage, reduxLayer);
} }
} }
}; };
@ -620,3 +727,20 @@ export const debouncedRenderers = {
renderBackground: debounce(renderBackground, DEBOUNCE_MS), renderBackground: debounce(renderBackground, DEBOUNCE_MS),
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS), arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
}; };
/**
* Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
* This is useful for edge maps and other masks, to make the black areas transparent.
* @param imageData The image data to apply the filter to
*/
const LightnessToAlphaFilter = (imageData: ImageData) => {
const len = imageData.data.length / 4;
for (let i = 0; i < len; i++) {
const r = imageData.data[i * 4 + 0] as number;
const g = imageData.data[i * 4 + 1] as number;
const b = imageData.data[i * 4 + 2] as number;
const cMin = Math.min(r, g, b);
const cMax = Math.max(r, g, b);
imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
}
};

View File

@ -13,7 +13,10 @@ import {
selectValidIPAdapters, selectValidIPAdapters,
selectValidT2IAdapters, selectValidT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice'; import {
selectAllControlAdapterIds,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react'; import { Fragment, memo } from 'react';
@ -26,10 +29,10 @@ const selector = createMemoizedSelector(
const badges: string[] = []; const badges: string[] = [];
let isError = false; let isError = false;
const regionalControlAdapterIds = selectAllControlAdapterIds(regionalPrompts.present);
const enabledNonRegionalIPAdapterCount = selectAllIPAdapters(controlAdapters) const enabledNonRegionalIPAdapterCount = selectAllIPAdapters(controlAdapters)
.filter( .filter((ca) => !regionalControlAdapterIds.includes(ca.id))
(ca) => !regionalPrompts.present.layers.filter(isMaskedGuidanceLayer).some((l) => l.ipAdapterIds.includes(ca.id))
)
.filter((ca) => ca.isEnabled).length; .filter((ca) => ca.isEnabled).length;
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length; const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
@ -40,7 +43,9 @@ const selector = createMemoizedSelector(
isError = true; isError = true;
} }
const enabledControlNetCount = selectAllControlNets(controlAdapters).filter((ca) => ca.isEnabled).length; const enabledControlNetCount = selectAllControlNets(controlAdapters)
.filter((ca) => !regionalControlAdapterIds.includes(ca.id))
.filter((ca) => ca.isEnabled).length;
const validControlNetCount = selectValidControlNets(controlAdapters).length; const validControlNetCount = selectValidControlNets(controlAdapters).length;
if (enabledControlNetCount > 0) { if (enabledControlNetCount > 0) {
badges.push(`${enabledControlNetCount} ControlNet`); badges.push(`${enabledControlNetCount} ControlNet`);
@ -49,7 +54,9 @@ const selector = createMemoizedSelector(
isError = true; isError = true;
} }
const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters).filter((ca) => ca.isEnabled).length; const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters)
.filter((ca) => !regionalControlAdapterIds.includes(ca.id))
.filter((ca) => ca.isEnabled).length;
const validT2IAdapterCount = selectValidT2IAdapters(controlAdapters).length; const validT2IAdapterCount = selectValidT2IAdapters(controlAdapters).length;
if (enabledT2IAdapterCount > 0) { if (enabledT2IAdapterCount > 0) {
badges.push(`${enabledT2IAdapterCount} T2I`); badges.push(`${enabledT2IAdapterCount} T2I`);
@ -59,7 +66,7 @@ const selector = createMemoizedSelector(
} }
const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter( const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter(
(id) => !regionalPrompts.present.layers.filter(isMaskedGuidanceLayer).some((l) => l.ipAdapterIds.includes(id)) (id) => !regionalControlAdapterIds.includes(id)
); );
return { return {

View File

@ -48,7 +48,7 @@ const ParametersPanel = () => {
<Flex gap={2} flexDirection="column" h="full" w="full"> <Flex gap={2} flexDirection="column" h="full" w="full">
<ImageSettingsAccordion /> <ImageSettingsAccordion />
<GenerationSettingsAccordion /> <GenerationSettingsAccordion />
<ControlSettingsAccordion /> {activeTabName !== 'txt2img' && <ControlSettingsAccordion />}
{activeTabName === 'unifiedCanvas' && <CompositingSettingsAccordion />} {activeTabName === 'unifiedCanvas' && <CompositingSettingsAccordion />}
{isSDXL && <RefinerSettingsAccordion />} {isSDXL && <RefinerSettingsAccordion />}
<AdvancedSettingsAccordion /> <AdvancedSettingsAccordion />