WIP control adapters in regional

This commit is contained in:
psychedelicious 2024-04-29 23:31:33 +10:00 committed by Kent Keirsey
parent e822897b1c
commit ded8267505
45 changed files with 1689 additions and 337 deletions

View File

@ -156,6 +156,7 @@
"balanced": "Balanced",
"base": "Base",
"beginEndStepPercent": "Begin / End Step Percentage",
"beginEndStepPercentShort": "Begin/End %",
"bgth": "bg_th",
"canny": "Canny",
"cannyDescription": "Canny edge detection",
@ -1531,6 +1532,10 @@
"maskPreviewColor": "Mask Preview Color",
"addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)"
"addIPAdapter": "Add $t(common.ipAdapter)",
"maskedGuidance": "Masked Guidance",
"maskedGuidanceLayer": "$t(regionalPrompts.maskedGuidance) $t(unifiedCanvas.layer)",
"controlNetLayer": "$t(common.controlNet) $t(unifiedCanvas.layer)",
"ipAdapterLayer": "$t(common.ipAdapter) $t(unifiedCanvas.layer)"
}
}

View File

@ -35,6 +35,7 @@ import { addInitialImageSelectedListener } from 'app/store/middleware/listenerMi
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addRegionalControlToControlAdapterBridge } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
@ -157,3 +158,5 @@ addUpscaleRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
addRegionalControlToControlAdapterBridge(startAppListening);

View File

@ -48,12 +48,10 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
controlImage: imageDTO,
})
);
},

View File

@ -58,12 +58,10 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
controlImage: imageDTO,
})
);
},

View File

@ -91,7 +91,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
dispatch(
controlAdapterProcessedImageChanged({
id,
processedControlImage: processedControlImage.image_name,
processedControlImage,
})
);
}

View File

@ -71,7 +71,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
dispatch(
controlAdapterImageChanged({
id,
controlImage: activeData.payload.imageDTO.image_name,
controlImage: activeData.payload.imageDTO,
})
);
dispatch(

View File

@ -96,7 +96,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
dispatch(
controlAdapterImageChanged({
id,
controlImage: imageDTO.image_name,
controlImage: imageDTO,
})
);
dispatch(

View File

@ -0,0 +1,93 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
controlAdapterLayerAdded,
ipAdapterLayerAdded,
layerDeleted,
maskedGuidanceLayerAdded,
maskLayerIPAdapterAdded,
maskLayerIPAdapterDeleted,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import type { Layer } from 'features/regionalPrompts/store/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
export const guidanceLayerAdded = createAction<Layer['type']>('regionalPrompts/guidanceLayerAdded');
export const guidanceLayerDeleted = createAction<string>('regionalPrompts/guidanceLayerDeleted');
export const allLayersDeleted = createAction('regionalPrompts/allLayersDeleted');
export const guidanceLayerIPAdapterAdded = createAction<string>('regionalPrompts/guidanceLayerIPAdapterAdded');
export const guidanceLayerIPAdapterDeleted = createAction<{ layerId: string; ipAdapterId: string }>(
'regionalPrompts/guidanceLayerIPAdapterDeleted'
);
export const addRegionalControlToControlAdapterBridge = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: guidanceLayerAdded,
effect: (action, { dispatch }) => {
const type = action.payload;
const layerId = uuidv4();
if (type === 'ip_adapter_layer') {
const ipAdapterId = uuidv4();
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId }));
} else if (type === 'control_adapter_layer') {
const controlNetId = uuidv4();
dispatch(controlAdapterAdded({ type: 'controlnet', overrides: { id: controlNetId } }));
dispatch(controlAdapterLayerAdded({ layerId, controlNetId }));
} else if (type === 'masked_guidance_layer') {
dispatch(maskedGuidanceLayerAdded({ layerId }));
}
},
});
startAppListening({
actionCreator: guidanceLayerDeleted,
effect: (action, { getState, dispatch }) => {
const layerId = action.payload;
const state = getState();
const layer = state.regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(layer, `Layer ${layerId} not found`);
if (layer.type === 'ip_adapter_layer') {
dispatch(controlAdapterRemoved({ id: layer.ipAdapterId }));
} else if (layer.type === 'control_adapter_layer') {
dispatch(controlAdapterRemoved({ id: layer.controlNetId }));
} else if (layer.type === 'masked_guidance_layer') {
for (const ipAdapterId of layer.ipAdapterIds) {
dispatch(controlAdapterRemoved({ id: ipAdapterId }));
}
}
dispatch(layerDeleted(layerId));
},
});
startAppListening({
actionCreator: allLayersDeleted,
effect: (action, { dispatch, getOriginalState }) => {
const state = getOriginalState();
for (const layer of state.regionalPrompts.present.layers) {
dispatch(guidanceLayerDeleted(layer.id));
}
},
});
startAppListening({
actionCreator: guidanceLayerIPAdapterAdded,
effect: (action, { dispatch }) => {
const layerId = action.payload;
const ipAdapterId = uuidv4();
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId }));
},
});
startAppListening({
actionCreator: guidanceLayerIPAdapterDeleted,
effect: (action, { dispatch }) => {
const { layerId, ipAdapterId } = action.payload;
dispatch(controlAdapterRemoved({ id: ipAdapterId }));
dispatch(maskLayerIPAdapterDeleted({ layerId, ipAdapterId }));
},
});
};

View File

@ -6,9 +6,8 @@ import { deepClone } from 'common/util/deepClone';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { maskLayerIPAdapterAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
@ -135,23 +134,46 @@ export const controlAdaptersSlice = createSlice({
const { id, isEnabled } = action.payload;
caAdapter.updateOne(state, { id, changes: { isEnabled } });
},
controlAdapterImageChanged: (
state,
action: PayloadAction<{
id: string;
controlImage: string | null;
}>
) => {
controlAdapterImageChanged: (state, action: PayloadAction<{ id: string; controlImage: ImageDTO | null }>) => {
const { id, controlImage } = action.payload;
const ca = selectControlAdapterById(state, id);
if (!ca) {
return;
}
caAdapter.updateOne(state, {
id,
changes: { controlImage, processedControlImage: null },
});
if (isControlNetOrT2IAdapter(ca)) {
if (controlImage) {
const { image_name, width, height } = controlImage;
const processorNode = deepClone(ca.processorNode);
const minDim = Math.min(controlImage.width, controlImage.height);
if ('detect_resolution' in processorNode) {
processorNode.detect_resolution = minDim;
}
if ('image_resolution' in processorNode) {
processorNode.image_resolution = minDim;
}
if ('resolution' in processorNode) {
processorNode.resolution = minDim;
}
caAdapter.updateOne(state, {
id,
changes: {
processorNode,
controlImage: image_name,
controlImageDimensions: { width, height },
processedControlImage: null,
},
});
} else {
caAdapter.updateOne(state, {
id,
changes: { controlImage: null, controlImageDimensions: null, processedControlImage: null },
});
}
} else {
// ip adapter
caAdapter.updateOne(state, { id, changes: { controlImage: controlImage?.image_name ?? null } });
}
if (controlImage !== null && isControlNetOrT2IAdapter(ca) && ca.processorType !== 'none') {
state.pendingControlImages.push(id);
@ -161,7 +183,7 @@ export const controlAdaptersSlice = createSlice({
state,
action: PayloadAction<{
id: string;
processedControlImage: string | null;
processedControlImage: ImageDTO | null;
}>
) => {
const { id, processedControlImage } = action.payload;
@ -174,12 +196,24 @@ export const controlAdaptersSlice = createSlice({
return;
}
caAdapter.updateOne(state, {
id,
changes: {
processedControlImage,
},
});
if (processedControlImage) {
const { image_name, width, height } = processedControlImage;
caAdapter.updateOne(state, {
id,
changes: {
processedControlImage: image_name,
processedControlImageDimensions: { width, height },
},
});
} else {
caAdapter.updateOne(state, {
id,
changes: {
processedControlImage: null,
processedControlImageDimensions: null,
},
});
}
state.pendingControlImages = state.pendingControlImages.filter((pendingId) => pendingId !== id);
},
@ -222,9 +256,22 @@ export const controlAdaptersSlice = createSlice({
}
const processor = buildControlAdapterProcessor(modelConfig);
update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
if (processor.processorType !== cn.processorNode.type) {
update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
if (cn.controlImageDimensions) {
const minDim = Math.min(cn.controlImageDimensions.width, cn.controlImageDimensions.height);
if ('detect_resolution' in update.changes.processorNode) {
update.changes.processorNode.detect_resolution = minDim;
}
if ('image_resolution' in update.changes.processorNode) {
update.changes.processorNode.image_resolution = minDim;
}
if ('resolution' in update.changes.processorNode) {
update.changes.processorNode.resolution = minDim;
}
}
}
caAdapter.updateOne(state, update);
},
controlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
@ -341,8 +388,23 @@ export const controlAdaptersSlice = createSlice({
if (update.changes.shouldAutoConfig && modelConfig) {
const processor = buildControlAdapterProcessor(modelConfig);
update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
if (processor.processorType !== cn.processorNode.type) {
update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
// Copy image resolution settings, urgh
if (cn.controlImageDimensions) {
const minDim = Math.min(cn.controlImageDimensions.width, cn.controlImageDimensions.height);
if ('detect_resolution' in update.changes.processorNode) {
update.changes.processorNode.detect_resolution = minDim;
}
if ('image_resolution' in update.changes.processorNode) {
update.changes.processorNode.image_resolution = minDim;
}
if ('resolution' in update.changes.processorNode) {
update.changes.processorNode.resolution = minDim;
}
}
}
}
caAdapter.updateOne(state, update);
@ -383,10 +445,6 @@ export const controlAdaptersSlice = createSlice({
builder.addCase(socketInvocationError, (state) => {
state.pendingControlImages = [];
});
builder.addCase(maskLayerIPAdapterAdded, (state, action) => {
caAdapter.addOne(state, buildControlAdapter(action.meta.uuid, 'ip_adapter'));
});
},
});

View File

@ -225,7 +225,9 @@ export type ControlNetConfig = {
controlMode: ControlMode;
resizeMode: ResizeMode;
controlImage: string | null;
controlImageDimensions: { width: number; height: number } | null;
processedControlImage: string | null;
processedControlImageDimensions: { width: number; height: number } | null;
processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean;
@ -241,7 +243,9 @@ export type T2IAdapterConfig = {
endStepPct: number;
resizeMode: ResizeMode;
controlImage: string | null;
controlImageDimensions: { width: number; height: number } | null;
processedControlImage: string | null;
processedControlImageDimensions: { width: number; height: number } | null;
processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean;

View File

@ -20,7 +20,9 @@ export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null,
controlImageDimensions: null,
processedControlImage: null,
processedControlImageDimensions: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults() as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
@ -35,7 +37,9 @@ export const initialT2IAdapter: Omit<T2IAdapterConfig, 'id'> = {
endStepPct: 1,
resizeMode: 'just_resize',
controlImage: null,
controlImageDimensions: null,
processedControlImage: null,
processedControlImageDimensions: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults() as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,

View File

@ -286,7 +286,9 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
controlMode: control_mode ?? initialControlNet.controlMode,
resizeMode: resize_mode ?? initialControlNet.resizeMode,
controlImage: image?.image_name ?? null,
controlImageDimensions: null,
processedControlImage: processedImage?.image_name ?? null,
processedControlImageDimensions: null,
processorType,
processorNode,
shouldAutoConfig: true,
@ -350,9 +352,11 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
controlImage: image?.image_name ?? null,
controlImageDimensions: null,
processedControlImage: processedImage?.image_name ?? null,
processorType,
processedControlImageDimensions: null,
processorNode,
processorType,
shouldAutoConfig: true,
id: uuidv4(),
};

View File

@ -2,6 +2,9 @@ import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { isControlAdapterLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { differenceWith, intersectionWith } from 'lodash-es';
import type {
CollectInvocation,
ControlNetInvocation,
@ -14,11 +17,8 @@ import { assert } from 'tsafe';
import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const getControlNets = (state: RootState) => {
// Start with the valid controlnets
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
@ -29,9 +29,33 @@ export const addControlNetToLinearGraph = async (
}
);
// txt2img tab has special handling - it uses layers exclusively, while the other tabs use the older control adapters
// accordion. We need to filter the list of valid T2I adapters according to the tab.
const activeTabName = activeTabNameSelector(state);
// Collect all ControlNet ids for ControlNet layers
const layerControlNetIds = state.regionalPrompts.present.layers
.filter(isControlAdapterLayer)
.map((l) => l.controlNetId);
if (activeTabName === 'txt2img') {
// Add only the cnets that are used in control layers
return intersectionWith(validControlNets, layerControlNetIds, (a, b) => a.id === b);
} else {
// Else, we want to exclude the cnets that are used in control layers
return differenceWith(validControlNets, layerControlNetIds, (a, b) => a.id === b);
}
};
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const controlNets = getControlNets(state);
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
if (validControlNets.length) {
if (controlNets.length) {
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
@ -47,7 +71,7 @@ export const addControlNetToLinearGraph = async (
},
});
for (const controlNet of validControlNets) {
for (const controlNet of controlNets) {
if (!controlNet.model) {
return;
}

View File

@ -2,8 +2,9 @@ import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { isMaskedGuidanceLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { differenceBy } from 'lodash-es';
import { isIPAdapterLayer, isMaskedGuidanceLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { differenceWith, intersectionWith } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
@ -16,11 +17,8 @@ import { assert } from 'tsafe';
import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const getIPAdapters = (state: RootState) => {
// Start with the valid IP adapters
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
@ -28,14 +26,37 @@ export const addIPAdapterToLinearGraph = async (
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
});
const regionalIPAdapterIds = state.regionalPrompts.present.layers
// Masked IP adapters are handled in the graph helper for regional control - skip them here
const maskedIPAdapterIds = state.regionalPrompts.present.layers
.filter(isMaskedGuidanceLayer)
.map((l) => l.ipAdapterIds)
.flat();
const nonMaskedIPAdapters = differenceWith(validIPAdapters, maskedIPAdapterIds, (a, b) => a.id === b);
const nonRegionalIPAdapters = differenceBy(validIPAdapters, regionalIPAdapterIds, 'id');
// txt2img tab has special handling - it uses layers exclusively, while the other tabs use the older control adapters
// accordion. We need to filter the list of valid IP adapters according to the tab.
const activeTabName = activeTabNameSelector(state);
if (nonRegionalIPAdapters.length) {
// Collect all IP Adapter ids for IP adapter layers
const layerIPAdapterIds = state.regionalPrompts.present.layers.filter(isIPAdapterLayer).map((l) => l.ipAdapterId);
if (activeTabName === 'txt2img') {
// If we are on the t2i tab, we only want to add the IP adapters that are used in unmasked IP Adapter layers
return intersectionWith(nonMaskedIPAdapters, layerIPAdapterIds, (a, b) => a.id === b);
} else {
// Else, we want to exclude the IP adapters that are used in IP Adapter layers
return differenceWith(nonMaskedIPAdapters, layerIPAdapterIds, (a, b) => a.id === b);
}
};
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const ipAdapters = getIPAdapters(state);
if (ipAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
@ -53,7 +74,7 @@ export const addIPAdapterToLinearGraph = async (
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
for (const ipAdapter of nonRegionalIPAdapters) {
for (const ipAdapter of ipAdapters) {
if (!ipAdapter.model) {
return;
}

View File

@ -31,7 +31,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// TODO: Image masks
.filter(isMaskedGuidanceLayer)
// Only visible layers are rendered on the canvas
.filter((l) => l.isVisible)
.filter((l) => l.isEnabled)
// Only layers with prompts get added to the graph
.filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
@ -39,12 +39,15 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
return hasTextPrompt || hasIPAdapter;
});
// Collect all IP Adapter ids for IP adapter layers
const layerIPAdapterIds = layers.flatMap((l) => l.ipAdapterIds);
const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter(
({ id, model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
const isRegional = layerIPAdapterIds.includes(id);
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
}
);

View File

@ -2,6 +2,9 @@ import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { isControlAdapterLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { differenceWith, intersectionWith } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
@ -14,11 +17,8 @@ import { assert } from 'tsafe';
import { T2I_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const getT2IAdapters = (state: RootState) => {
// Start with the valid controlnets
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
@ -29,7 +29,32 @@ export const addT2IAdaptersToLinearGraph = async (
}
);
if (validT2IAdapters.length) {
// txt2img tab has special handling - it uses layers exclusively, while the other tabs use the older control adapters
// accordion. We need to filter the list of valid T2I adapters according to the tab.
const activeTabName = activeTabNameSelector(state);
// Collect all ids for control adapter layers
const layerControlAdapterIds = state.regionalPrompts.present.layers
.filter(isControlAdapterLayer)
.map((l) => l.controlNetId);
if (activeTabName === 'txt2img') {
// Add only the T2Is that are used in control layers
return intersectionWith(validT2IAdapters, layerControlAdapterIds, (a, b) => a.id === b);
} else {
// Else, we want to exclude the T2Is that are used in control layers
return differenceWith(validT2IAdapters, layerControlAdapterIds, (a, b) => a.id === b);
}
};
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const t2iAdapters = getT2IAdapters(state);
if (t2iAdapters.length) {
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = {
id: T2I_ADAPTER_COLLECT,
@ -47,7 +72,7 @@ export const addT2IAdaptersToLinearGraph = async (
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
for (const t2iAdapter of validT2IAdapters) {
for (const t2iAdapter of t2iAdapters) {
if (!t2iAdapter.model) {
return;
}

View File

@ -1,6 +1,6 @@
import { Button } from '@invoke-ai/ui-library';
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { layerAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { guidanceLayerAdded } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@ -8,14 +8,27 @@ import { PiPlusBold } from 'react-icons/pi';
export const AddLayerButton = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
dispatch(layerAdded('masked_guidance_layer'));
const addMaskedGuidanceLayer = useCallback(() => {
dispatch(guidanceLayerAdded('masked_guidance_layer'));
}, [dispatch]);
const addControlNetLayer = useCallback(() => {
dispatch(guidanceLayerAdded('control_adapter_layer'));
}, [dispatch]);
const addIPAdapterLayer = useCallback(() => {
dispatch(guidanceLayerAdded('ip_adapter_layer'));
}, [dispatch]);
return (
<Button onClick={onClick} leftIcon={<PiPlusBold />} variant="ghost">
{t('regionalPrompts.addLayer')}
</Button>
<Menu>
<MenuButton as={Button} leftIcon={<PiPlusBold />} variant="ghost">
{t('regionalPrompts.addLayer')}
</MenuButton>
<MenuList>
<MenuItem onClick={addMaskedGuidanceLayer}> {t('regionalPrompts.maskedGuidanceLayer')}</MenuItem>
<MenuItem onClick={addControlNetLayer}> {t('regionalPrompts.controlNetLayer')}</MenuItem>
<MenuItem onClick={addIPAdapterLayer}> {t('regionalPrompts.ipAdapterLayer')}</MenuItem>
</MenuList>
</Menu>
);
});

View File

@ -1,9 +1,9 @@
import { Button, Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { guidanceLayerIPAdapterAdded } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
isMaskedGuidanceLayer,
maskLayerIPAdapterAdded,
maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged,
selectRegionalPromptsSlice,
@ -39,7 +39,7 @@ export const AddPromptButtons = ({ layerId }: AddPromptButtonProps) => {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]);
const addIPAdapter = useCallback(() => {
dispatch(maskLayerIPAdapterAdded(layerId));
dispatch(guidanceLayerIPAdapterAdded(layerId));
}, [dispatch, layerId]);
return (

View File

@ -23,7 +23,7 @@ export const BrushSize = memo(() => {
const brushSize = useAppSelector((s) => s.regionalPrompts.present.brushSize);
const onChange = useCallback(
(v: number) => {
dispatch(brushSizeChanged(v));
dispatch(brushSizeChanged(Math.round(v)));
},
[dispatch]
);

View File

@ -0,0 +1,62 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ControlAdapterLayerConfig from 'features/regionalPrompts/components/controlAdapterOverrides/ControlAdapterLayerConfig';
import { LayerTitle } from 'features/regionalPrompts/components/LayerTitle';
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
import { RPLayerVisibilityToggle } from 'features/regionalPrompts/components/RPLayerVisibilityToggle';
import {
isControlAdapterLayer,
layerSelected,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useMemo } from 'react';
import { assert } from 'tsafe';
type Props = {
layerId: string;
};
export const ControlAdapterLayerListItem = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(isControlAdapterLayer(layer), `Layer ${layerId} not found or not a ControlNet layer`);
return {
controlNetId: layer.controlNetId,
isSelected: layerId === regionalPrompts.present.selectedLayerId,
};
}),
[layerId]
);
const { controlNetId, isSelected } = useAppSelector(selector);
const onClickCapture = useCallback(() => {
// Must be capture so that the layer is selected before deleting/resetting/etc
dispatch(layerSelected(layerId));
}, [dispatch, layerId]);
return (
<Flex
gap={2}
onClickCapture={onClickCapture}
bg={isSelected ? 'base.400' : 'base.800'}
ps={2}
borderRadius="base"
pe="1px"
py="1px"
>
<Flex flexDir="column" gap={4} w="full" bg="base.850" p={3} borderRadius="base">
<Flex gap={3} alignItems="center">
<RPLayerVisibilityToggle layerId={layerId} />
<LayerTitle type="control_adapter_layer" />
<Spacer />
<RPLayerDeleteButton layerId={layerId} />
</Flex>
<ControlAdapterLayerConfig id={controlNetId} />
</Flex>
</Flex>
);
});
ControlAdapterLayerListItem.displayName = 'ControlAdapterLayerListItem';

View File

@ -1,6 +1,6 @@
import { Button } from '@invoke-ai/ui-library';
import { allLayersDeleted } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { useAppDispatch } from 'app/store/storeHooks';
import { allLayersDeleted } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';

View File

@ -0,0 +1,62 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ControlAdapterLayerConfig from 'features/regionalPrompts/components/controlAdapterOverrides/ControlAdapterLayerConfig';
import { LayerTitle } from 'features/regionalPrompts/components/LayerTitle';
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
import { RPLayerVisibilityToggle } from 'features/regionalPrompts/components/RPLayerVisibilityToggle';
import {
isIPAdapterLayer,
layerSelected,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useMemo } from 'react';
import { assert } from 'tsafe';
type Props = {
layerId: string;
};
export const IPAdapterLayerListItem = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(isIPAdapterLayer(layer), `Layer ${layerId} not found or not an IP Adapter layer`);
return {
ipAdapterId: layer.ipAdapterId,
isSelected: layerId === regionalPrompts.present.selectedLayerId,
};
}),
[layerId]
);
const { ipAdapterId, isSelected } = useAppSelector(selector);
const onClickCapture = useCallback(() => {
// Must be capture so that the layer is selected before deleting/resetting/etc
dispatch(layerSelected(layerId));
}, [dispatch, layerId]);
return (
<Flex
gap={2}
onClickCapture={onClickCapture}
bg={isSelected ? 'base.400' : 'base.800'}
ps={2}
borderRadius="base"
pe="1px"
py="1px"
>
<Flex flexDir="column" gap={4} w="full" bg="base.850" p={3} borderRadius="base">
<Flex gap={3} alignItems="center">
<RPLayerVisibilityToggle layerId={layerId} />
<LayerTitle type="ip_adapter_layer" />
<Spacer />
<RPLayerDeleteButton layerId={layerId} />
</Flex>
<ControlAdapterLayerConfig id={ipAdapterId} />
</Flex>
</Flex>
);
});
IPAdapterLayerListItem.displayName = 'IPAdapterLayerListItem';

View File

@ -0,0 +1,29 @@
import { Text } from '@invoke-ai/ui-library';
import type { Layer } from 'features/regionalPrompts/store/types';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
type: Layer['type'];
};
export const LayerTitle = memo(({ type }: Props) => {
const { t } = useTranslation();
const title = useMemo(() => {
if (type === 'masked_guidance_layer') {
return t('regionalPrompts.maskedGuidance');
} else if (type === 'control_adapter_layer') {
return t('common.controlNet');
} else if (type === 'ip_adapter_layer') {
return t('common.ipAdapter');
}
}, [t, type]);
return (
<Text size="sm" fontWeight="semibold" pointerEvents="none" color="base.300">
{title}
</Text>
);
});
LayerTitle.displayName = 'LayerTitle';

View File

@ -2,6 +2,7 @@ import { Badge, Flex, Spacer } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { rgbColorToString } from 'features/canvas/util/colorToString';
import { LayerTitle } from 'features/regionalPrompts/components/LayerTitle';
import { RPLayerColorPicker } from 'features/regionalPrompts/components/RPLayerColorPicker';
import { RPLayerDeleteButton } from 'features/regionalPrompts/components/RPLayerDeleteButton';
import { RPLayerIPAdapterList } from 'features/regionalPrompts/components/RPLayerIPAdapterList';
@ -25,7 +26,7 @@ type Props = {
layerId: string;
};
export const RPLayerListItem = memo(({ layerId }: Props) => {
export const MaskedGuidanceLayerListItem = memo(({ layerId }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selector = useMemo(
@ -59,21 +60,21 @@ export const RPLayerListItem = memo(({ layerId }: Props) => {
borderRadius="base"
pe="1px"
py="1px"
cursor="pointer"
>
<Flex flexDir="column" gap={2} w="full" bg="base.850" p={2} borderRadius="base">
<Flex flexDir="column" w="full" bg="base.850" p={3} gap={3} borderRadius="base">
<Flex gap={3} alignItems="center">
<RPLayerVisibilityToggle layerId={layerId} />
<RPLayerColorPicker layerId={layerId} />
<LayerTitle type="masked_guidance_layer" />
<Spacer />
{autoNegative === 'invert' && (
<Badge color="base.300" bg="transparent" borderWidth={1}>
{t('regionalPrompts.autoNegative')}
</Badge>
)}
<RPLayerDeleteButton layerId={layerId} />
<RPLayerColorPicker layerId={layerId} />
<RPLayerSettingsPopover layerId={layerId} />
<RPLayerMenu layerId={layerId} />
<RPLayerDeleteButton layerId={layerId} />
</Flex>
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons layerId={layerId} />}
{hasPositivePrompt && <RPLayerPositivePrompt layerId={layerId} />}
@ -84,4 +85,4 @@ export const RPLayerListItem = memo(({ layerId }: Props) => {
);
});
RPLayerListItem.displayName = 'RPLayerListItem';
MaskedGuidanceLayerListItem.displayName = 'MaskedGuidanceLayerListItem';

View File

@ -29,7 +29,7 @@ const useAutoNegative = (layerId: string) => {
return autoNegative;
};
export const RPLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
export const MaskedGuidanceLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoNegative = useAutoNegative(layerId);
@ -48,4 +48,4 @@ export const RPLayerAutoNegativeCheckbox = memo(({ layerId }: Props) => {
);
});
RPLayerAutoNegativeCheckbox.displayName = 'RPLayerAutoNegativeCheckbox';
MaskedGuidanceLayerAutoNegativeCheckbox.displayName = 'MaskedGuidanceLayerAutoNegativeCheckbox';

View File

@ -1,9 +1,11 @@
import { Flex } from '@invoke-ai/ui-library';
import { Divider, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
import { guidanceLayerIPAdapterDeleted } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ControlAdapterLayerConfig from 'features/regionalPrompts/components/controlAdapterOverrides/ControlAdapterLayerConfig';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useMemo } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { assert } from 'tsafe';
type Props = {
@ -22,13 +24,55 @@ export const RPLayerIPAdapterList = memo(({ layerId }: Props) => {
);
const ipAdapterIds = useAppSelector(selectIPAdapterIds);
if (ipAdapterIds.length === 0) {
return null;
}
return (
<Flex w="full" flexDir="column" gap={2}>
<>
{ipAdapterIds.map((id, index) => (
<ControlAdapterConfig key={id} id={id} number={index + 1} />
<Flex flexDir="column" key={id}>
<Flex pb={3}>
<Divider />
</Flex>
<IPAdapterListItem layerId={layerId} ipAdapterId={id} ipAdapterNumber={index + 1} />
</Flex>
))}
</Flex>
</>
);
});
RPLayerIPAdapterList.displayName = 'RPLayerIPAdapterList';
type IPAdapterListItemProps = {
layerId: string;
ipAdapterId: string;
ipAdapterNumber: number;
};
const IPAdapterListItem = memo(({ layerId, ipAdapterId, ipAdapterNumber }: IPAdapterListItemProps) => {
const dispatch = useAppDispatch();
const onDeleteIPAdapter = useCallback(() => {
dispatch(guidanceLayerIPAdapterDeleted({ layerId, ipAdapterId }));
}, [dispatch, ipAdapterId, layerId]);
return (
<Flex flexDir="column" gap={3}>
<Flex alignItems="center" gap={3}>
<Text fontWeight="semibold" color="base.400">{`IP Adapter ${ipAdapterNumber}`}</Text>
<Spacer />
<IconButton
size="sm"
icon={<PiTrashSimpleBold />}
aria-label="Delete IP Adapter"
onClick={onDeleteIPAdapter}
variant="ghost"
colorScheme="error"
/>
</Flex>
<ControlAdapterLayerConfig id={ipAdapterId} />
</Flex>
);
});
IPAdapterListItem.displayName = 'IPAdapterListItem';

View File

@ -1,5 +1,6 @@
import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { guidanceLayerIPAdapterAdded } from 'app/store/middleware/listenerMiddleware/listeners/regionalControlToControlAdapterBridge';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
isMaskedGuidanceLayer,
@ -9,7 +10,6 @@ import {
layerMovedToBack,
layerMovedToFront,
layerReset,
maskLayerIPAdapterAdded,
maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged,
selectRegionalPromptsSlice,
@ -59,7 +59,7 @@ export const RPLayerMenu = memo(({ layerId }: Props) => {
dispatch(maskLayerNegativePromptChanged({ layerId, prompt: '' }));
}, [dispatch, layerId]);
const addIPAdapter = useCallback(() => {
dispatch(maskLayerIPAdapterAdded(layerId));
dispatch(guidanceLayerIPAdapterAdded(layerId));
}, [dispatch, layerId]);
const moveForward = useCallback(() => {
dispatch(layerMovedForward(layerId));

View File

@ -9,7 +9,7 @@ import {
PopoverContent,
PopoverTrigger,
} from '@invoke-ai/ui-library';
import { RPLayerAutoNegativeCheckbox } from 'features/regionalPrompts/components/RPLayerAutoNegativeCheckbox';
import { MaskedGuidanceLayerAutoNegativeCheckbox } from 'features/regionalPrompts/components/RPLayerAutoNegativeCheckbox';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixBold } from 'react-icons/pi';
@ -41,7 +41,7 @@ const RPLayerSettingsPopover = ({ layerId }: Props) => {
<PopoverBody>
<Flex direction="column" gap={2}>
<FormControlGroup formLabelProps={formLabelProps}>
<RPLayerAutoNegativeCheckbox layerId={layerId} />
<MaskedGuidanceLayerAutoNegativeCheckbox layerId={layerId} />
</FormControlGroup>
</Flex>
</PopoverBody>

View File

@ -4,20 +4,43 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButton';
import { ControlAdapterLayerListItem } from 'features/regionalPrompts/components/ControlAdapterLayerListItem';
import { DeleteAllLayersButton } from 'features/regionalPrompts/components/DeleteAllLayersButton';
import { RPLayerListItem } from 'features/regionalPrompts/components/RPLayerListItem';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { IPAdapterLayerListItem } from 'features/regionalPrompts/components/IPAdapterLayerListItem';
import { MaskedGuidanceLayerListItem } from 'features/regionalPrompts/components/MaskedGuidanceLayerListItem';
import {
isControlAdapterLayer,
isIPAdapterLayer,
isMaskedGuidanceLayer,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo } from 'react';
const selectRPLayerIdsReversed = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
const selectMaskedGuidanceLayerIds = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.present.layers
.filter(isMaskedGuidanceLayer)
.map((l) => l.id)
.reverse()
);
const selectControlNetLayerIds = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.present.layers
.filter(isControlAdapterLayer)
.map((l) => l.id)
.reverse()
);
const selectIPAdapterLayerIds = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) =>
regionalPrompts.present.layers
.filter(isIPAdapterLayer)
.map((l) => l.id)
.reverse()
);
export const RegionalPromptsPanelContent = memo(() => {
const rpLayerIdsReversed = useAppSelector(selectRPLayerIdsReversed);
const maskedGuidanceLayerIds = useAppSelector(selectMaskedGuidanceLayerIds);
const controlNetLayerIds = useAppSelector(selectControlNetLayerIds);
const ipAdapterLayerIds = useAppSelector(selectIPAdapterLayerIds);
return (
<Flex flexDir="column" gap={4} w="full" h="full">
<Flex justifyContent="space-around">
@ -26,8 +49,14 @@ export const RegionalPromptsPanelContent = memo(() => {
</Flex>
<ScrollableContent>
<Flex flexDir="column" gap={4}>
{rpLayerIdsReversed.map((id) => (
<RPLayerListItem key={id} layerId={id} />
{maskedGuidanceLayerIds.map((id) => (
<MaskedGuidanceLayerListItem key={id} layerId={id} />
))}
{controlNetLayerIds.map((id) => (
<ControlAdapterLayerListItem key={id} layerId={id} />
))}
{ipAdapterLayerIds.map((id) => (
<IPAdapterLayerListItem key={id} layerId={id} />
))}
</Flex>
</ScrollableContent>

View File

@ -18,9 +18,8 @@ import {
import { debouncedRenderers, renderers as normalRenderers } from 'features/regionalPrompts/util/renderers';
import Konva from 'konva';
import type { IRect } from 'konva/lib/types';
import type { MutableRefObject } from 'react';
import { memo, useCallback, useLayoutEffect, useMemo, useRef, useState } from 'react';
import { assert } from 'tsafe';
import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react';
import { v4 as uuidv4 } from 'uuid';
// This will log warnings when layers > 5 - maybe use `import.meta.env.MODE === 'development'` instead?
Konva.showWarnings = false;
@ -28,16 +27,14 @@ Konva.showWarnings = false;
const log = logger('regionalPrompts');
const selectSelectedLayerColor = createMemoizedSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === regionalPrompts.present.selectedLayerId);
if (!layer) {
return null;
}
assert(isMaskedGuidanceLayer(layer), `Layer ${regionalPrompts.present.selectedLayerId} is not an RP layer`);
return layer.previewColor;
const layer = regionalPrompts.present.layers
.filter(isMaskedGuidanceLayer)
.find((l) => l.id === regionalPrompts.present.selectedLayerId);
return layer?.previewColor ?? null;
});
const useStageRenderer = (
stageRef: MutableRefObject<Konva.Stage>,
stage: Konva.Stage,
container: HTMLDivElement | null,
wrapper: HTMLDivElement | null,
asPreview: boolean
@ -79,25 +76,24 @@ const useStageRenderer = (
if (!container) {
return;
}
const stage = stageRef.current.container(container);
stage.container(container);
return () => {
log.trace('Cleaning up stage');
stage.destroy();
};
}, [container, stageRef]);
}, [container, stage]);
useLayoutEffect(() => {
log.trace('Adding stage listeners');
if (asPreview) {
return;
}
stageRef.current.on('mousedown', onMouseDown);
stageRef.current.on('mouseup', onMouseUp);
stageRef.current.on('mousemove', onMouseMove);
stageRef.current.on('mouseenter', onMouseEnter);
stageRef.current.on('mouseleave', onMouseLeave);
stageRef.current.on('wheel', onMouseWheel);
const stage = stageRef.current;
stage.on('mousedown', onMouseDown);
stage.on('mouseup', onMouseUp);
stage.on('mousemove', onMouseMove);
stage.on('mouseenter', onMouseEnter);
stage.on('mouseleave', onMouseLeave);
stage.on('wheel', onMouseWheel);
return () => {
log.trace('Cleaning up stage listeners');
@ -108,7 +104,7 @@ const useStageRenderer = (
stage.off('mouseleave', onMouseLeave);
stage.off('wheel', onMouseWheel);
};
}, [stageRef, asPreview, onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel]);
}, [stage, asPreview, onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel]);
useLayoutEffect(() => {
log.trace('Updating stage dimensions');
@ -116,8 +112,6 @@ const useStageRenderer = (
return;
}
const stage = stageRef.current;
const fitStageToContainer = () => {
const newXScale = wrapper.offsetWidth / state.size.width;
const newYScale = wrapper.offsetHeight / state.size.height;
@ -135,7 +129,7 @@ const useStageRenderer = (
return () => {
resizeObserver.disconnect();
};
}, [stageRef, state.size.width, state.size.height, wrapper]);
}, [stage, state.size.width, state.size.height, wrapper]);
useLayoutEffect(() => {
log.trace('Rendering tool preview');
@ -144,7 +138,7 @@ const useStageRenderer = (
return;
}
renderers.renderToolPreview(
stageRef.current,
stage,
tool,
selectedLayerIdColor,
state.globalMaskLayerOpacity,
@ -155,7 +149,7 @@ const useStageRenderer = (
);
}, [
asPreview,
stageRef,
stage,
tool,
selectedLayerIdColor,
state.globalMaskLayerOpacity,
@ -168,8 +162,17 @@ const useStageRenderer = (
useLayoutEffect(() => {
log.trace('Rendering layers');
renderers.renderLayers(stageRef.current, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
}, [stageRef, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged, renderers]);
renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
}, [
stage,
state.layers,
state.globalMaskLayerOpacity,
tool,
onLayerPosChanged,
renderers,
state.size.width,
state.size.height,
]);
useLayoutEffect(() => {
log.trace('Rendering bbox');
@ -177,8 +180,8 @@ const useStageRenderer = (
// Preview should not display bboxes
return;
}
renderers.renderBbox(stageRef.current, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
}, [stageRef, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderers]);
renderers.renderBbox(stage, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
}, [stage, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderers]);
useLayoutEffect(() => {
log.trace('Rendering background');
@ -186,13 +189,13 @@ const useStageRenderer = (
// The preview should not have a background
return;
}
renderers.renderBackground(stageRef.current, state.size.width, state.size.height);
}, [stageRef, asPreview, state.size.width, state.size.height, renderers]);
renderers.renderBackground(stage, state.size.width, state.size.height);
}, [stage, asPreview, state.size.width, state.size.height, renderers]);
useLayoutEffect(() => {
log.trace('Arranging layers');
renderers.arrangeLayers(stageRef.current, layerIds);
}, [stageRef, layerIds, renderers]);
renderers.arrangeLayers(stage, layerIds);
}, [stage, layerIds, renderers]);
};
type Props = {
@ -200,10 +203,8 @@ type Props = {
};
export const StageComponent = memo(({ asPreview = false }: Props) => {
const stageRef = useRef<Konva.Stage>(
new Konva.Stage({
container: document.createElement('div'), // We will overwrite this shortly...
})
const [stage] = useState(
() => new Konva.Stage({ id: uuidv4(), container: document.createElement('div'), listening: !asPreview })
);
const [container, setContainer] = useState<HTMLDivElement | null>(null);
const [wrapper, setWrapper] = useState<HTMLDivElement | null>(null);
@ -216,7 +217,7 @@ export const StageComponent = memo(({ asPreview = false }: Props) => {
setWrapper(el);
}, []);
useStageRenderer(stageRef, container, wrapper, asPreview);
useStageRenderer(stage, container, wrapper, asPreview);
return (
<Flex overflow="hidden" w="full" h="full">

View File

@ -1,12 +1,7 @@
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
$tool,
layerAdded,
selectedLayerDeleted,
selectedLayerReset,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { $tool, selectedLayerDeleted, selectedLayerReset } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@ -40,11 +35,6 @@ export const ToolChooser: React.FC = () => {
}, [dispatch]);
useHotkeys('shift+c', resetSelectedLayer);
const addLayer = useCallback(() => {
dispatch(layerAdded('masked_guidance_layer'));
}, [dispatch]);
useHotkeys('shift+a', addLayer);
const deleteSelectedLayer = useCallback(() => {
dispatch(selectedLayerDeleted());
}, [dispatch]);

View File

@ -0,0 +1,230 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
import {
controlAdapterImageChanged,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { heightChanged, widthChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiFloppyDiskBold, PiRulerBold } from 'react-icons/pi';
import {
useAddImageToBoardMutation,
useChangeImageIsIntermediateMutation,
useGetImageDTOQuery,
useRemoveImageFromBoardMutation,
} from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
type Props = {
id: string;
isSmall?: boolean;
};
const selectPendingControlImages = createMemoizedSelector(
selectControlAdaptersSlice,
(controlAdapters) => controlAdapters.pendingControlImages
);
const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const controlImageName = useControlAdapterControlImage(id);
const processedControlImageName = useControlAdapterProcessedControlImage(id);
const processorType = useControlAdapterProcessorType(id);
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const isConnected = useAppSelector((s) => s.system.isConnected);
const activeTabName = useAppSelector(activeTabNameSelector);
const optimalDimension = useAppSelector(selectOptimalDimension);
const pendingControlImages = useAppSelector(selectPendingControlImages);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
controlImageName ?? skipToken
);
const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery(
processedControlImageName ?? skipToken
);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => {
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
}, [id, dispatch]);
const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) {
return;
}
await changeIsIntermediate({
imageDTO: processedControlImage,
is_intermediate: false,
}).unwrap();
if (autoAddBoardId !== 'none') {
addToBoard({
imageDTO: processedControlImage,
board_id: autoAddBoardId,
});
} else {
removeFromBoard({ imageDTO: processedControlImage });
}
}, [processedControlImage, changeIsIntermediate, autoAddBoardId, addToBoard, removeFromBoard]);
const handleSetControlImageToDimensions = useCallback(() => {
if (!controlImage) {
return;
}
if (activeTabName === 'unifiedCanvas') {
dispatch(setBoundingBoxDimensions({ width: controlImage.width, height: controlImage.height }, optimalDimension));
} else {
const { width, height } = calculateNewSize(
controlImage.width / controlImage.height,
optimalDimension * optimalDimension
);
dispatch(widthChanged({ width: controlImage.width, updateAspectRatio: true }));
dispatch(heightChanged({ height: controlImage.height, updateAspectRatio: true }));
}
}, [controlImage, activeTabName, dispatch, optimalDimension]);
const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true);
}, []);
const handleMouseLeave = useCallback(() => {
setIsMouseOverImage(false);
}, []);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) {
return {
id,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, id]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id,
actionType: 'SET_CONTROL_ADAPTER_IMAGE',
context: { id },
}),
[id]
);
const postUploadAction = useMemo<PostUploadAction>(() => ({ type: 'SET_CONTROL_ADAPTER_IMAGE', id }), [id]);
const shouldShowProcessedImage =
controlImage &&
processedControlImage &&
!isMouseOverImage &&
!pendingControlImages.includes(id) &&
processorType !== 'none';
useEffect(() => {
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
handleResetControlImage();
}
}, [handleResetControlImage, isConnected, isErrorControlImage, isErrorProcessedControlImage]);
return (
<Flex
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
position="relative"
w="full"
h={isSmall ? 36 : 366} // magic no touch
alignItems="center"
justifyContent="center"
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage}
postUploadAction={postUploadAction}
/>
<Box
position="absolute"
top={0}
insetInlineStart={0}
w="full"
h="full"
opacity={shouldShowProcessedImage ? 1 : 0}
transitionProperty="common"
transitionDuration="normal"
pointerEvents="none"
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={processedControlImage}
isUploadDisabled={true}
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{pendingControlImages.includes(id) && (
<Flex
position="absolute"
top={0}
insetInlineStart={0}
w="full"
h="full"
alignItems="center"
justifyContent="center"
opacity={0.8}
borderRadius="base"
bg="base.900"
>
<Spinner size="xl" color="base.400" />
</Flex>
)}
</Flex>
);
};
export default memo(ControlAdapterImagePreview);
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@ -0,0 +1,70 @@
import { Box, Flex, Icon, IconButton } from '@invoke-ai/ui-library';
import ControlAdapterProcessorComponent from 'features/controlAdapters/components/ControlAdapterProcessorComponent';
import ControlAdapterShouldAutoConfig from 'features/controlAdapters/components/ControlAdapterShouldAutoConfig';
import ParamControlAdapterIPMethod from 'features/controlAdapters/components/parameters/ParamControlAdapterIPMethod';
import ParamControlAdapterProcessorSelect from 'features/controlAdapters/components/parameters/ParamControlAdapterProcessorSelect';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretUpBold } from 'react-icons/pi';
import { useToggle } from 'react-use';
import ControlAdapterImagePreview from './ControlAdapterImagePreview';
import { ParamControlAdapterBeginEnd } from './ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './ParamControlAdapterControlMode';
import ParamControlAdapterModel from './ParamControlAdapterModel';
import ParamControlAdapterWeight from './ParamControlAdapterWeight';
const ControlAdapterLayerConfig = (props: { id: string }) => {
const { id } = props;
const controlAdapterType = useControlAdapterType(id);
const { t } = useTranslation();
const [isExpanded, toggleIsExpanded] = useToggle(false);
return (
<Flex flexDir="column" gap={4} position="relative">
<Flex gap={3} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<ParamControlAdapterModel id={id} />{' '}
</Box>
<IconButton
size="sm"
tooltip={isExpanded ? t('controlnet.hideAdvanced') : t('controlnet.showAdvanced')}
aria-label={isExpanded ? t('controlnet.hideAdvanced') : t('controlnet.showAdvanced')}
onClick={toggleIsExpanded}
variant="ghost"
icon={
<Icon
boxSize={4}
as={PiCaretUpBold}
transform={isExpanded ? 'rotate(0deg)' : 'rotate(180deg)'}
transitionProperty="common"
transitionDuration="normal"
/>
}
/>
</Flex>
<Flex gap={4} w="full" alignItems="center">
<Flex flexDir="column" gap={3} w="full">
{controlAdapterType === 'ip_adapter' && <ParamControlAdapterIPMethod id={id} />}
{controlAdapterType === 'controlnet' && <ParamControlAdapterControlMode id={id} />}
<ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<ControlAdapterImagePreview id={id} isSmall />
</Flex>
</Flex>
{isExpanded && (
<>
<ControlAdapterShouldAutoConfig id={id} />
<ParamControlAdapterProcessorSelect id={id} />
<ControlAdapterProcessorComponent id={id} />
</>
)}
</Flex>
);
};
export default memo(ControlAdapterLayerConfig);

View File

@ -0,0 +1,89 @@
import { CompositeRangeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterBeginEndStepPct } from 'features/controlAdapters/hooks/useControlAdapterBeginEndStepPct';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import {
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
export const ParamControlAdapterBeginEnd = memo(({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const stepPcts = useControlAdapterBeginEndStepPct(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const onChange = useCallback(
(v: [number, number]) => {
dispatch(
controlAdapterBeginStepPctChanged({
id,
beginStepPct: v[0],
})
);
dispatch(
controlAdapterEndStepPctChanged({
id,
endStepPct: v[1],
})
);
},
[dispatch, id]
);
const onReset = useCallback(() => {
dispatch(
controlAdapterBeginStepPctChanged({
id,
beginStepPct: 0,
})
);
dispatch(
controlAdapterEndStepPctChanged({
id,
endStepPct: 1,
})
);
}, [dispatch, id]);
const value = useMemo<[number, number]>(() => [stepPcts?.beginStepPct ?? 0, stepPcts?.endStepPct ?? 1], [stepPcts]);
if (!stepPcts) {
return null;
}
return (
<FormControl isDisabled={!isEnabled} orientation="horizontal">
<InformationalPopover feature="controlNetBeginEnd">
<FormLabel m={0}>{t('controlnet.beginEndStepPercentShort')}</FormLabel>
</InformationalPopover>
<CompositeRangeSlider
aria-label={ariaLabel}
value={value}
onChange={onChange}
onReset={onReset}
min={0}
max={1}
step={0.05}
fineStep={0.01}
minStepsBetweenThumbs={1}
formatValue={formatPct}
marks
withThumbTooltip
/>
</FormControl>
);
});
ParamControlAdapterBeginEnd.displayName = 'ParamControlAdapterBeginEnd';
const ariaLabel = ['Begin Step %', 'End Step %'];

View File

@ -0,0 +1,66 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterControlMode } from 'features/controlAdapters/hooks/useControlAdapterControlMode';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { controlAdapterControlModeChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlMode } from 'features/controlAdapters/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const ParamControlAdapterControlMode = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlMode = useControlAdapterControlMode(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const CONTROL_MODE_DATA = useMemo(
() => [
{ label: t('controlnet.balanced'), value: 'balanced' },
{ label: t('controlnet.prompt'), value: 'more_prompt' },
{ label: t('controlnet.control'), value: 'more_control' },
{ label: t('controlnet.megaControl'), value: 'unbalanced' },
],
[t]
);
const handleControlModeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
return;
}
dispatch(
controlAdapterControlModeChanged({
id,
controlMode: v.value as ControlMode,
})
);
},
[id, dispatch]
);
const value = useMemo(
() => CONTROL_MODE_DATA.filter((o) => o.value === controlMode)[0],
[CONTROL_MODE_DATA, controlMode]
);
if (!controlMode) {
return null;
}
return (
<FormControl isDisabled={!isEnabled}>
<InformationalPopover feature="controlNetControlMode">
<FormLabel m={0}>{t('controlnet.control')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={CONTROL_MODE_DATA} onChange={handleControlModeChange} />
</FormControl>
);
};
export default memo(ParamControlAdapterControlMode);

View File

@ -0,0 +1,136 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import {
controlAdapterCLIPVisionModelChanged,
controlAdapterModelChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type {
AnyModelConfig,
ControlNetModelConfig,
IPAdapterModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
type ParamControlAdapterModelProps = {
id: string;
};
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation();
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
return;
}
dispatch(
controlAdapterModelChanged({
id,
modelConfig,
})
);
},
[dispatch, id]
);
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v?.value) {
return;
}
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
},
[dispatch, id]
);
const selectedModel = useMemo(
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, modelConfig]
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel,
getIsDisabled,
isLoading,
});
const clipVisionOptions = useMemo<ComboboxOption[]>(
() => [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
],
[]
);
const clipVisionModel = useMemo(
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
[clipVisionOptions, currentCLIPVisionModel]
);
return (
<Flex gap={4}>
<Tooltip label={selectedModel?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base} w="full">
<Combobox
options={options}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
<FormControl
isDisabled={!isEnabled}
isInvalid={!value || mainModel?.base !== modelConfig?.base}
width="max-content"
minWidth={28}
>
<Combobox
options={clipVisionOptions}
placeholder={t('controlnet.selectCLIPVisionModel')}
value={clipVisionModel}
onChange={onCLIPVisionModelChange}
/>
</FormControl>
)}
</Flex>
);
};
export default memo(ParamControlAdapterModel);

View File

@ -0,0 +1,74 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterWeight } from 'features/controlAdapters/hooks/useControlAdapterWeight';
import { controlAdapterWeightChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlAdapterWeightProps = {
id: string;
};
const formatValue = (v: number) => v.toFixed(2);
const ParamControlAdapterWeight = ({ id }: ParamControlAdapterWeightProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isEnabled = useControlAdapterIsEnabled(id);
const weight = useControlAdapterWeight(id);
const initial = useAppSelector((s) => s.config.sd.ca.weight.initial);
const sliderMin = useAppSelector((s) => s.config.sd.ca.weight.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.ca.weight.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.ca.weight.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.ca.weight.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.ca.weight.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.ca.weight.fineStep);
const onChange = useCallback(
(weight: number) => {
dispatch(controlAdapterWeightChanged({ id, weight }));
},
[dispatch, id]
);
if (isNil(weight)) {
// should never happen
return null;
}
return (
<FormControl isDisabled={!isEnabled} orientation="horizontal">
<InformationalPopover feature="controlNetWeight">
<FormLabel m={0}>{t('controlnet.weight')}</FormLabel>
</InformationalPopover>
<CompositeSlider
value={weight}
onChange={onChange}
defaultValue={initial}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
marks={marks}
formatValue={formatValue}
/>
<CompositeNumberInput
value={weight}
onChange={onChange}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
maxW={20}
defaultValue={initial}
/>
</FormControl>
);
};
export default memo(ParamControlAdapterWeight);
const marks = [0, 1, 2];

View File

@ -1,6 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import {
isMaskedGuidanceLayer,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useMemo } from 'react';
import { assert } from 'tsafe';
@ -39,8 +42,8 @@ export const useLayerIsVisible = (layerId: string) => {
() =>
createSelector(selectRegionalPromptsSlice, (regionalPrompts) => {
const layer = regionalPrompts.present.layers.find((l) => l.id === layerId);
assert(isMaskedGuidanceLayer(layer), `Layer ${layerId} not found or not an RP layer`);
return layer.isVisible;
assert(layer, `Layer ${layerId} not found`);
return layer.isEnabled;
}),
[layerId]
);

View File

@ -10,7 +10,7 @@ const selectValidLayerCount = createSelector(selectRegionalPromptsSlice, (region
}
const validLayers = regionalPrompts.present.layers
.filter(isMaskedGuidanceLayer)
.filter((l) => l.isVisible)
.filter((l) => l.isEnabled)
.filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
const hasAtLeastOneImagePrompt = l.ipAdapterIds.length > 0;

View File

@ -4,20 +4,16 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { controlAdapterRemoved, isAnyControlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
isAnyControlAdapterAdded,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import { modelChanged } from 'features/parameters/store/generationSlice';
import type {
ParameterAutoNegative,
ParameterHeight,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect, Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash-es';
@ -27,81 +23,17 @@ import type { UndoableOptions } from 'redux-undo';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
type DrawingTool = 'brush' | 'eraser';
export type Tool = DrawingTool | 'move' | 'rect';
export type VectorMaskLine = {
id: string;
type: 'vector_mask_line';
tool: DrawingTool;
strokeWidth: number;
points: number[];
};
export type VectorMaskRect = {
id: string;
type: 'vector_mask_rect';
x: number;
y: number;
width: number;
height: number;
};
type LayerBase = {
id: string;
isVisible: boolean;
};
type RenderableLayerBase = LayerBase & {
x: number;
y: number;
bbox: IRect | null;
bboxNeedsUpdate: boolean;
};
type ControlAdapterLayer = RenderableLayerBase & {
type: 'controlnet_layer'; // technically, also t2i adapter layer
controlAdapterId: string;
};
type IPAdapterLayer = LayerBase & {
type: 'ip_adapter_layer'; // technically, also t2i adapter layer
ipAdapterId: string;
};
export type MaskedGuidanceLayer = RenderableLayerBase & {
type: 'masked_guidance_layer';
maskObjects: (VectorMaskLine | VectorMaskRect)[];
positivePrompt: ParameterPositivePrompt | null;
negativePrompt: ParameterNegativePrompt | null; // Up to one text prompt per mask
ipAdapterIds: string[]; // Any number of image prompts
previewColor: RgbColor;
autoNegative: ParameterAutoNegative;
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
};
export type Layer = MaskedGuidanceLayer | ControlAdapterLayer | IPAdapterLayer;
type RegionalPromptsState = {
_version: 1;
selectedLayerId: string | null;
layers: Layer[];
brushSize: number;
globalMaskLayerOpacity: number;
isEnabled: boolean;
positivePrompt: ParameterPositivePrompt;
negativePrompt: ParameterNegativePrompt;
positivePrompt2: ParameterPositiveStylePromptSDXL;
negativePrompt2: ParameterNegativeStylePromptSDXL;
shouldConcatPrompts: boolean;
initialImage: string | null;
size: {
width: ParameterWidth;
height: ParameterHeight;
aspectRatio: AspectRatioState;
};
};
import type {
ControlAdapterLayer,
DrawingTool,
IPAdapterLayer,
Layer,
MaskedGuidanceLayer,
RegionalPromptsState,
Tool,
VectorMaskLine,
VectorMaskRect,
} from './types';
export const initialRegionalPromptsState: RegionalPromptsState = {
_version: 1,
@ -126,19 +58,22 @@ export const initialRegionalPromptsState: RegionalPromptsState = {
const isLine = (obj: VectorMaskLine | VectorMaskRect): obj is VectorMaskLine => obj.type === 'vector_mask_line';
export const isMaskedGuidanceLayer = (layer?: Layer): layer is MaskedGuidanceLayer =>
layer?.type === 'masked_guidance_layer';
export const isRenderableLayer = (layer?: Layer): layer is MaskedGuidanceLayer =>
layer?.type === 'masked_guidance_layer' || layer?.type === 'controlnet_layer';
export const isControlAdapterLayer = (layer?: Layer): layer is ControlAdapterLayer =>
layer?.type === 'control_adapter_layer';
export const isIPAdapterLayer = (layer?: Layer): layer is IPAdapterLayer => layer?.type === 'ip_adapter_layer';
export const isRenderableLayer = (layer?: Layer): layer is MaskedGuidanceLayer | ControlAdapterLayer =>
layer?.type === 'masked_guidance_layer' || layer?.type === 'control_adapter_layer';
const resetLayer = (layer: Layer) => {
if (layer.type === 'masked_guidance_layer') {
layer.maskObjects = [];
layer.bbox = null;
layer.isVisible = true;
layer.isEnabled = true;
layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
return;
}
if (layer.type === 'controlnet_layer') {
if (layer.type === 'control_adapter_layer') {
// TODO
}
};
@ -153,59 +88,71 @@ export const regionalPromptsSlice = createSlice({
initialState: initialRegionalPromptsState,
reducers: {
//#region All Layers
layerAdded: {
reducer: (state, action: PayloadAction<Layer['type'], string, { uuid: string }>) => {
const type = action.payload;
if (type === 'masked_guidance_layer') {
const layer: MaskedGuidanceLayer = {
id: getMaskedGuidanceLayerId(action.meta.uuid),
type: 'masked_guidance_layer',
isVisible: true,
bbox: null,
bboxNeedsUpdate: false,
maskObjects: [],
previewColor: getVectorMaskPreviewColor(state),
x: 0,
y: 0,
autoNegative: 'invert',
needsPixelBbox: false,
positivePrompt: '',
negativePrompt: null,
ipAdapterIds: [],
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
return;
}
if (type === 'controlnet_layer') {
const layer: ControlAdapterLayer = {
id: getControlLayerId(action.meta.uuid),
type: 'controlnet_layer',
controlAdapterId: action.meta.uuid,
x: 0,
y: 0,
bbox: null,
bboxNeedsUpdate: false,
isVisible: true,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
return;
}
},
prepare: (payload: Layer['type']) => ({ payload, meta: { uuid: uuidv4() } }),
maskedGuidanceLayerAdded: (state, action: PayloadAction<{ layerId: string }>) => {
const { layerId } = action.payload;
const layer: MaskedGuidanceLayer = {
id: getMaskedGuidanceLayerId(layerId),
type: 'masked_guidance_layer',
isEnabled: true,
bbox: null,
bboxNeedsUpdate: false,
maskObjects: [],
previewColor: getVectorMaskPreviewColor(state),
x: 0,
y: 0,
autoNegative: 'invert',
needsPixelBbox: false,
positivePrompt: '',
negativePrompt: null,
ipAdapterIds: [],
isSelected: true,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
return;
},
ipAdapterLayerAdded: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
const layer: IPAdapterLayer = {
id: getIPAdapterLayerId(layerId),
type: 'ip_adapter_layer',
isEnabled: true,
ipAdapterId,
};
state.layers.push(layer);
return;
},
controlAdapterLayerAdded: (state, action: PayloadAction<{ layerId: string; controlNetId: string }>) => {
const { layerId, controlNetId } = action.payload;
const layer: ControlAdapterLayer = {
id: getControlNetLayerId(layerId),
type: 'control_adapter_layer',
controlNetId,
x: 0,
y: 0,
bbox: null,
bboxNeedsUpdate: false,
isEnabled: true,
imageName: null,
opacity: 1,
isSelected: true,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
return;
},
layerSelected: (state, action: PayloadAction<string>) => {
const layer = state.layers.find((l) => l.id === action.payload);
if (layer) {
state.selectedLayerId = layer.id;
for (const layer of state.layers) {
if (isRenderableLayer(layer) && layer.id === action.payload) {
layer.isSelected = true;
state.selectedLayerId = action.payload;
}
}
},
layerVisibilityToggled: (state, action: PayloadAction<string>) => {
const layer = state.layers.find((l) => l.id === action.payload);
if (layer) {
layer.isVisible = !layer.isVisible;
layer.isEnabled = !layer.isEnabled;
}
},
layerTranslated: (state, action: PayloadAction<{ layerId: string; x: number; y: number }>) => {
@ -252,10 +199,6 @@ export const regionalPromptsSlice = createSlice({
// Because the layers are in reverse order, moving to the back is equivalent to moving to the front
moveToFront(state.layers, cb);
},
allLayersDeleted: (state) => {
state.layers = [];
state.selectedLayerId = null;
},
selectedLayerReset: (state) => {
const layer = state.layers.find((l) => l.id === state.selectedLayerId);
if (layer) {
@ -283,14 +226,19 @@ export const regionalPromptsSlice = createSlice({
layer.negativePrompt = prompt;
}
},
maskLayerIPAdapterAdded: {
reducer: (state, action: PayloadAction<string, string, { uuid: string }>) => {
const layer = state.layers.find((l) => l.id === action.payload);
if (layer?.type === 'masked_guidance_layer') {
layer.ipAdapterIds.push(action.meta.uuid);
}
},
prepare: (payload: string) => ({ payload, meta: { uuid: uuidv4() } }),
maskLayerIPAdapterAdded: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'masked_guidance_layer') {
layer.ipAdapterIds.push(ipAdapterId);
}
},
maskLayerIPAdapterDeleted: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (layer?.type === 'masked_guidance_layer') {
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== ipAdapterId);
}
},
maskLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload;
@ -422,10 +370,13 @@ export const regionalPromptsSlice = createSlice({
//#region General
brushSizeChanged: (state, action: PayloadAction<number>) => {
state.brushSize = action.payload;
state.brushSize = Math.round(action.payload);
},
globalMaskLayerOpacityChanged: (state, action: PayloadAction<number>) => {
state.globalMaskLayerOpacity = action.payload;
state.layers.filter(isControlAdapterLayer).forEach((l) => {
l.opacity = action.payload;
});
},
isEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isEnabled = action.payload;
@ -445,12 +396,6 @@ export const regionalPromptsSlice = createSlice({
//#endregion
},
extraReducers(builder) {
builder.addCase(controlAdapterRemoved, (state, action) => {
state.layers.filter(isMaskedGuidanceLayer).forEach((layer) => {
layer.ipAdapterIds = layer.ipAdapterIds.filter((id) => id !== action.payload.id);
});
});
builder.addCase(modelChanged, (state, action) => {
const newModel = action.payload;
if (!newModel || action.meta.previousModel?.base === newModel.base) {
@ -466,6 +411,28 @@ export const regionalPromptsSlice = createSlice({
state.size.height = height;
});
builder.addCase(controlAdapterImageChanged, (state, action) => {
const { id, controlImage } = action.payload;
const layer = state.layers.filter(isControlAdapterLayer).find((l) => l.controlNetId === id);
if (layer) {
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.imageName = controlImage?.image_name ?? null;
}
});
builder.addCase(controlAdapterProcessedImageChanged, (state, action) => {
const { id, processedControlImage } = action.payload;
const layer = state.layers.filter(isControlAdapterLayer).find((l) => l.controlNetId === id);
if (layer) {
layer.bbox = null;
layer.bboxNeedsUpdate = true;
layer.isEnabled = true;
layer.imageName = processedControlImage?.image_name ?? null;
}
});
// TODO: This is a temp fix to reduce issues with T2I adapter having a different downscaling
// factor than the UNet. Hopefully we get an upstream fix in diffusers.
builder.addMatcher(isAnyControlAdapterAdded, (state, action) => {
@ -510,7 +477,6 @@ class LayerColors {
export const {
// All layer actions
layerAdded,
layerDeleted,
layerMovedBackward,
layerMovedForward,
@ -521,9 +487,11 @@ export const {
layerTranslated,
layerBboxChanged,
layerVisibilityToggled,
allLayersDeleted,
selectedLayerReset,
selectedLayerDeleted,
maskedGuidanceLayerAdded,
ipAdapterLayerAdded,
controlAdapterLayerAdded,
// Mask layer actions
maskLayerLineAdded,
maskLayerPointsAdded,
@ -531,6 +499,7 @@ export const {
maskLayerNegativePromptChanged,
maskLayerPositivePromptChanged,
maskLayerIPAdapterAdded,
maskLayerIPAdapterDeleted,
maskLayerAutoNegativeChanged,
maskLayerPreviewColorChanged,
// Base layer actions
@ -549,6 +518,20 @@ export const {
redo,
} = regionalPromptsSlice.actions;
export const selectAllControlAdapterIds = (regionalPrompts: RegionalPromptsState) =>
regionalPrompts.layers.flatMap((l) => {
if (l.type === 'control_adapter_layer') {
return [l.controlNetId];
}
if (l.type === 'ip_adapter_layer') {
return [l.ipAdapterId];
}
if (l.type === 'masked_guidance_layer') {
return l.ipAdapterIds;
}
return [];
});
export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@ -571,8 +554,11 @@ export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_bord
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
export const BACKGROUND_LAYER_ID = 'background_layer';
export const BACKGROUND_RECT_ID = 'background_layer.rect';
export const CONTROLNET_LAYER_TRANSFORMER_ID = 'control_adapter_layer.transformer';
// Names (aka classes) for Konva layers and objects
export const CONTROLNET_LAYER_NAME = 'control_adapter_layer';
export const CONTROLNET_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
export const MASKED_GUIDANCE_LAYER_NAME = 'masked_guidance_layer';
export const MASKED_GUIDANCE_LAYER_LINE_NAME = 'masked_guidance_layer.line';
export const MASKED_GUIDANCE_LAYER_OBJECT_GROUP_NAME = 'masked_guidance_layer.object_group';
@ -586,7 +572,9 @@ const getMaskedGuidnaceLayerRectId = (layerId: string, lineId: string) => `${lay
export const getMaskedGuidanceLayerObjectGroupId = (layerId: string, groupId: string) =>
`${layerId}.objectGroup_${groupId}`;
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
const getControlLayerId = (layerId: string) => `control_layer_${layerId}`;
const getControlNetLayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
export const getControlNetLayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
const getIPAdapterLayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
export const regionalPromptsPersistConfig: PersistConfig<RegionalPromptsState> = {
name: regionalPromptsSlice.name,

View File

@ -0,0 +1,91 @@
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type {
ParameterAutoNegative,
ParameterHeight,
ParameterNegativePrompt,
ParameterNegativeStylePromptSDXL,
ParameterPositivePrompt,
ParameterPositiveStylePromptSDXL,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import type { IRect } from 'konva/lib/types';
import type { RgbColor } from 'react-colorful';
export type DrawingTool = 'brush' | 'eraser';
export type Tool = DrawingTool | 'move' | 'rect';
export type VectorMaskLine = {
id: string;
type: 'vector_mask_line';
tool: DrawingTool;
strokeWidth: number;
points: number[];
};
export type VectorMaskRect = {
id: string;
type: 'vector_mask_rect';
x: number;
y: number;
width: number;
height: number;
};
export type LayerBase = {
id: string;
isEnabled: boolean;
};
export type RenderableLayerBase = LayerBase & {
x: number;
y: number;
bbox: IRect | null;
bboxNeedsUpdate: boolean;
isSelected: boolean;
};
export type ControlAdapterLayer = RenderableLayerBase & {
type: 'control_adapter_layer'; // technically, also t2i adapter layer
controlNetId: string;
imageName: string | null;
opacity: number;
};
export type IPAdapterLayer = LayerBase & {
type: 'ip_adapter_layer'; // technically, also t2i adapter layer
ipAdapterId: string;
};
export type MaskedGuidanceLayer = RenderableLayerBase & {
type: 'masked_guidance_layer';
maskObjects: (VectorMaskLine | VectorMaskRect)[];
positivePrompt: ParameterPositivePrompt | null;
negativePrompt: ParameterNegativePrompt | null; // Up to one text prompt per mask
ipAdapterIds: string[]; // Any number of image prompts
previewColor: RgbColor;
autoNegative: ParameterAutoNegative;
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
};
export type Layer = MaskedGuidanceLayer | ControlAdapterLayer | IPAdapterLayer;
export type RegionalPromptsState = {
_version: 1;
selectedLayerId: string | null;
layers: Layer[];
brushSize: number;
globalMaskLayerOpacity: number;
isEnabled: boolean;
positivePrompt: ParameterPositivePrompt;
negativePrompt: ParameterNegativePrompt;
positivePrompt2: ParameterPositiveStylePromptSDXL;
negativePrompt2: ParameterNegativeStylePromptSDXL;
shouldConcatPrompts: boolean;
initialImage: string | null;
size: {
width: ParameterWidth;
height: ParameterHeight;
aspectRatio: AspectRatioState;
};
};

View File

@ -1,20 +1,18 @@
import { getStore } from 'app/store/nanostores/store';
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
import { getScaledFlooredCursorPosition } from 'features/regionalPrompts/hooks/mouseEventHooks';
import type {
Layer,
MaskedGuidanceLayer,
Tool,
VectorMaskLine,
VectorMaskRect,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import {
$tool,
BACKGROUND_LAYER_ID,
BACKGROUND_RECT_ID,
CONTROLNET_LAYER_IMAGE_NAME,
CONTROLNET_LAYER_NAME,
getControlNetLayerImageId,
getLayerBboxId,
getMaskedGuidanceLayerObjectGroupId,
isControlAdapterLayer,
isMaskedGuidanceLayer,
isRenderableLayer,
LAYER_BBOX_NAME,
MASKED_GUIDANCE_LAYER_LINE_NAME,
MASKED_GUIDANCE_LAYER_NAME,
@ -27,11 +25,20 @@ import {
TOOL_PREVIEW_LAYER_ID,
TOOL_PREVIEW_RECT_ID,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import type {
ControlAdapterLayer,
Layer,
MaskedGuidanceLayer,
Tool,
VectorMaskLine,
VectorMaskRect,
} from 'features/regionalPrompts/store/types';
import { getLayerBboxFast, getLayerBboxPixels } from 'features/regionalPrompts/util/bbox';
import Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es';
import type { RgbColor } from 'react-colorful';
import { imagesApi } from 'services/api/endpoints/images';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
@ -53,6 +60,9 @@ const getIsSelected = (layerId?: string | null) => {
return layerId === getStore().getState().regionalPrompts.present.selectedLayerId;
};
const selectRenderableLayers = (n: Konva.Node) =>
n.name() === MASKED_GUIDANCE_LAYER_NAME || n.name() === CONTROLNET_LAYER_NAME;
const selectVectorMaskObjects = (node: Konva.Node) => {
return node.name() === MASKED_GUIDANCE_LAYER_LINE_NAME || node.name() === MASKED_GUIDANCE_LAYER_RECT_NAME;
};
@ -219,7 +229,7 @@ const renderToolPreview = (
* @param reduxLayer The redux layer to create the konva layer from.
* @param onLayerPosChanged Callback for when the layer's position changes.
*/
const createVectorMaskLayer = (
const createMaskedGuidanceLayer = (
stage: Konva.Stage,
reduxLayer: MaskedGuidanceLayer,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
@ -320,7 +330,7 @@ const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Gro
* @param globalMaskLayerOpacity The opacity of the global mask layer.
* @param tool The current tool.
*/
const renderVectorMaskLayer = (
const renderMaskedGuidanceLayer = (
stage: Konva.Stage,
reduxLayer: MaskedGuidanceLayer,
globalMaskLayerOpacity: number,
@ -328,7 +338,7 @@ const renderVectorMaskLayer = (
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
): void => {
const konvaLayer =
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createVectorMaskLayer(stage, reduxLayer, onLayerPosChanged);
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createMaskedGuidanceLayer(stage, reduxLayer, onLayerPosChanged);
// Update the layer's position and listening state
konvaLayer.setAttrs({
@ -383,8 +393,8 @@ const renderVectorMaskLayer = (
}
// Only update layer visibility if it has changed.
if (konvaLayer.visible() !== reduxLayer.isVisible) {
konvaLayer.visible(reduxLayer.isVisible);
if (konvaLayer.visible() !== reduxLayer.isEnabled) {
konvaLayer.visible(reduxLayer.isEnabled);
groupNeedsCache = true;
}
@ -401,6 +411,101 @@ const renderVectorMaskLayer = (
}
};
const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer): Konva.Layer => {
const konvaLayer = new Konva.Layer({
id: reduxLayer.id,
name: CONTROLNET_LAYER_NAME,
imageSmoothingEnabled: false,
});
stage.add(konvaLayer);
return konvaLayer;
};
const createControlNetLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({
name: CONTROLNET_LAYER_IMAGE_NAME,
image,
filters: [LightnessToAlphaFilter],
});
konvaLayer.add(konvaImage);
return konvaImage;
};
const updateControlNetLayerImageSource = async (
stage: Konva.Stage,
konvaLayer: Konva.Layer,
reduxLayer: ControlAdapterLayer
) => {
if (reduxLayer.imageName) {
const imageName = reduxLayer.imageName;
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(reduxLayer.imageName));
const imageDTO = await req.unwrap();
req.unsubscribe();
const image = new Image();
const imageId = getControlNetLayerImageId(reduxLayer.id, imageName);
image.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage =
konvaLayer.findOne<Konva.Image>(`.${CONTROLNET_LAYER_IMAGE_NAME}`) ??
createControlNetLayerImage(konvaLayer, image);
// Update the image's attributes
konvaImage.setAttrs({
id: imageId,
image,
});
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
// Must cache after this to apply the filters
konvaImage.cache();
image.id = imageId;
};
image.src = imageDTO.image_url;
} else {
konvaLayer.findOne(`.${CONTROLNET_LAYER_IMAGE_NAME}`)?.destroy();
}
};
const updateControlNetLayerImageAttrs = (
stage: Konva.Stage,
konvaImage: Konva.Image,
reduxLayer: ControlAdapterLayer
) => {
konvaImage.setAttrs({
opacity: reduxLayer.opacity,
scaleX: 1,
scaleY: 1,
width: stage.width() / stage.scaleX(),
height: stage.height() / stage.scaleY(),
visible: reduxLayer.isEnabled,
});
konvaImage.cache();
};
const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer) => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createControlNetLayer(stage, reduxLayer);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CONTROLNET_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) {
if (
reduxLayer.imageName &&
canvasImageSource.id !== getControlNetLayerImageId(reduxLayer.id, reduxLayer.imageName)
) {
imageSourceNeedsUpdate = true;
} else if (!reduxLayer.imageName) {
imageSourceNeedsUpdate = true;
}
} else if (!canvasImageSource) {
imageSourceNeedsUpdate = true;
}
if (imageSourceNeedsUpdate) {
updateControlNetLayerImageSource(stage, konvaLayer, reduxLayer);
} else if (konvaImage) {
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer);
}
};
/**
* Renders the layers on the stage.
* @param stage The konva stage to render on.
@ -416,10 +521,9 @@ const renderLayers = (
tool: Tool,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
) => {
const reduxLayerIds = reduxLayers.map(mapId);
const reduxLayerIds = reduxLayers.filter(isRenderableLayer).map(mapId);
// Remove un-rendered layers
for (const konvaLayer of stage.find<Konva.Layer>(`.${MASKED_GUIDANCE_LAYER_NAME}`)) {
for (const konvaLayer of stage.find<Konva.Layer>(selectRenderableLayers)) {
if (!reduxLayerIds.includes(konvaLayer.id())) {
konvaLayer.destroy();
}
@ -427,7 +531,10 @@ const renderLayers = (
for (const reduxLayer of reduxLayers) {
if (isMaskedGuidanceLayer(reduxLayer)) {
renderVectorMaskLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
renderMaskedGuidanceLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
}
if (isControlAdapterLayer(reduxLayer)) {
renderControlNetLayer(stage, reduxLayer);
}
}
};
@ -620,3 +727,20 @@ export const debouncedRenderers = {
renderBackground: debounce(renderBackground, DEBOUNCE_MS),
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
};
/**
* Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
* This is useful for edge maps and other masks, to make the black areas transparent.
* @param imageData The image data to apply the filter to
*/
const LightnessToAlphaFilter = (imageData: ImageData) => {
const len = imageData.data.length / 4;
for (let i = 0; i < len; i++) {
const r = imageData.data[i * 4 + 0] as number;
const g = imageData.data[i * 4 + 1] as number;
const b = imageData.data[i * 4 + 2] as number;
const cMin = Math.min(r, g, b);
const cMax = Math.max(r, g, b);
imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
}
};

View File

@ -13,7 +13,10 @@ import {
selectValidIPAdapters,
selectValidT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isMaskedGuidanceLayer, selectRegionalPromptsSlice } from 'features/regionalPrompts/store/regionalPromptsSlice';
import {
selectAllControlAdapterIds,
selectRegionalPromptsSlice,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react';
@ -26,10 +29,10 @@ const selector = createMemoizedSelector(
const badges: string[] = [];
let isError = false;
const regionalControlAdapterIds = selectAllControlAdapterIds(regionalPrompts.present);
const enabledNonRegionalIPAdapterCount = selectAllIPAdapters(controlAdapters)
.filter(
(ca) => !regionalPrompts.present.layers.filter(isMaskedGuidanceLayer).some((l) => l.ipAdapterIds.includes(ca.id))
)
.filter((ca) => !regionalControlAdapterIds.includes(ca.id))
.filter((ca) => ca.isEnabled).length;
const validIPAdapterCount = selectValidIPAdapters(controlAdapters).length;
@ -40,7 +43,9 @@ const selector = createMemoizedSelector(
isError = true;
}
const enabledControlNetCount = selectAllControlNets(controlAdapters).filter((ca) => ca.isEnabled).length;
const enabledControlNetCount = selectAllControlNets(controlAdapters)
.filter((ca) => !regionalControlAdapterIds.includes(ca.id))
.filter((ca) => ca.isEnabled).length;
const validControlNetCount = selectValidControlNets(controlAdapters).length;
if (enabledControlNetCount > 0) {
badges.push(`${enabledControlNetCount} ControlNet`);
@ -49,7 +54,9 @@ const selector = createMemoizedSelector(
isError = true;
}
const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters).filter((ca) => ca.isEnabled).length;
const enabledT2IAdapterCount = selectAllT2IAdapters(controlAdapters)
.filter((ca) => !regionalControlAdapterIds.includes(ca.id))
.filter((ca) => ca.isEnabled).length;
const validT2IAdapterCount = selectValidT2IAdapters(controlAdapters).length;
if (enabledT2IAdapterCount > 0) {
badges.push(`${enabledT2IAdapterCount} T2I`);
@ -59,7 +66,7 @@ const selector = createMemoizedSelector(
}
const controlAdapterIds = selectControlAdapterIds(controlAdapters).filter(
(id) => !regionalPrompts.present.layers.filter(isMaskedGuidanceLayer).some((l) => l.ipAdapterIds.includes(id))
(id) => !regionalControlAdapterIds.includes(id)
);
return {

View File

@ -48,7 +48,7 @@ const ParametersPanel = () => {
<Flex gap={2} flexDirection="column" h="full" w="full">
<ImageSettingsAccordion />
<GenerationSettingsAccordion />
<ControlSettingsAccordion />
{activeTabName !== 'txt2img' && <ControlSettingsAccordion />}
{activeTabName === 'unifiedCanvas' && <CompositingSettingsAccordion />}
{isSDXL && <RefinerSettingsAccordion />}
<AdvancedSettingsAccordion />