feat(ui): refactor control adapters

Control adapters logic/state/ui is now generalized to hold controlnet, ip_adapter and t2i_adapter. In the future, other control adapter types can be added.

TODO:
- Limit IP adapter to 1
- Add T2I adapter to linear graphs
- Fix autoprocess
- T2I metadata saving & recall
- Improve on control adapters UI
This commit is contained in:
psychedelicious 2023-10-05 22:40:21 +11:00
parent 9c720da021
commit 9508e0c9db
70 changed files with 1860 additions and 1236 deletions

View File

@ -50,6 +50,7 @@
"close": "Close", "close": "Close",
"communityLabel": "Community", "communityLabel": "Community",
"controlNet": "Controlnet", "controlNet": "Controlnet",
"controlAdapter": "Control Adapter",
"ipAdapter": "IP Adapter", "ipAdapter": "IP Adapter",
"darkMode": "Dark Mode", "darkMode": "Dark Mode",
"discordLabel": "Discord", "discordLabel": "Discord",

View File

@ -1,5 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist'; import { controlAdaptersPersistDenylist } from 'features/controlNet/store/controlAdaptersPersistDenylist';
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist'; import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist'; import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
@ -20,7 +20,7 @@ const serializationDenylist: {
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,
ui: uiPersistDenylist, ui: uiPersistDenylist,
controlNet: controlNetDenylist, controlNet: controlAdaptersPersistDenylist,
dynamicPrompts: dynamicPromptsPersistDenylist, dynamicPrompts: dynamicPromptsPersistDenylist,
}; };

View File

@ -1,5 +1,5 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice'; import { initialControlAdapterState } from 'features/controlNet/store/controlAdaptersSlice';
import { initialDynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { initialDynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice';
@ -25,7 +25,7 @@ const initialStates: {
config: initialConfigState, config: initialConfigState,
ui: initialUIState, ui: initialUIState,
hotkeys: initialHotkeysState, hotkeys: initialHotkeysState,
controlNet: initialControlNetState, controlAdapters: initialControlAdapterState,
dynamicPrompts: initialDynamicPromptsState, dynamicPrompts: initialDynamicPromptsState,
sdxl: initialSDXLState, sdxl: initialSDXLState,
}; };

View File

@ -1,8 +1,5 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { import { controlAdaptersReset } from 'features/controlNet/store/controlAdaptersSlice';
controlNetReset,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors'; import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
@ -20,8 +17,7 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
let wasInitialImageReset = false; let wasInitialImageReset = false;
let wasCanvasReset = false; let wasCanvasReset = false;
let wasNodeEditorReset = false; let wasNodeEditorReset = false;
let wasControlNetReset = false; let wereControlAdaptersReset = false;
let wasIPAdapterReset = false;
const state = getState(); const state = getState();
deleted_images.forEach((image_name) => { deleted_images.forEach((image_name) => {
@ -42,14 +38,9 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
wasNodeEditorReset = true; wasNodeEditorReset = true;
} }
if (imageUsage.isControlNetImage && !wasControlNetReset) { if (imageUsage.isControlImage && !wereControlAdaptersReset) {
dispatch(controlNetReset()); dispatch(controlAdaptersReset());
wasControlNetReset = true; wereControlAdaptersReset = true;
}
if (imageUsage.isIPAdapterImage && !wasIPAdapterReset) {
dispatch(ipAdapterStateReset());
wasIPAdapterReset = true;
} }
}); });
}, },

View File

@ -1,20 +1,21 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlAdapterImageChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
export const addCanvasImageToControlNetListener = () => { export const addCanvasImageToControlNetListener = () => {
startAppListening({ startAppListening({
actionCreator: canvasImageToControlNet, actionCreator: canvasImageToControlAdapter,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
const log = logger('canvas'); const log = logger('canvas');
const state = getState(); const state = getState();
const { id } = action.payload;
let blob; let blob: Blob;
try { try {
blob = await getBaseLayerBlob(state, true); blob = await getBaseLayerBlob(state, true);
} catch (err) { } catch (err) {
@ -50,8 +51,8 @@ export const addCanvasImageToControlNetListener = () => {
const { image_name } = imageDTO; const { image_name } = imageDTO;
dispatch( dispatch(
controlNetImageChanged({ controlAdapterImageChanged({
controlNetId: action.payload.controlNet.controlNetId, id,
controlImage: image_name, controlImage: image_name,
}) })
); );

View File

@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { canvasMaskToControlNet } from 'features/canvas/store/actions'; import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlAdapterImageChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -9,11 +9,11 @@ import { startAppListening } from '..';
export const addCanvasMaskToControlNetListener = () => { export const addCanvasMaskToControlNetListener = () => {
startAppListening({ startAppListening({
actionCreator: canvasMaskToControlNet, actionCreator: canvasMaskToControlAdapter,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
const log = logger('canvas'); const log = logger('canvas');
const state = getState(); const state = getState();
const { id } = action.payload;
const canvasBlobsAndImageData = await getCanvasData( const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState, state.canvas.layerState,
state.canvas.boundingBoxCoordinates, state.canvas.boundingBoxCoordinates,
@ -61,8 +61,8 @@ export const addCanvasMaskToControlNetListener = () => {
const { image_name } = imageDTO; const { image_name } = imageDTO;
dispatch( dispatch(
controlNetImageChanged({ controlAdapterImageChanged({
controlNetId: action.payload.controlNet.controlNetId, id,
controlImage: image_name, controlImage: image_name,
}) })
); );

View File

@ -1,15 +1,24 @@
import { AnyListenerPredicate } from '@reduxjs/toolkit'; import { AnyListenerPredicate } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { controlNetImageProcessed } from 'features/controlNet/store/actions'; import { controlAdapterImageProcessed } from 'features/controlNet/store/actions';
import { import {
controlNetAutoConfigToggled, controlAdapterAutoConfigToggled,
controlNetImageChanged, controlAdapterImageChanged,
controlNetModelChanged, controlAdapterModelChanged,
controlNetProcessorParamsChanged, controlAdapterProcessorParamsChanged,
controlNetProcessorTypeChanged, controlAdapterProcessortTypeChanged,
} from 'features/controlNet/store/controlNetSlice'; selectControlAdapterById,
} from 'features/controlNet/store/controlAdaptersSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
type AnyControlAdapterParamChangeAction =
| ReturnType<typeof controlAdapterProcessorParamsChanged>
| ReturnType<typeof controlAdapterModelChanged>
| ReturnType<typeof controlAdapterImageChanged>
| ReturnType<typeof controlAdapterProcessortTypeChanged>
| ReturnType<typeof controlAdapterAutoConfigToggled>;
const predicate: AnyListenerPredicate<RootState> = ( const predicate: AnyListenerPredicate<RootState> = (
action, action,
@ -17,35 +26,31 @@ const predicate: AnyListenerPredicate<RootState> = (
prevState prevState
) => { ) => {
const isActionMatched = const isActionMatched =
controlNetProcessorParamsChanged.match(action) || controlAdapterProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) || controlAdapterModelChanged.match(action) ||
controlNetImageChanged.match(action) || controlAdapterImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action) || controlAdapterProcessortTypeChanged.match(action) ||
controlNetAutoConfigToggled.match(action); controlAdapterAutoConfigToggled.match(action);
if (!isActionMatched) { if (!isActionMatched) {
return false; return false;
} }
if (controlNetAutoConfigToggled.match(action)) { const { id } = action.payload;
const ca = selectControlAdapterById(prevState.controlAdapters, id);
if (!ca || !isControlNetOrT2IAdapter(ca)) {
return false;
}
if (controlAdapterAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config // do not process if the user just disabled auto-config
if ( if (ca.shouldAutoConfig === true) {
prevState.controlNet.controlNets[action.payload.controlNetId]
?.shouldAutoConfig === true
) {
return false; return false;
} }
} }
const cn = state.controlNet.controlNets[action.payload.controlNetId]; const { controlImage, processorType, shouldAutoConfig } = ca;
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
if (!cn) {
// something is wrong, the controlNet should exist
return false;
}
const { controlImage, processorType, shouldAutoConfig } = cn;
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty // do not process if the action is a model change but the processor settings are dirty
return false; return false;
} }
@ -67,7 +72,7 @@ export const addControlNetAutoProcessListener = () => {
predicate, predicate,
effect: async (action, { dispatch, cancelActiveListeners, delay }) => { effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const log = logger('session'); const log = logger('session');
const { controlNetId } = action.payload; const { id } = (action as AnyControlAdapterParamChangeAction).payload;
// Cancel any in-progress instances of this listener // Cancel any in-progress instances of this listener
cancelActiveListeners(); cancelActiveListeners();
@ -75,7 +80,7 @@ export const addControlNetAutoProcessListener = () => {
// Delay before starting actual work // Delay before starting actual work
await delay(300); await delay(300);
dispatch(controlNetImageProcessed({ controlNetId })); dispatch(controlAdapterImageProcessed({ id }));
}, },
}); });
}; };

View File

@ -1,11 +1,11 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { import {
clearPendingControlImages, pendingControlImagesCleared,
controlNetImageChanged, controlAdapterImageChanged,
controlNetProcessedImageChanged, selectControlAdapterById,
} from 'features/controlNet/store/controlNetSlice'; controlAdapterProcessedImageChanged,
} from 'features/controlNet/store/controlAdaptersSlice';
import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants'; import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
@ -15,16 +15,18 @@ import { isImageOutput } from 'services/api/guards';
import { Graph, ImageDTO } from 'services/api/types'; import { Graph, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlAdapterImageProcessed } from 'features/controlNet/store/actions';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
export const addControlNetImageProcessedListener = () => { export const addControlNetImageProcessedListener = () => {
startAppListening({ startAppListening({
actionCreator: controlNetImageProcessed, actionCreator: controlAdapterImageProcessed,
effect: async (action, { dispatch, getState, take }) => { effect: async (action, { dispatch, getState, take }) => {
const log = logger('session'); const log = logger('session');
const { controlNetId } = action.payload; const { id } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId]; const ca = selectControlAdapterById(getState().controlAdapters, id);
if (!controlNet?.controlImage) { if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
log.error('Unable to process ControlNet image'); log.error('Unable to process ControlNet image');
return; return;
} }
@ -33,10 +35,10 @@ export const addControlNetImageProcessedListener = () => {
// Also we need to grab the image. // Also we need to grab the image.
const graph: Graph = { const graph: Graph = {
nodes: { nodes: {
[controlNet.processorNode.id]: { [ca.processorNode.id]: {
...controlNet.processorNode, ...ca.processorNode,
is_intermediate: true, is_intermediate: true,
image: { image_name: controlNet.controlImage }, image: { image_name: ca.controlImage },
}, },
[SAVE_IMAGE]: { [SAVE_IMAGE]: {
id: SAVE_IMAGE, id: SAVE_IMAGE,
@ -48,7 +50,7 @@ export const addControlNetImageProcessedListener = () => {
edges: [ edges: [
{ {
source: { source: {
node_id: controlNet.processorNode.id, node_id: ca.processorNode.id,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -103,8 +105,8 @@ export const addControlNetImageProcessedListener = () => {
// Update the processed image in the store // Update the processed image in the store
dispatch( dispatch(
controlNetProcessedImageChanged({ controlAdapterProcessedImageChanged({
controlNetId, id,
processedControlImage: processedControlImage.image_name, processedControlImage: processedControlImage.image_name,
}) })
); );
@ -126,10 +128,8 @@ export const addControlNetImageProcessedListener = () => {
duration: 15000, duration: 15000,
}) })
); );
dispatch(clearPendingControlImages()); dispatch(pendingControlImagesCleared());
dispatch( dispatch(controlAdapterImageChanged({ id, controlImage: null }));
controlNetImageChanged({ controlNetId, controlImage: null })
);
return; return;
} }
} }

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { import {
controlNetImageChanged, controlAdapterImageChanged,
controlNetProcessedImageChanged, controlAdapterProcessedImageChanged,
ipAdapterImageChanged, selectControlAdapterAll,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlAdaptersSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions'; import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice'; import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
@ -17,6 +17,7 @@ import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util'; import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
export const addRequestedSingleImageDeletionListener = () => { export const addRequestedSingleImageDeletionListener = () => {
startAppListening({ startAppListening({
@ -90,35 +91,28 @@ export const addRequestedSingleImageDeletionListener = () => {
dispatch(clearInitialImage()); dispatch(clearInitialImage());
} }
// reset controlNets that use the deleted images // reset control adapters that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => { forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
if ( if (
controlNet.controlImage === imageDTO.image_name || ca.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name (isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
) { ) {
dispatch( dispatch(
controlNetImageChanged({ controlAdapterImageChanged({
controlNetId: controlNet.controlNetId, id: ca.id,
controlImage: null, controlImage: null,
}) })
); );
dispatch( dispatch(
controlNetProcessedImageChanged({ controlAdapterProcessedImageChanged({
controlNetId: controlNet.controlNetId, id: ca.id,
processedControlImage: null, processedControlImage: null,
}) })
); );
} }
}); });
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images // reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => { getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
@ -215,35 +209,28 @@ export const addRequestedMultipleImageDeletionListener = () => {
dispatch(clearInitialImage()); dispatch(clearInitialImage());
} }
// reset controlNets that use the deleted images // reset control adapters that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => { forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
if ( if (
controlNet.controlImage === imageDTO.image_name || ca.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name (isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
) { ) {
dispatch( dispatch(
controlNetImageChanged({ controlAdapterImageChanged({
controlNetId: controlNet.controlNetId, id: ca.id,
controlImage: null, controlImage: null,
}) })
); );
dispatch( dispatch(
controlNetProcessedImageChanged({ controlAdapterProcessedImageChanged({
controlNetId: controlNet.controlNetId, id: ca.id,
processedControlImage: null, processedControlImage: null,
}) })
); );
} }
}); });
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images // reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => { getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {

View File

@ -3,11 +3,9 @@ import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { import {
controlNetImageChanged, controlAdapterImageChanged,
controlNetIsEnabledChanged, controlAdapterIsEnabledChanged,
ipAdapterImageChanged, } from 'features/controlNet/store/controlAdaptersSlice';
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
import { import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
@ -90,39 +88,26 @@ export const addImageDroppedListener = () => {
* Image dropped on ControlNet * Image dropped on ControlNet
*/ */
if ( if (
overData.actionType === 'SET_CONTROLNET_IMAGE' && overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
const { controlNetId } = overData.context; const { id } = overData.context;
dispatch( dispatch(
controlNetImageChanged({ controlAdapterImageChanged({
id,
controlImage: activeData.payload.imageDTO.image_name, controlImage: activeData.payload.imageDTO.image_name,
controlNetId,
}) })
); );
dispatch( dispatch(
controlNetIsEnabledChanged({ controlAdapterIsEnabledChanged({
controlNetId, id,
isEnabled: true, isEnabled: true,
}) })
); );
return; return;
} }
/**
* Image dropped on IP Adapter image
*/
if (
overData.actionType === 'SET_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
return;
}
/** /**
* Image dropped on Canvas * Image dropped on Canvas
*/ */

View File

@ -18,8 +18,7 @@ export const addImageToDeleteSelectedListener = () => {
const isImageInUse = const isImageInUse =
imagesUsage.some((i) => i.isCanvasImage) || imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) || imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlNetImage) || imagesUsage.some((i) => i.isControlImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isNodesImage); imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) { if (shouldConfirmOnDelete || isImageInUse) {

View File

@ -2,11 +2,9 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { import {
controlNetImageChanged, controlAdapterImageChanged,
controlNetIsEnabledChanged, controlAdapterIsEnabledChanged,
ipAdapterImageChanged, } from 'features/controlNet/store/controlAdaptersSlice';
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
@ -87,17 +85,17 @@ export const addImageUploadedFulfilledListener = () => {
return; return;
} }
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') { if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
const { controlNetId } = postUploadAction; const { id } = postUploadAction;
dispatch( dispatch(
controlNetIsEnabledChanged({ controlAdapterIsEnabledChanged({
controlNetId, id,
isEnabled: true, isEnabled: true,
}) })
); );
dispatch( dispatch(
controlNetImageChanged({ controlAdapterImageChanged({
controlNetId, id,
controlImage: imageDTO.image_name, controlImage: imageDTO.image_name,
}) })
); );
@ -110,18 +108,6 @@ export const addImageUploadedFulfilledListener = () => {
return; return;
} }
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setIPAdapterImage'),
})
);
return;
}
if (postUploadAction?.type === 'SET_INITIAL_IMAGE') { if (postUploadAction?.type === 'SET_INITIAL_IMAGE') {
dispatch(initialImageChanged(imageDTO)); dispatch(initialImageChanged(imageDTO));
dispatch( dispatch(

View File

@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { import {
controlNetRemoved, controlAdapterRemoved,
ipAdapterStateReset, selectControlAdapterAll,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice'; import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { import {
@ -60,24 +60,13 @@ export const addModelSelectedListener = () => {
} }
// handle incompatible controlnets // handle incompatible controlnets
const { controlNets } = state.controlNet; selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
forEach(controlNets, (controlNet, controlNetId) => { if (ca.model?.base_model !== base_model) {
if (controlNet.model?.base_model !== base_model) { dispatch(controlAdapterRemoved({ id: ca.id }));
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1; modelsCleared += 1;
} }
}); });
// handle incompatible IP-Adapter
const { ipAdapterInfo } = state.controlNet;
if (
ipAdapterInfo.model &&
ipAdapterInfo.model.base_model !== base_model
) {
dispatch(ipAdapterStateReset());
modelsCleared += 1;
}
if (modelsCleared > 0) { if (modelsCleared > 0) {
dispatch( dispatch(
addToast( addToast(

View File

@ -1,8 +1,10 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { import {
controlNetRemoved, controlAdapterModelCleared,
ipAdapterModelChanged, selectAllControlNets,
} from 'features/controlNet/store/controlNetSlice'; selectAllIPAdapters,
selectAllT2IAdapters,
} from 'features/controlNet/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice'; import { loraRemoved } from 'features/lora/store/loraSlice';
import { import {
modelChanged, modelChanged,
@ -19,14 +21,12 @@ import {
} from 'features/sdxl/store/sdxlSlice'; } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { import {
ipAdapterModelsAdapter,
mainModelsAdapter, mainModelsAdapter,
modelsApi, modelsApi,
vaeModelsAdapter, vaeModelsAdapter,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types'; import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { zIPAdapterModel } from 'features/nodes/types/types';
export const addModelsLoadedListener = () => { export const addModelsLoadedListener = () => {
startAppListening({ startAppListening({
@ -221,21 +221,45 @@ export const addModelsLoadedListener = () => {
`ControlNet models loaded (${action.payload.ids.length})` `ControlNet models loaded (${action.payload.ids.length})`
); );
const controlNets = getState().controlNet.controlNets; selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities, action.payload.entities,
(m) => (m) =>
m?.model_name === controlNet?.model?.model_name && m?.model_name === ca?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model m?.base_model === ca?.model?.base_model
); );
if (isControlNetAvailable) { if (isModelAvailable) {
return; return;
} }
dispatch(controlNetRemoved({ controlNetId })); dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info(
{ models: action.payload.entities },
`ControlNet models loaded (${action.payload.ids.length})`
);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
}); });
}, },
}); });
@ -249,38 +273,20 @@ export const addModelsLoadedListener = () => {
`IP Adapter models loaded (${action.payload.ids.length})` `IP Adapter models loaded (${action.payload.ids.length})`
); );
const { model } = getState().controlNet.ipAdapterInfo; selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some( const isModelAvailable = some(
action.payload.entities, action.payload.entities,
(m) => (m) =>
m?.model_name === model?.model_name && m?.model_name === ca?.model?.model_name &&
m?.base_model === model?.base_model m?.base_model === ca?.model?.base_model
); );
if (isModelAvailable) { if (isModelAvailable) {
return; return;
} }
const firstModel = ipAdapterModelsAdapter dispatch(controlAdapterModelCleared({ id: ca.id }));
.getSelectors() });
.selectAll(action.payload)[0];
if (!firstModel) {
dispatch(ipAdapterModelChanged(null));
}
const result = zIPAdapterModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse IP Adapter model'
);
return;
}
dispatch(ipAdapterModelChanged(result.data));
}, },
}); });
startAppListening({ startAppListening({

View File

@ -7,7 +7,7 @@ import {
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice'; import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import controlAdaptersReducer from 'features/controlNet/store/controlAdaptersSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice'; import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
@ -44,7 +44,7 @@ const allReducers = {
config: configReducer, config: configReducer,
ui: uiReducer, ui: uiReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
controlNet: controlNetReducer, controlAdapters: controlAdaptersReducer,
dynamicPrompts: dynamicPromptsReducer, dynamicPrompts: dynamicPromptsReducer,
deleteImageModal: deleteImageModalReducer, deleteImageModal: deleteImageModalReducer,
changeBoardModal: changeBoardModalReducer, changeBoardModal: changeBoardModalReducer,
@ -68,7 +68,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'postprocessing', 'postprocessing',
'system', 'system',
'ui', 'ui',
'controlNet', 'controlAdapters',
'dynamicPrompts', 'dynamicPrompts',
'lora', 'lora',
'modelmanager', 'modelmanager',

View File

@ -15,7 +15,7 @@ type UseImageUploadButtonArgs = {
* @example * @example
* const { getUploadButtonProps, getUploadInputProps, openUploader } = useImageUploadButton({ * const { getUploadButtonProps, getUploadInputProps, openUploader } = useImageUploadButton({
* postUploadAction: { * postUploadAction: {
* type: 'SET_CONTROLNET_IMAGE', * type: 'SET_CONTROL_ADAPTER_IMAGE',
* controlNetId: '12345', * controlNetId: '12345',
* }, * },
* isDisabled: getIsUploadDisabled(), * isDisabled: getIsUploadDisabled(),

View File

@ -2,16 +2,18 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectControlAdapterAll } from 'features/controlNet/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
import { isInvocationNode } from 'features/nodes/types/types'; import { isInvocationNode } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next'; import i18n from 'i18next';
import { forEach, map } from 'lodash-es'; import { forEach } from 'lodash-es';
import { getConnectedEdges } from 'reactflow'; import { getConnectedEdges } from 'reactflow';
const selector = createSelector( const selector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
( (
{ controlNet, generation, system, nodes, dynamicPrompts }, { controlAdapters, generation, system, nodes, dynamicPrompts },
activeTabName activeTabName
) => { ) => {
const { initialImage, model } = generation; const { initialImage, model } = generation;
@ -87,21 +89,21 @@ const selector = createSelector(
reasons.push(i18n.t('parameters.invoke.noModelSelected')); reasons.push(i18n.t('parameters.invoke.noModelSelected'));
} }
if (controlNet.isEnabled) { selectControlAdapterAll(controlAdapters).forEach((ca, i) => {
map(controlNet.controlNets).forEach((controlNet, i) => { if (!ca.isEnabled) {
if (!controlNet.isEnabled) {
return; return;
} }
if (!controlNet.model) { if (!ca.model) {
reasons.push( reasons.push(
i18n.t('parameters.invoke.noModelForControlNet', { index: i + 1 }) i18n.t('parameters.invoke.noModelForControlNet', { index: i + 1 })
); );
} }
if ( if (
!controlNet.controlImage || !ca.controlImage ||
(!controlNet.processedControlImage && (isControlNetOrT2IAdapter(ca) &&
controlNet.processorType !== 'none') !ca.processedControlImage &&
ca.processorType !== 'none')
) { ) {
reasons.push( reasons.push(
i18n.t('parameters.invoke.noControlImageForControlNet', { i18n.t('parameters.invoke.noControlImageForControlNet', {
@ -111,7 +113,6 @@ const selector = createSelector(
} }
}); });
} }
}
return { isReady: !reasons.length, reasons }; return { isReady: !reasons.length, reasons };
}, },

View File

@ -1,5 +1,4 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery'); export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@ -22,10 +21,10 @@ export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved' 'canvas/stagingAreaImageSaved'
); );
export const canvasMaskToControlNet = createAction<{ export const canvasMaskToControlAdapter = createAction<{ id: string }>(
controlNet: ControlNetConfig; 'canvas/canvasMaskToControlAdapter'
}>('canvas/canvasMaskToControlNet'); );
export const canvasImageToControlNet = createAction<{ export const canvasImageToControlAdapter = createAction<{ id: string }>(
controlNet: ControlNetConfig; 'canvas/canvasImageToControlAdapter'
}>('canvas/canvasImageToControlNet'); );

View File

@ -3,11 +3,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa'; import { FaCopy, FaTrash } from 'react-icons/fa';
import { import {
ControlNetConfig, controlAdapterDuplicated,
controlNetDuplicated, controlAdapterIsEnabledChanged,
controlNetRemoved, controlAdapterRemoved,
controlNetIsEnabledChanged, selectControlAdapterById,
} from '../store/controlNetSlice'; } from '../store/controlAdaptersSlice';
import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
@ -19,8 +19,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useToggle } from 'react-use'; import { isControlNetOrT2IAdapter } from '../store/types';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview'; import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent'; import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig'; import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
@ -29,14 +28,12 @@ import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode'; import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
import { useToggle } from 'react-use';
import { useControlAdapterType } from '../hooks/useControlAdapterType';
type ControlNetProps = { const ControlNet = (props: { id: string }) => {
controlNet: ControlNetConfig; const { id } = props;
}; const controlAdapterType = useControlAdapterType(id);
const ControlNet = (props: ControlNetProps) => {
const { controlNet } = props;
const { controlNetId } = controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -44,8 +41,8 @@ const ControlNet = (props: ControlNetProps) => {
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ controlNet }) => { ({ controlAdapters }) => {
const cn = controlNet.controlNets[controlNetId]; const cn = selectControlAdapterById(controlAdapters, id);
if (!cn) { if (!cn) {
return { return {
@ -54,9 +51,15 @@ const ControlNet = (props: ControlNetProps) => {
}; };
} }
const { isEnabled, shouldAutoConfig } = cn; const isEnabled = cn.isEnabled;
const shouldAutoConfig = isControlNetOrT2IAdapter(cn)
? cn.shouldAutoConfig
: false;
return { isEnabled, shouldAutoConfig }; return {
isEnabled,
shouldAutoConfig,
};
}, },
defaultSelectorOptions defaultSelectorOptions
); );
@ -65,30 +68,29 @@ const ControlNet = (props: ControlNetProps) => {
const [isExpanded, toggleIsExpanded] = useToggle(false); const [isExpanded, toggleIsExpanded] = useToggle(false);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId })); dispatch(controlAdapterRemoved({ id }));
}, [controlNetId, dispatch]); }, [id, dispatch]);
const handleDuplicate = useCallback(() => { const handleDuplicate = useCallback(() => {
dispatch( dispatch(controlAdapterDuplicated(id));
controlNetDuplicated({ }, [id, dispatch]);
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
);
}, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback( const handleToggleIsEnabled = useCallback(
(e: ChangeEvent<HTMLInputElement>) => { (e: ChangeEvent<HTMLInputElement>) => {
dispatch( dispatch(
controlNetIsEnabledChanged({ controlAdapterIsEnabledChanged({
controlNetId, id,
isEnabled: e.target.checked, isEnabled: e.target.checked,
}) })
); );
}, },
[controlNetId, dispatch] [id, dispatch]
); );
if (!controlAdapterType) {
return null;
}
return ( return (
<Flex <Flex
sx={{ sx={{
@ -120,10 +122,10 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s', transitionDuration: '0.1s',
}} }}
> >
<ParamControlNetModel controlNet={controlNet} /> <ParamControlNetModel id={id} />
</Box> </Box>
{activeTabName === 'unifiedCanvas' && ( {activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports controlNet={controlNet} /> <ControlNetCanvasImageImports id={id} />
)} )}
<IAIIconButton <IAIIconButton
size="sm" size="sm"
@ -207,8 +209,8 @@ const ControlNet = (props: ControlNetProps) => {
justifyContent: 'space-between', justifyContent: 'space-between',
}} }}
> >
<ParamControlNetWeight controlNet={controlNet} /> <ParamControlNetWeight id={id} />
<ParamControlNetBeginEnd controlNet={controlNet} /> <ParamControlNetBeginEnd id={id} />
</Flex> </Flex>
{!isExpanded && ( {!isExpanded && (
<Flex <Flex
@ -220,22 +222,22 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1', aspectRatio: '1/1',
}} }}
> >
<ControlNetImagePreview controlNet={controlNet} isSmall /> <ControlNetImagePreview id={id} isSmall />
</Flex> </Flex>
)} )}
</Flex> </Flex>
<Flex sx={{ gap: 2 }}> <Flex sx={{ gap: 2 }}>
<ParamControlNetControlMode controlNet={controlNet} /> <ParamControlNetControlMode id={id} />
<ParamControlNetResizeMode controlNet={controlNet} /> <ParamControlNetResizeMode id={id} />
</Flex> </Flex>
<ParamControlNetProcessorSelect controlNet={controlNet} /> <ParamControlNetProcessorSelect id={id} />
</Flex> </Flex>
{isExpanded && ( {isExpanded && (
<> <>
<ControlNetImagePreview controlNet={controlNet} /> <ControlNetImagePreview id={id} />
<ParamControlNetShouldAutoConfig controlNet={controlNet} /> <ParamControlNetShouldAutoConfig id={id} />
<ControlNetProcessorComponent controlNet={controlNet} /> <ControlNetProcessorComponent id={id} />
</> </>
)} )}
</Flex> </Flex>

View File

@ -23,20 +23,20 @@ import {
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types'; import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon'; import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import { import { controlAdapterImageChanged } from '../store/controlAdaptersSlice';
ControlNetConfig, import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
controlNetImageChanged, import { useControlAdapterProcessedControlImage } from '../hooks/useControlAdapterProcessedControlImage';
} from '../store/controlNetSlice'; import { useControlAdapterProcessorType } from '../hooks/useControlAdapterProcessorType';
type Props = { type Props = {
controlNet: ControlNetConfig; id: string;
isSmall?: boolean; isSmall?: boolean;
}; };
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ controlNet, gallery }) => { ({ controlAdapters, gallery }) => {
const { pendingControlImages } = controlNet; const { pendingControlImages } = controlAdapters;
const { autoAddBoardId } = gallery; const { autoAddBoardId } = gallery;
return { return {
@ -47,13 +47,10 @@ const selector = createSelector(
defaultSelectorOptions defaultSelectorOptions
); );
const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => { const ControlNetImagePreview = ({ isSmall, id }: Props) => {
const { const controlImageName = useControlAdapterControlImage(id);
controlImage: controlImageName, const processedControlImageName = useControlAdapterProcessedControlImage(id);
processedControlImage: processedControlImageName, const processorType = useControlAdapterProcessorType(id);
processorType,
controlNetId,
} = controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -75,8 +72,8 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const [addToBoard] = useAddImageToBoardMutation(); const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation(); const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => { const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null })); dispatch(controlAdapterImageChanged({ id, controlImage: null }));
}, [controlNetId, dispatch]); }, [id, dispatch]);
const handleSaveControlImage = useCallback(async () => { const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) { if (!processedControlImage) {
@ -133,32 +130,32 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => { const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) { if (controlImage) {
return { return {
id: controlNetId, id,
payloadType: 'IMAGE_DTO', payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage }, payload: { imageDTO: controlImage },
}; };
} }
}, [controlImage, controlNetId]); }, [controlImage, id]);
const droppableData = useMemo<TypesafeDroppableData | undefined>( const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({ () => ({
id: controlNetId, id,
actionType: 'SET_CONTROLNET_IMAGE', actionType: 'SET_CONTROL_ADAPTER_IMAGE',
context: { controlNetId }, context: { id },
}), }),
[controlNetId] [id]
); );
const postUploadAction = useMemo<PostUploadAction>( const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }), () => ({ type: 'SET_CONTROL_ADAPTER_IMAGE', id }),
[controlNetId] [id]
); );
const shouldShowProcessedImage = const shouldShowProcessedImage =
controlImage && controlImage &&
processedControlImage && processedControlImage &&
!isMouseOverImage && !isMouseOverImage &&
!pendingControlImages.includes(controlNetId) && !pendingControlImages.includes(id) &&
processorType !== 'none'; processorType !== 'none';
return ( return (
@ -222,7 +219,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
/> />
</> </>
{pendingControlImages.includes(controlNetId) && ( {pendingControlImages.includes(id) && (
<Flex <Flex
sx={{ sx={{
position: 'absolute', position: 'absolute',

View File

@ -1,26 +1,26 @@
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetImageProcessed } from '../store/actions'; import IAIButton from 'common/components/IAIButton';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue'; import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
import { memo, useCallback } from 'react';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { controlAdapterImageProcessed } from '../store/actions';
type Props = { type Props = {
controlNet: ControlNetConfig; id: string;
}; };
const ControlNetPreprocessButton = (props: Props) => { const ControlNetPreprocessButton = ({ id }: Props) => {
const { controlNetId, controlImage } = props.controlNet; const controlImage = useControlAdapterControlImage(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToEnqueue(); const isReady = useIsReadyToEnqueue();
const handleProcess = useCallback(() => { const handleProcess = useCallback(() => {
dispatch( dispatch(
controlNetImageProcessed({ controlAdapterImageProcessed({
controlNetId, id,
}) })
); );
}, [controlNetId, dispatch]); }, [id, dispatch]);
return ( return (
<IAIButton <IAIButton

View File

@ -1,5 +1,4 @@
import { memo } from 'react'; import { memo } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import CannyProcessor from './processors/CannyProcessor'; import CannyProcessor from './processors/CannyProcessor';
import ColorMapProcessor from './processors/ColorMapProcessor'; import ColorMapProcessor from './processors/ColorMapProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor'; import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
@ -13,18 +12,25 @@ import NormalBaeProcessor from './processors/NormalBaeProcessor';
import OpenposeProcessor from './processors/OpenposeProcessor'; import OpenposeProcessor from './processors/OpenposeProcessor';
import PidiProcessor from './processors/PidiProcessor'; import PidiProcessor from './processors/PidiProcessor';
import ZoeDepthProcessor from './processors/ZoeDepthProcessor'; import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterProcessorNode } from '../hooks/useControlAdapterProcessorNode';
export type ControlNetProcessorProps = { export type Props = {
controlNet: ControlNetConfig; id: string;
}; };
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { const ControlNetProcessorComponent = ({ id }: Props) => {
const { controlNetId, isEnabled, processorNode } = props.controlNet; const isEnabled = useControlAdapterIsEnabled(id);
const processorNode = useControlAdapterProcessorNode(id);
if (!processorNode) {
return null;
}
if (processorNode.type === 'canny_image_processor') { if (processorNode.type === 'canny_image_processor') {
return ( return (
<CannyProcessor <CannyProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -34,7 +40,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'color_map_image_processor') { if (processorNode.type === 'color_map_image_processor') {
return ( return (
<ColorMapProcessor <ColorMapProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -44,7 +50,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'hed_image_processor') { if (processorNode.type === 'hed_image_processor') {
return ( return (
<HedProcessor <HedProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -54,7 +60,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'lineart_image_processor') { if (processorNode.type === 'lineart_image_processor') {
return ( return (
<LineartProcessor <LineartProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -64,7 +70,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'content_shuffle_image_processor') { if (processorNode.type === 'content_shuffle_image_processor') {
return ( return (
<ContentShuffleProcessor <ContentShuffleProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -74,7 +80,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'lineart_anime_image_processor') { if (processorNode.type === 'lineart_anime_image_processor') {
return ( return (
<LineartAnimeProcessor <LineartAnimeProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -84,7 +90,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'mediapipe_face_processor') { if (processorNode.type === 'mediapipe_face_processor') {
return ( return (
<MediapipeFaceProcessor <MediapipeFaceProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -94,7 +100,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'midas_depth_image_processor') { if (processorNode.type === 'midas_depth_image_processor') {
return ( return (
<MidasDepthProcessor <MidasDepthProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -104,7 +110,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'mlsd_image_processor') { if (processorNode.type === 'mlsd_image_processor') {
return ( return (
<MlsdImageProcessor <MlsdImageProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -114,7 +120,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'normalbae_image_processor') { if (processorNode.type === 'normalbae_image_processor') {
return ( return (
<NormalBaeProcessor <NormalBaeProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -124,7 +130,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'openpose_image_processor') { if (processorNode.type === 'openpose_image_processor') {
return ( return (
<OpenposeProcessor <OpenposeProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -134,7 +140,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'pidi_image_processor') { if (processorNode.type === 'pidi_image_processor') {
return ( return (
<PidiProcessor <PidiProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />
@ -144,7 +150,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'zoe_depth_image_processor') { if (processorNode.type === 'zoe_depth_image_processor') {
return ( return (
<ZoeDepthProcessor <ZoeDepthProcessor
controlNetId={controlNetId} controlNetId={id}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled} isEnabled={isEnabled}
/> />

View File

@ -1,24 +1,24 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { import { controlAdapterAutoConfigToggled } from 'features/controlNet/store/controlAdaptersSlice';
ControlNetConfig,
controlNetAutoConfigToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterShouldAutoConfig } from '../hooks/useControlAdapterShouldAutoConfig';
type Props = { type Props = {
controlNet: ControlNetConfig; id: string;
}; };
const ParamControlNetShouldAutoConfig = (props: Props) => { const ParamControlNetShouldAutoConfig = ({ id }: Props) => {
const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet; const isEnabled = useControlAdapterIsEnabled(id);
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const handleShouldAutoConfigChanged = useCallback(() => { const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId })); dispatch(controlAdapterAutoConfigToggled({ id }));
}, [controlNetId, dispatch]); }, [id, dispatch]);
return ( return (
<IAISwitch <IAISwitch

View File

@ -1,16 +1,16 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetProcessorParamsChanged } from 'features/controlNet/store/controlNetSlice'; import { controlAdapterProcessorParamsChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { ControlNetProcessorNode } from 'features/controlNet/store/types'; import { ControlAdapterProcessorNode } from 'features/controlNet/store/types';
import { useCallback } from 'react'; import { useCallback } from 'react';
export const useProcessorNodeChanged = () => { export const useProcessorNodeChanged = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleProcessorNodeChanged = useCallback( const handleProcessorNodeChanged = useCallback(
(controlNetId: string, changes: Partial<ControlNetProcessorNode>) => { (id: string, params: Partial<ControlAdapterProcessorNode>) => {
dispatch( dispatch(
controlNetProcessorParamsChanged({ controlAdapterProcessorParamsChanged({
controlNetId, id,
changes, params,
}) })
); );
}, },

View File

@ -2,32 +2,31 @@ import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { import {
canvasImageToControlNet, canvasImageToControlAdapter,
canvasMaskToControlNet, canvasMaskToControlAdapter,
} from 'features/canvas/store/actions'; } from 'features/canvas/store/actions';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaImage, FaMask } from 'react-icons/fa'; import { FaImage, FaMask } from 'react-icons/fa';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ControlNetCanvasImageImportsProps = { type ControlNetCanvasImageImportsProps = {
controlNet: ControlNetConfig; id: string;
}; };
const ControlNetCanvasImageImports = ( const ControlNetCanvasImageImports = (
props: ControlNetCanvasImageImportsProps props: ControlNetCanvasImageImportsProps
) => { ) => {
const { controlNet } = props; const { id } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const handleImportImageFromCanvas = useCallback(() => { const handleImportImageFromCanvas = useCallback(() => {
dispatch(canvasImageToControlNet({ controlNet })); dispatch(canvasImageToControlAdapter({ id }));
}, [controlNet, dispatch]); }, [id, dispatch]);
const handleImportMaskFromCanvas = useCallback(() => { const handleImportMaskFromCanvas = useCallback(() => {
dispatch(canvasMaskToControlNet({ controlNet })); dispatch(canvasMaskToControlAdapter({ id }));
}, [controlNet, dispatch]); }, [id, dispatch]);
return ( return (
<Flex <Flex

View File

@ -11,44 +11,49 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import { useControlAdapterBeginEndStepPct } from 'features/controlNet/hooks/useControlAdapterBeginEndStepPct';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import { import {
ControlNetConfig, controlAdapterBeginStepPctChanged,
controlNetBeginStepPctChanged, controlAdapterEndStepPctChanged,
controlNetEndStepPctChanged, } from 'features/controlNet/store/controlAdaptersSlice';
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type Props = { type Props = {
controlNet: ControlNetConfig; id: string;
}; };
const formatPct = (v: number) => `${Math.round(v * 100)}%`; const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => { const ParamControlNetBeginEnd = ({ id }: Props) => {
const { beginStepPct, endStepPct, isEnabled, controlNetId } = const isEnabled = useControlAdapterIsEnabled(id);
props.controlNet; const stepPcts = useControlAdapterBeginEndStepPct(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const handleStepPctChanged = useCallback( const handleStepPctChanged = useCallback(
(v: number[]) => { (v: number[]) => {
dispatch( dispatch(
controlNetBeginStepPctChanged({ controlAdapterBeginStepPctChanged({
controlNetId, id,
beginStepPct: v[0] as number, beginStepPct: v[0] as number,
}) })
); );
dispatch( dispatch(
controlNetEndStepPctChanged({ controlAdapterEndStepPctChanged({
controlNetId, id,
endStepPct: v[1] as number, endStepPct: v[1] as number,
}) })
); );
}, },
[controlNetId, dispatch] [dispatch, id]
); );
if (!stepPcts) {
return null;
}
return ( return (
<IAIInformationalPopover feature="controlNetBeginEnd"> <IAIInformationalPopover feature="controlNetBeginEnd">
<FormControl isDisabled={!isEnabled}> <FormControl isDisabled={!isEnabled}>
@ -56,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
<HStack w="100%" gap={2} alignItems="center"> <HStack w="100%" gap={2} alignItems="center">
<RangeSlider <RangeSlider
aria-label={['Begin Step %', 'End Step %!']} aria-label={['Begin Step %', 'End Step %!']}
value={[beginStepPct, endStepPct]} value={[stepPcts.beginStepPct, stepPcts.endStepPct]}
onChange={handleStepPctChanged} onChange={handleStepPctChanged}
min={0} min={0}
max={1} max={1}
@ -67,10 +72,18 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderTrack> <RangeSliderTrack>
<RangeSliderFilledTrack /> <RangeSliderFilledTrack />
</RangeSliderTrack> </RangeSliderTrack>
<Tooltip label={formatPct(beginStepPct)} placement="top" hasArrow> <Tooltip
label={formatPct(stepPcts.beginStepPct)}
placement="top"
hasArrow
>
<RangeSliderThumb index={0} /> <RangeSliderThumb index={0} />
</Tooltip> </Tooltip>
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow> <Tooltip
label={formatPct(stepPcts.endStepPct)}
placement="top"
hasArrow
>
<RangeSliderThumb index={1} /> <RangeSliderThumb index={1} />
</Tooltip> </Tooltip>
<RangeSliderMark <RangeSliderMark

View File

@ -1,22 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { import { useControlAdapterControlMode } from 'features/controlNet/hooks/useControlAdapterControlMode';
ControlModes, import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
ControlNetConfig, import { controlAdapterControlModeChanged } from 'features/controlNet/store/controlAdaptersSlice';
controlNetControlModeChanged, import { ControlMode } from 'features/controlNet/store/types';
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = { type Props = {
controlNet: ControlNetConfig; id: string;
}; };
export default function ParamControlNetControlMode( export default function ParamControlNetControlMode({ id }: Props) {
props: ParamControlNetControlModeProps const isEnabled = useControlAdapterIsEnabled(id);
) { const controlMode = useControlAdapterControlMode(id);
const { controlMode, isEnabled, controlNetId } = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -28,19 +26,23 @@ export default function ParamControlNetControlMode(
]; ];
const handleControlModeChange = useCallback( const handleControlModeChange = useCallback(
(controlMode: ControlModes) => { (controlMode: ControlMode) => {
dispatch(controlNetControlModeChanged({ controlNetId, controlMode })); dispatch(controlAdapterControlModeChanged({ id, controlMode }));
}, },
[controlNetId, dispatch] [id, dispatch]
); );
if (!controlMode) {
return null;
}
return ( return (
<IAIInformationalPopover feature="controlNetControlMode"> <IAIInformationalPopover feature="controlNetControlMode">
<IAIMantineSelect <IAIMantineSelect
disabled={!isEnabled} disabled={!isEnabled}
label={t('controlnet.controlMode')} label={t('controlnet.controlMode')}
data={CONTROL_MODE_DATA} data={CONTROL_MODE_DATA}
value={String(controlMode)} value={controlMode}
onChange={handleControlModeChange} onChange={handleControlModeChange}
/> />
</IAIInformationalPopover> </IAIInformationalPopover>

View File

@ -1,23 +1,21 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
ControlNetConfig, import { useControlAdapterModel } from 'features/controlNet/hooks/useControlAdapterModel';
controlNetModelChanged, import { useControlAdapterModels } from 'features/controlNet/hooks/useControlAdapterModels';
} from 'features/controlNet/store/controlNetSlice'; import { useControlAdapterType } from 'features/controlNet/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNet: ControlNetConfig; id: string;
}; };
const selector = createSelector( const selector = createSelector(
@ -29,23 +27,31 @@ const selector = createSelector(
defaultSelectorOptions defaultSelectorOptions
); );
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const ParamControlNetModel = ({ id }: ParamControlNetModelProps) => {
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet; const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { mainModel } = useAppSelector(selector); const { mainModel } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery(); const models = useControlAdapterModels(controlAdapterType);
const data = useMemo(() => { const data = useMemo(() => {
if (!controlNetModels) { if (!models) {
return []; return [];
} }
const data: SelectItem[] = []; const data: {
value: string;
label: string;
group: string;
disabled: boolean;
tooltip?: string;
}[] = [];
forEach(controlNetModels.entities, (model, id) => { models.forEach((model) => {
if (!model) { if (!model) {
return; return;
} }
@ -53,7 +59,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const disabled = model?.base_model !== mainModel?.base_model; const disabled = model?.base_model !== mainModel?.base_model;
data.push({ data.push({
value: id, value: model.id,
label: model.model_name, label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
disabled, disabled,
@ -63,20 +69,24 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
}); });
}); });
data.sort((a, b) =>
// sort 'none' to the top
a.disabled ? 1 : b.disabled ? -1 : a.label.localeCompare(b.label)
);
return data; return data;
}, [controlNetModels, mainModel?.base_model, t]); }, [mainModel?.base_model, models, t]);
// grab the full model entity from the RTK Query cache // grab the full model entity from the RTK Query cache
const selectedModel = useMemo( const selectedModel = useMemo(
() => () =>
controlNetModels?.entities[ models?.find(
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}` (m) =>
] ?? null, m?.id ===
[ `${model?.base_model}/${controlAdapterType}/${model?.model_name}`
controlNetModel?.base_model, // (m) => m?.id === `${model?.base_model}/controlnet/${model?.model_name}`
controlNetModel?.model_name, ) ?? null,
controlNetModels?.entities, [controlAdapterType, model?.base_model, model?.model_name, models]
]
); );
const handleModelChanged = useCallback( const handleModelChanged = useCallback(
@ -91,13 +101,13 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
return; return;
} }
dispatch( dispatch(controlAdapterModelChanged({ id, model: newControlNetModel }));
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
}, },
[controlNetId, dispatch] [dispatch, id]
); );
console.log(model, selectedModel);
return ( return (
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}

View File

@ -5,19 +5,18 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, { import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import { useControlAdapterProcessorNode } from 'features/controlNet/hooks/useControlAdapterProcessorNode';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import {
ControlNetConfig,
controlNetProcessorTypeChanged,
} from '../../store/controlNetSlice';
import { ControlNetProcessorType } from '../../store/types';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlAdapterProcessortTypeChanged } from '../../store/controlAdaptersSlice';
import { ControlAdapterProcessorType } from '../../store/types';
type ParamControlNetProcessorSelectProps = { type Props = {
controlNet: ControlNetConfig; id: string;
}; };
const selector = createSelector( const selector = createSelector(
@ -41,7 +40,7 @@ const selector = createSelector(
.filter( .filter(
(d) => (d) =>
!config.sd.disabledControlNetProcessors.includes( !config.sd.disabledControlNetProcessors.includes(
d.value as ControlNetProcessorType d.value as ControlAdapterProcessorType
) )
); );
@ -50,26 +49,29 @@ const selector = createSelector(
defaultSelectorOptions defaultSelectorOptions
); );
const ParamControlNetProcessorSelect = ( const ParamControlNetProcessorSelect = ({ id }: Props) => {
props: ParamControlNetProcessorSelectProps const isEnabled = useControlAdapterIsEnabled(id);
) => { const processorNode = useControlAdapterProcessorNode(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { controlNetId, isEnabled, processorNode } = props.controlNet;
const controlNetProcessors = useAppSelector(selector); const controlNetProcessors = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
const handleProcessorTypeChanged = useCallback( const handleProcessorTypeChanged = useCallback(
(v: string | null) => { (v: string | null) => {
dispatch( dispatch(
controlNetProcessorTypeChanged({ controlAdapterProcessortTypeChanged({
controlNetId, id,
processorType: v as ControlNetProcessorType, processorType: v as ControlAdapterProcessorType,
}) })
); );
}, },
[controlNetId, dispatch] [id, dispatch]
); );
if (!processorNode) {
return null;
}
return ( return (
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
label={t('controlnet.processor')} label={t('controlnet.processor')}

View File

@ -1,22 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
ControlNetConfig, import { useControlAdapterResizeMode } from 'features/controlNet/hooks/useControlAdapterResizeMode';
ResizeModes, import { controlAdapterResizeModeChanged } from 'features/controlNet/store/controlAdaptersSlice';
controlNetResizeModeChanged, import { ResizeMode } from 'features/controlNet/store/types';
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ParamControlNetResizeModeProps = { type Props = {
controlNet: ControlNetConfig; id: string;
}; };
export default function ParamControlNetResizeMode( export default function ParamControlNetResizeMode({ id }: Props) {
props: ParamControlNetResizeModeProps const isEnabled = useControlAdapterIsEnabled(id);
) { const resizeMode = useControlAdapterResizeMode(id);
const { resizeMode, isEnabled, controlNetId } = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -27,19 +25,23 @@ export default function ParamControlNetResizeMode(
]; ];
const handleResizeModeChange = useCallback( const handleResizeModeChange = useCallback(
(resizeMode: ResizeModes) => { (resizeMode: ResizeMode) => {
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode })); dispatch(controlAdapterResizeModeChanged({ id, resizeMode }));
}, },
[controlNetId, dispatch] [id, dispatch]
); );
if (!resizeMode) {
return null;
}
return ( return (
<IAIInformationalPopover feature="controlNetResizeMode"> <IAIInformationalPopover feature="controlNetResizeMode">
<IAIMantineSelect <IAIMantineSelect
disabled={!isEnabled} disabled={!isEnabled}
label={t('controlnet.resizeMode')} label={t('controlnet.resizeMode')}
data={RESIZE_MODE_DATA} data={RESIZE_MODE_DATA}
value={String(resizeMode)} value={resizeMode}
onChange={handleResizeModeChange} onChange={handleResizeModeChange}
/> />
</IAIInformationalPopover> </IAIInformationalPopover>

View File

@ -1,28 +1,34 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
ControlNetConfig, import { useControlAdapterWeight } from 'features/controlNet/hooks/useControlAdapterWeight';
controlNetWeightChanged, import { controlAdapterWeightChanged } from 'features/controlNet/store/controlAdaptersSlice';
} from 'features/controlNet/store/controlNetSlice'; import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ParamControlNetWeightProps = { type ParamControlNetWeightProps = {
controlNet: ControlNetConfig; id: string;
}; };
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { const ParamControlNetWeight = ({ id }: ParamControlNetWeightProps) => {
const { weight, isEnabled, controlNetId } = props.controlNet; const isEnabled = useControlAdapterIsEnabled(id);
const weight = useControlAdapterWeight(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const handleWeightChanged = useCallback( const handleWeightChanged = useCallback(
(weight: number) => { (weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight })); dispatch(controlAdapterWeightChanged({ id, weight }));
}, },
[controlNetId, dispatch] [dispatch, id]
); );
if (isNil(weight)) {
// should never happen
return null;
}
return ( return (
<IAIInformationalPopover feature="controlNetWeight"> <IAIInformationalPopover feature="controlNetWeight">
<IAISlider <IAISlider

View File

@ -0,0 +1,45 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { controlAdapterAdded } from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useMemo } from 'react';
import {
controlNetModelsAdapter,
useGetControlNetModelsQuery,
} from 'services/api/endpoints/models';
export const useAddControlNet = () => {
const dispatch = useAppDispatch();
const baseModel = useAppSelector(
(state) => state.generation.model?.base_model
);
const { data: controlNetModels } = useGetControlNetModelsQuery();
const firstControlNetModel = useMemo(
() =>
controlNetModels
? controlNetModelsAdapter
.getSelectors()
.selectAll(controlNetModels)
.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]
: undefined,
[baseModel, controlNetModels]
);
const isDisabled = useMemo(
() => !firstControlNetModel,
[firstControlNetModel]
);
const addControlNet = useCallback(() => {
if (isDisabled) {
return;
}
dispatch(
controlAdapterAdded({
type: 'controlnet',
overrides: { model: firstControlNetModel },
})
);
}, [dispatch, firstControlNetModel, isDisabled]);
return {
addControlNet,
isDisabled,
};
};

View File

@ -0,0 +1,63 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
controlAdapterAdded,
selectAllIPAdapters,
} from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useMemo } from 'react';
import {
ipAdapterModelsAdapter,
useGetIPAdapterModelsQuery,
} from 'services/api/endpoints/models';
const selector = createSelector(
[stateSelector],
({ controlAdapters, generation }) => {
const ipAdapterCount = selectAllIPAdapters(controlAdapters).length;
const { model } = generation;
return {
ipAdapterCount,
baseModel: model?.base_model,
};
},
defaultSelectorOptions
);
export const useAddIPAdapter = () => {
const { ipAdapterCount, baseModel } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const firstIPAdapterModel = useMemo(
() =>
ipAdapterModels
? ipAdapterModelsAdapter
.getSelectors()
.selectAll(ipAdapterModels)
.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]
: undefined,
[baseModel, ipAdapterModels]
);
const isDisabled = useMemo(
() => !firstIPAdapterModel && ipAdapterCount === 0,
[firstIPAdapterModel, ipAdapterCount]
);
const addIPAdapter = useCallback(() => {
if (isDisabled) {
return;
}
dispatch(
controlAdapterAdded({
type: 'ip_adapter',
overrides: { model: firstIPAdapterModel },
})
);
}, [dispatch, firstIPAdapterModel, isDisabled]);
return {
addIPAdapter,
isDisabled,
};
};

View File

@ -0,0 +1,45 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { controlAdapterAdded } from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useMemo } from 'react';
import {
t2iAdapterModelsAdapter,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useAddT2IAdapter = () => {
const dispatch = useAppDispatch();
const baseModel = useAppSelector(
(state) => state.generation.model?.base_model
);
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const firstT2IAdapterModel = useMemo(
() =>
t2iAdapterModels
? t2iAdapterModelsAdapter
.getSelectors()
.selectAll(t2iAdapterModels)
.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]
: undefined,
[baseModel, t2iAdapterModels]
);
const isDisabled = useMemo(
() => !firstT2IAdapterModel,
[firstT2IAdapterModel]
);
const addT2IAdapter = useCallback(() => {
if (isDisabled) {
return;
}
dispatch(
controlAdapterAdded({
type: 't2i_adapter',
overrides: { model: firstT2IAdapterModel },
})
);
}, [dispatch, firstT2IAdapterModel, isDisabled]);
return {
addT2IAdapter,
isDisabled,
};
};

View File

@ -0,0 +1,22 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapter = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => selectControlAdapterById(controlAdapters, id),
defaultSelectorOptions
),
[id]
);
const controlAdapter = useAppSelector(selector);
return controlAdapter;
};

View File

@ -0,0 +1,30 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterBeginEndStepPct = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const cn = selectControlAdapterById(controlAdapters, id);
return cn
? {
beginStepPct: cn.beginStepPct,
endStepPct: cn.endStepPct,
}
: undefined;
},
defaultSelectorOptions
),
[id]
);
const stepPcts = useAppSelector(selector);
return stepPcts;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterControlImage = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.controlImage,
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNet } from '../store/types';
export const useControlAdapterControlMode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNet(ca)) {
return ca.controlMode;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterIsEnabled = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.isEnabled ?? false,
defaultSelectorOptions
),
[id]
);
const isEnabled = useAppSelector(selector);
return isEnabled;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterModel = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.model,
defaultSelectorOptions
),
[id]
);
const model = useAppSelector(selector);
return model;
};

View File

@ -0,0 +1,48 @@
import { useMemo } from 'react';
import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
t2iAdapterModelsAdapter,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
import { ControlAdapterType } from '../store/types';
export const useControlAdapterModels = (type?: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
const controlNetModels = useMemo(
() =>
controlNetModelsData
? controlNetModelsAdapter.getSelectors().selectAll(controlNetModelsData)
: [],
[controlNetModelsData]
);
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const t2iAdapterModels = useMemo(
() =>
t2iAdapterModelsData
? t2iAdapterModelsAdapter.getSelectors().selectAll(t2iAdapterModelsData)
: [],
[t2iAdapterModelsData]
);
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
const ipAdapterModels = useMemo(
() =>
ipAdapterModelsData
? ipAdapterModelsAdapter.getSelectors().selectAll(ipAdapterModelsData)
: [],
[ipAdapterModelsData]
);
if (type === 'controlnet') {
return controlNetModels;
}
if (type === 't2i_adapter') {
return t2iAdapterModels;
}
if (type === 'ip_adapter') {
return ipAdapterModels;
}
return;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessedControlImage = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processedControlImage
: undefined;
},
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessorNode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processorNode
: undefined;
},
defaultSelectorOptions
),
[id]
);
const processorNode = useAppSelector(selector);
return processorNode;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessorType = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processorType
: undefined;
},
defaultSelectorOptions
),
[id]
);
const processorType = useAppSelector(selector);
return processorType;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterResizeMode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNetOrT2IAdapter(ca)) {
return ca.resizeMode;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterShouldAutoConfig = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNetOrT2IAdapter(ca)) {
return ca.shouldAutoConfig;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterType = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.type,
defaultSelectorOptions
),
[id]
);
const type = useAppSelector(selector);
return type;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterWeight = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.weight,
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
export const controlNetImageProcessed = createAction<{ export const controlAdapterImageProcessed = createAction<{
controlNetId: string; id: string;
}>('controlNet/imageProcessed'); }>('controlAdapters/imageProcessed');

View File

@ -1,16 +1,16 @@
import i18n from 'i18next'; import i18n from 'i18next';
import { import {
ControlNetProcessorType, ControlAdapterProcessorType,
RequiredControlNetProcessorNode, RequiredControlAdapterProcessorNode,
} from './types'; } from './types';
type ControlNetProcessorsDict = Record< type ControlNetProcessorsDict = Record<
ControlNetProcessorType, ControlAdapterProcessorType,
{ {
type: ControlNetProcessorType | 'none'; type: ControlAdapterProcessorType | 'none';
label: string; label: string;
description: string; description: string;
default: RequiredControlNetProcessorNode | { type: 'none' }; default: RequiredControlAdapterProcessorNode | { type: 'none' };
} }
>; >;
/** /**
@ -240,7 +240,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
}; };
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: { export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType; [key: string]: ControlAdapterProcessorType;
} = { } = {
canny: 'canny_image_processor', canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor', mlsd: 'mlsd_image_processor',

View File

@ -0,0 +1,8 @@
import { ControlAdaptersState } from './types';
/**
* ControlNet slice persist denylist
*/
export const controlAdaptersPersistDenylist: (keyof ControlAdaptersState)[] = [
'pendingControlImages',
];

View File

@ -0,0 +1,479 @@
import {
PayloadAction,
Update,
createEntityAdapter,
createSlice,
} from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, merge, uniq } from 'lodash-es';
import { appSocketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { buildControlAdapter } from '../util/buildControlAdapter';
import { controlAdapterImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlAdapterConfig,
ControlAdapterProcessorType,
ControlAdapterType,
ControlAdaptersState,
ControlMode,
ControlNetConfig,
RequiredControlAdapterProcessorNode,
ResizeMode,
T2IAdapterConfig,
isControlNet,
isControlNetOrT2IAdapter,
isIPAdapter,
isT2IAdapter,
} from './types';
export const caAdapter = createEntityAdapter<ControlAdapterConfig>();
export const {
selectById: selectControlAdapterById,
selectAll: selectControlAdapterAll,
selectEntities: selectControlAdapterEntities,
selectIds: selectControlAdapterIds,
selectTotal: selectControlAdapterTotal,
} = caAdapter.getSelectors();
export const initialControlAdapterState: ControlAdaptersState =
caAdapter.getInitialState<{
pendingControlImages: string[];
}>({
pendingControlImages: [],
});
export const selectAllControlNets = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isControlNet);
export const selectValidControlNets = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isControlNet)
.filter(
(ca) =>
ca.isEnabled &&
ca.model &&
(Boolean(ca.processedControlImage) ||
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
export const selectAllIPAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isIPAdapter);
export const selectValidIPAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isIPAdapter)
.filter((ca) => ca.isEnabled && ca.model && Boolean(ca.controlImage));
export const selectAllT2IAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isT2IAdapter);
export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isT2IAdapter)
.filter(
(ca) =>
ca.isEnabled &&
ca.model &&
(Boolean(ca.processedControlImage) ||
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
export const controlAdaptersSlice = createSlice({
name: 'controlAdapters',
initialState: initialControlAdapterState,
reducers: {
controlAdapterAdded: {
reducer: (
state,
action: PayloadAction<{
id: string;
type: ControlAdapterType;
overrides?: Partial<ControlAdapterConfig>;
}>
) => {
const { id, type, overrides } = action.payload;
caAdapter.addOne(state, buildControlAdapter(id, type, overrides));
},
prepare: ({
type,
overrides,
}: {
type: ControlAdapterType;
overrides?: Partial<ControlAdapterConfig>;
}) => {
return { payload: { id: uuidv4(), type, overrides } };
},
},
controlAdapterRecalled: (
state,
action: PayloadAction<ControlAdapterConfig>
) => {
const config = action.payload;
caAdapter.addOne(state, config);
},
controlAdapterDuplicated: {
reducer: (
state,
action: PayloadAction<{
id: string;
newId: string;
}>
) => {
const { id, newId } = action.payload;
const controlAdapter = selectControlAdapterById(state, id);
if (!controlAdapter) {
return;
}
const newControlAdapter = merge(cloneDeep(controlAdapter), {
id: newId,
});
caAdapter.addOne(state, newControlAdapter);
},
prepare: (id: string) => {
return { payload: { id, newId: uuidv4() } };
},
},
controlAdapterAddedFromImage: {
reducer: (
state,
action: PayloadAction<{
id: string;
type: ControlAdapterType;
controlImage: string;
}>
) => {
const { id, type, controlImage } = action.payload;
caAdapter.addOne(
state,
buildControlAdapter(id, type, { controlImage })
);
},
prepare: (payload: {
type: ControlAdapterType;
controlImage: string;
}) => {
return { payload: { ...payload, id: uuidv4() } };
},
},
controlAdapterRemoved: (state, action: PayloadAction<{ id: string }>) => {
caAdapter.removeOne(state, action.payload.id);
},
controlAdapterIsEnabledChanged: (
state,
action: PayloadAction<{ id: string; isEnabled: boolean }>
) => {
const { id, isEnabled } = action.payload;
caAdapter.updateOne(state, { id, changes: { isEnabled } });
},
controlAdapterImageChanged: (
state,
action: PayloadAction<{
id: string;
controlImage: string | null;
}>
) => {
const { id, controlImage } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
caAdapter.updateOne(state, {
id,
changes: { controlImage, processedControlImage: null },
});
if (
controlImage !== null &&
isControlNetOrT2IAdapter(cn) &&
cn.processorType !== 'none'
) {
state.pendingControlImages.push(id);
}
},
controlAdapterProcessedImageChanged: (
state,
action: PayloadAction<{
id: string;
processedControlImage: string | null;
}>
) => {
const { id, processedControlImage } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
return;
}
caAdapter.updateOne(state, {
id,
changes: {
processedControlImage,
},
});
state.pendingControlImages = state.pendingControlImages.filter(
(pendingId) => pendingId !== id
);
},
controlAdapterModelCleared: (
state,
action: PayloadAction<{ id: string }>
) => {
caAdapter.updateOne(state, {
id: action.payload.id,
changes: { model: null },
});
},
controlAdapterModelChanged: (
state,
action: PayloadAction<{
id: string;
model:
| ControlNetModelParam
| T2IAdapterModelParam
| IPAdapterModelParam;
}>
) => {
const { id, model } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } });
return;
}
const update: Update<ControlNetConfig | T2IAdapterConfig> = {
id,
changes: { model },
};
update.changes.processedControlImage = null;
if (cn.shouldAutoConfig) {
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
}
caAdapter.updateOne(state, update);
},
controlAdapterWeightChanged: (
state,
action: PayloadAction<{ id: string; weight: number }>
) => {
const { id, weight } = action.payload;
caAdapter.updateOne(state, { id, changes: { weight } });
},
controlAdapterBeginStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginStepPct: number }>
) => {
const { id, beginStepPct } = action.payload;
caAdapter.updateOne(state, { id, changes: { beginStepPct } });
},
controlAdapterEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; endStepPct: number }>
) => {
const { id, endStepPct } = action.payload;
caAdapter.updateOne(state, { id, changes: { endStepPct } });
},
controlAdapterControlModeChanged: (
state,
action: PayloadAction<{ id: string; controlMode: ControlMode }>
) => {
const { id, controlMode } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNet(cn)) {
return;
}
caAdapter.updateOne(state, { id, changes: { controlMode } });
},
controlAdapterResizeModeChanged: (
state,
action: PayloadAction<{
id: string;
resizeMode: ResizeMode;
}>
) => {
const { id, resizeMode } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
caAdapter.updateOne(state, { id, changes: { resizeMode } });
},
controlAdapterProcessorParamsChanged: (
state,
action: PayloadAction<{
id: string;
params: Partial<RequiredControlAdapterProcessorNode>;
}>
) => {
const { id, params } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn) || !cn.processorNode) {
return;
}
const processorNode = merge(cloneDeep(cn.processorNode), params);
caAdapter.updateOne(state, {
id,
changes: {
shouldAutoConfig: false,
processorNode,
},
});
},
controlAdapterProcessortTypeChanged: (
state,
action: PayloadAction<{
id: string;
processorType: ControlAdapterProcessorType;
}>
) => {
const { id, processorType } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
const processorNode = cloneDeep(
CONTROLNET_PROCESSORS[processorType].default
) as RequiredControlAdapterProcessorNode;
caAdapter.updateOne(state, {
id,
changes: {
processorType,
processedControlImage: null,
processorNode,
shouldAutoConfig: false,
},
});
},
controlAdapterAutoConfigToggled: (
state,
action: PayloadAction<{
id: string;
}>
) => {
const { id } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
const update: Update<ControlNetConfig | T2IAdapterConfig> = {
id,
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
};
if (update.changes.shouldAutoConfig) {
// manage the processor for the user
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (cn.model?.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
}
caAdapter.updateOne(state, update);
},
controlAdaptersReset: () => {
return cloneDeep(initialControlAdapterState);
},
pendingControlImagesCleared: (state) => {
state.pendingControlImages = [];
},
},
extraReducers: (builder) => {
builder.addCase(controlAdapterImageProcessed, (state, action) => {
const cn = selectControlAdapterById(state, action.payload.id);
if (!cn) {
return;
}
if (cn.controlImage !== null) {
state.pendingControlImages = uniq(
state.pendingControlImages.concat(action.payload.id)
);
}
});
builder.addCase(appSocketInvocationError, (state) => {
state.pendingControlImages = [];
});
},
});
export const {
controlAdapterAdded,
controlAdapterRecalled,
controlAdapterDuplicated,
controlAdapterAddedFromImage,
controlAdapterRemoved,
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
controlAdapterIsEnabledChanged,
controlAdapterModelChanged,
controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
controlAdapterControlModeChanged,
controlAdapterResizeModeChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
controlAdaptersReset,
controlAdapterAutoConfigToggled,
pendingControlImagesCleared,
controlAdapterModelCleared,
} = controlAdaptersSlice.actions;
export default controlAdaptersSlice.reducer;

View File

@ -1,8 +0,0 @@
import { ControlNetState } from './controlNetSlice';
/**
* ControlNet slice persist denylist
*/
export const controlNetDenylist: (keyof ControlNetState)[] = [
'pendingControlImages',
];

View File

@ -1,486 +0,0 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
export type ControlModes = NonNullable<
components['schemas']['ControlNetInvocation']['control_mode']
>;
export type ResizeModes = NonNullable<
components['schemas']['ControlNetInvocation']['resize_mode']
>;
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
resizeMode: ResizeModes;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
};
export type IPAdapterConfig = {
adapterImage: string | null;
model: IPAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
};
export type ControlNetState = {
controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean;
pendingControlImages: string[];
isIPAdapterEnabled: boolean;
ipAdapterInfo: IPAdapterConfig;
};
export const initialIPAdapterState: IPAdapterConfig = {
adapterImage: null,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
};
export const initialControlNetState: ControlNetState = {
controlNets: {},
isEnabled: false,
pendingControlImages: [],
isIPAdapterEnabled: false,
ipAdapterInfo: { ...initialIPAdapterState },
};
export const controlNetSlice = createSlice({
name: 'controlNet',
initialState: initialControlNetState,
reducers: {
isControlNetEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
controlNetEnabled: (state) => {
state.isEnabled = true;
},
controlNetAdded: (
state,
action: PayloadAction<{
controlNetId: string;
controlNet?: ControlNetConfig;
}>
) => {
const { controlNetId, controlNet } = action.payload;
state.controlNets[controlNetId] = {
...(controlNet ?? initialControlNet),
controlNetId,
};
},
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
const controlNet = action.payload;
state.controlNets[controlNet.controlNetId] = {
...controlNet,
};
},
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const oldControlNet = state.controlNets[sourceControlNetId];
if (!oldControlNet) {
return;
}
const newControlnet = cloneDeep(oldControlNet);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
...initialControlNet,
controlNetId,
controlImage,
};
},
controlNetRemoved: (
state,
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
delete state.controlNets[controlNetId];
},
controlNetIsEnabledChanged: (
state,
action: PayloadAction<{ controlNetId: string; isEnabled: boolean }>
) => {
const { controlNetId, isEnabled } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.isEnabled = isEnabled;
},
controlNetImageChanged: (
state,
action: PayloadAction<{
controlNetId: string;
controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.controlImage = controlImage;
cn.processedControlImage = null;
if (controlImage !== null && cn.processorType !== 'none') {
state.pendingControlImages.push(controlNetId);
}
},
controlNetProcessedImageChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.processedControlImage = processedControlImage;
state.pendingControlImages = state.pendingControlImages.filter(
(id) => id !== controlNetId
);
},
controlNetModelChanged: (
state,
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelParam;
}>
) => {
const { controlNetId, model } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.model = model;
cn.processedControlImage = null;
if (cn.shouldAutoConfig) {
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
} else {
cn.processorType = 'none';
cn.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlNetProcessorNode;
}
}
},
controlNetWeightChanged: (
state,
action: PayloadAction<{ controlNetId: string; weight: number }>
) => {
const { controlNetId, weight } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.weight = weight;
},
controlNetBeginStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; beginStepPct: number }>
) => {
const { controlNetId, beginStepPct } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.beginStepPct = beginStepPct;
},
controlNetEndStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; endStepPct: number }>
) => {
const { controlNetId, endStepPct } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.endStepPct = endStepPct;
},
controlNetControlModeChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
) => {
const { controlNetId, controlMode } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.controlMode = controlMode;
},
controlNetResizeModeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
resizeMode: ResizeModes;
}>
) => {
const { controlNetId, resizeMode } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.resizeMode = resizeMode;
},
controlNetProcessorParamsChanged: (
state,
action: PayloadAction<{
controlNetId: string;
changes: Omit<
Partial<RequiredControlNetProcessorNode>,
'id' | 'type' | 'is_intermediate'
>;
}>
) => {
const { controlNetId, changes } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
const processorNode = cn.processorNode;
cn.processorNode = {
...processorNode,
...changes,
};
cn.shouldAutoConfig = false;
},
controlNetProcessorTypeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processorType: ControlNetProcessorType;
}>
) => {
const { controlNetId, processorType } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.processedControlImage = null;
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
cn.shouldAutoConfig = false;
},
controlNetAutoConfigToggled: (
state,
action: PayloadAction<{
controlNetId: string;
}>
) => {
const { controlNetId } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
const newShouldAutoConfig = !cn.shouldAutoConfig;
if (newShouldAutoConfig) {
// manage the processor for the user
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (cn.model?.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
} else {
cn.processorType = 'none';
cn.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlNetProcessorNode;
}
}
cn.shouldAutoConfig = newShouldAutoConfig;
},
controlNetReset: () => {
return { ...initialControlNetState };
},
isIPAdapterEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isIPAdapterEnabled = action.payload;
},
ipAdapterRecalled: (state, action: PayloadAction<IPAdapterConfig>) => {
state.ipAdapterInfo = action.payload;
},
ipAdapterImageChanged: (state, action: PayloadAction<string | null>) => {
state.ipAdapterInfo.adapterImage = action.payload;
},
ipAdapterWeightChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.weight = action.payload;
},
ipAdapterModelChanged: (
state,
action: PayloadAction<IPAdapterModelParam | null>
) => {
state.ipAdapterInfo.model = action.payload;
},
ipAdapterBeginStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.beginStepPct = action.payload;
},
ipAdapterEndStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.endStepPct = action.payload;
},
ipAdapterStateReset: (state) => {
state.isIPAdapterEnabled = false;
state.ipAdapterInfo = { ...initialIPAdapterState };
},
clearPendingControlImages: (state) => {
state.pendingControlImages = [];
},
},
extraReducers: (builder) => {
builder.addCase(controlNetImageProcessed, (state, action) => {
const cn = state.controlNets[action.payload.controlNetId];
if (!cn) {
return;
}
if (cn.controlImage !== null) {
state.pendingControlImages.push(action.payload.controlNetId);
}
});
builder.addCase(appSocketInvocationError, (state) => {
state.pendingControlImages = [];
});
builder.addMatcher(
imagesApi.endpoints.deleteImage.matchFulfilled,
(state, action) => {
// Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg.originalArgs;
forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) {
c.controlImage = null;
c.processedControlImage = null;
}
if (c.processedControlImage === image_name) {
c.processedControlImage = null;
}
});
}
);
},
});
export const {
isControlNetEnabledToggled,
controlNetEnabled,
controlNetAdded,
controlNetRecalled,
controlNetDuplicated,
controlNetAddedFromImage,
controlNetRemoved,
controlNetImageChanged,
controlNetProcessedImageChanged,
controlNetIsEnabledChanged,
controlNetModelChanged,
controlNetWeightChanged,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetControlModeChanged,
controlNetResizeModeChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
controlNetReset,
controlNetAutoConfigToggled,
isIPAdapterEnabledChanged,
ipAdapterRecalled,
ipAdapterImageChanged,
ipAdapterWeightChanged,
ipAdapterModelChanged,
ipAdapterBeginStepPctChanged,
ipAdapterEndStepPctChanged,
ipAdapterStateReset,
clearPendingControlImages,
} = controlNetSlice.actions;
export default controlNetSlice.reducer;

View File

@ -1,4 +1,11 @@
import { EntityState } from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { isObject } from 'lodash-es'; import { isObject } from 'lodash-es';
import { components } from 'services/api/schema';
import { import {
CannyImageProcessorInvocation, CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation, ColorMapImageProcessorInvocation,
@ -12,6 +19,7 @@ import {
NormalbaeImageProcessorInvocation, NormalbaeImageProcessorInvocation,
OpenposeImageProcessorInvocation, OpenposeImageProcessorInvocation,
PidiImageProcessorInvocation, PidiImageProcessorInvocation,
T2IAdapterModelConfig,
ZoeDepthImageProcessorInvocation, ZoeDepthImageProcessorInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
@ -19,7 +27,7 @@ import { O } from 'ts-toolbelt';
/** /**
* Any ControlNet processor node * Any ControlNet processor node
*/ */
export type ControlNetProcessorNode = export type ControlAdapterProcessorNode =
| CannyImageProcessorInvocation | CannyImageProcessorInvocation
| ColorMapImageProcessorInvocation | ColorMapImageProcessorInvocation
| ContentShuffleImageProcessorInvocation | ContentShuffleImageProcessorInvocation
@ -37,8 +45,8 @@ export type ControlNetProcessorNode =
/** /**
* Any ControlNet processor type * Any ControlNet processor type
*/ */
export type ControlNetProcessorType = NonNullable< export type ControlAdapterProcessorType = NonNullable<
ControlNetProcessorNode['type'] | 'none' ControlAdapterProcessorNode['type'] | 'none'
>; >;
/** /**
@ -148,7 +156,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
/** /**
* Any ControlNet Processor node, with its parameters flagged as required * Any ControlNet Processor node, with its parameters flagged as required
*/ */
export type RequiredControlNetProcessorNode = O.Required< export type RequiredControlAdapterProcessorNode = O.Required<
| RequiredCannyImageProcessorInvocation | RequiredCannyImageProcessorInvocation
| RequiredColorMapImageProcessorInvocation | RequiredColorMapImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation | RequiredContentShuffleImageProcessorInvocation
@ -356,3 +364,90 @@ export const isZoeDepthImageProcessorInvocation = (
} }
return false; return false;
}; };
export type ControlMode = NonNullable<
components['schemas']['ControlNetInvocation']['control_mode']
>;
export type ResizeMode = NonNullable<
components['schemas']['ControlNetInvocation']['resize_mode']
>;
export type ControlNetConfig = {
type: 'controlnet';
id: string;
isEnabled: boolean;
model: ControlNetModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlMode;
resizeMode: ResizeMode;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean;
};
export type T2IAdapterConfig = {
type: 't2i_adapter';
id: string;
isEnabled: boolean;
model: T2IAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
resizeMode: ResizeMode;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean;
};
export type IPAdapterConfig = {
type: 'ip_adapter';
id: string;
isEnabled: boolean;
controlImage: string | null;
model: IPAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
};
export type ControlAdapterConfig =
| ControlNetConfig
| IPAdapterConfig
| T2IAdapterConfig;
export type ControlAdapterType = ControlAdapterConfig['type'];
export type ControlAdaptersState = EntityState<ControlAdapterConfig> & {
pendingControlImages: string[];
};
export const isControlNet = (
controlAdapter: ControlAdapterConfig
): controlAdapter is ControlNetConfig => {
return controlAdapter.type === 'controlnet';
};
export const isIPAdapter = (
controlAdapter: ControlAdapterConfig
): controlAdapter is IPAdapterConfig => {
return controlAdapter.type === 'ip_adapter';
};
export const isT2IAdapter = (
controlAdapter: ControlAdapterConfig
): controlAdapter is T2IAdapterConfig => {
return controlAdapter.type === 't2i_adapter';
};
export const isControlNetOrT2IAdapter = (
controlAdapter: ControlAdapterConfig
): controlAdapter is ControlNetConfig | T2IAdapterConfig => {
return isControlNet(controlAdapter) || isT2IAdapter(controlAdapter);
};

View File

@ -0,0 +1,70 @@
import { cloneDeep, merge } from 'lodash-es';
import {
ControlAdapterConfig,
ControlAdapterType,
ControlNetConfig,
IPAdapterConfig,
RequiredCannyImageProcessorInvocation,
T2IAdapterConfig,
} from '../store/types';
import { CONTROLNET_PROCESSORS } from '../store/constants';
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
type: 'controlnet',
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export const initialT2IAdapter: Omit<T2IAdapterConfig, 'id'> = {
type: 't2i_adapter',
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
type: 'ip_adapter',
isEnabled: true,
controlImage: null,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
};
export const buildControlAdapter = (
id: string,
type: ControlAdapterType,
overrides: Partial<ControlAdapterConfig> = {}
): ControlAdapterConfig => {
switch (type) {
case 'controlnet':
return merge(cloneDeep(initialControlNet), { id, ...overrides });
case 't2i_adapter':
return merge(cloneDeep(initialT2IAdapter), { id, ...overrides });
case 'ip_adapter':
return merge(cloneDeep(initialIPAdapter), { id, ...overrides });
default:
throw new Error(`Unknown control adapter type: ${type}`);
}
};

View File

@ -1,15 +0,0 @@
import { filter } from 'lodash-es';
import { ControlNetConfig } from '../store/controlNetSlice';
export const getValidControlNets = (
controlNets: Record<string, ControlNetConfig>
) => {
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
return validControlNets;
};

View File

@ -41,8 +41,7 @@ const selector = createSelector(
isInitialImage: some(allImageUsage, (i) => i.isInitialImage), isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage), isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage), isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage), isControlImage: some(allImageUsage, (i) => i.isControlImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
}; };
return { return {

View File

@ -35,12 +35,9 @@ const ImageUsageMessage = (props: Props) => {
{imageUsage.isCanvasImage && ( {imageUsage.isCanvasImage && (
<ListItem>{t('common.unifiedCanvas')}</ListItem> <ListItem>{t('common.unifiedCanvas')}</ListItem>
)} )}
{imageUsage.isControlNetImage && ( {imageUsage.isControlImage && (
<ListItem>{t('common.controlNet')}</ListItem> <ListItem>{t('common.controlNet')}</ListItem>
)} )}
{imageUsage.isIPAdapterImage && (
<ListItem>{t('common.ipAdapter')}</ListItem>
)}
{imageUsage.isNodesImage && ( {imageUsage.isNodesImage && (
<ListItem>{t('common.nodeEditor')}</ListItem> <ListItem>{t('common.nodeEditor')}</ListItem>
)} )}

View File

@ -4,9 +4,11 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isInvocationNode } from 'features/nodes/types/types'; import { isInvocationNode } from 'features/nodes/types/types';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
import { ImageUsage } from './types'; import { ImageUsage } from './types';
import { selectControlAdapterAll } from 'features/controlNet/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
export const getImageUsage = (state: RootState, image_name: string) => { export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state; const { generation, canvas, nodes, controlAdapters } = state;
const isInitialImage = generation.initialImage?.imageName === image_name; const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some( const isCanvasImage = canvas.layerState.objects.some(
@ -21,20 +23,17 @@ export const getImageUsage = (state: RootState, image_name: string) => {
); );
}); });
const isControlNetImage = some( const isControlImage = selectControlAdapterAll(controlAdapters).some(
controlNet.controlNets, (ca) =>
(c) => ca.controlImage === image_name ||
c.controlImage === image_name || c.processedControlImage === image_name (isControlNetOrT2IAdapter(ca) && ca.processedControlImage === image_name)
); );
const isIPAdapterImage = controlNet.ipAdapterInfo.adapterImage === image_name;
const imageUsage: ImageUsage = { const imageUsage: ImageUsage = {
isInitialImage, isInitialImage,
isCanvasImage, isCanvasImage,
isNodesImage, isNodesImage,
isControlNetImage, isControlImage,
isIPAdapterImage,
}; };
return imageUsage; return imageUsage;

View File

@ -9,6 +9,5 @@ export type ImageUsage = {
isInitialImage: boolean; isInitialImage: boolean;
isCanvasImage: boolean; isCanvasImage: boolean;
isNodesImage: boolean; isNodesImage: boolean;
isControlNetImage: boolean; isControlImage: boolean;
isIPAdapterImage: boolean;
}; };

View File

@ -28,10 +28,10 @@ export type InitialImageDropData = BaseDropData & {
actionType: 'SET_INITIAL_IMAGE'; actionType: 'SET_INITIAL_IMAGE';
}; };
export type ControlNetDropData = BaseDropData & { export type ControlAdapterDropData = BaseDropData & {
actionType: 'SET_CONTROLNET_IMAGE'; actionType: 'SET_CONTROL_ADAPTER_IMAGE';
context: { context: {
controlNetId: string; id: string;
}; };
}; };
@ -76,8 +76,7 @@ export type AddFieldToLinearViewDropData = BaseDropData & {
export type TypesafeDroppableData = export type TypesafeDroppableData =
| CurrentImageDropData | CurrentImageDropData
| InitialImageDropData | InitialImageDropData
| ControlNetDropData | ControlAdapterDropData
| IPAdapterImageDropData
| CanvasInitialImageDropData | CanvasInitialImageDropData
| NodesImageDropData | NodesImageDropData
| AddToBatchDropData | AddToBatchDropData

View File

@ -22,9 +22,7 @@ export const isValidDrop = (
return payloadType === 'IMAGE_DTO'; return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE': case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO'; return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE': case 'SET_CONTROL_ADAPTER_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_IP_ADAPTER_IMAGE':
return payloadType === 'IMAGE_DTO'; return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE': case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO'; return payloadType === 'IMAGE_DTO';

View File

@ -52,8 +52,7 @@ const DeleteBoardModal = (props: Props) => {
isInitialImage: some(allImageUsage, (i) => i.isInitialImage), isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage), isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage), isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage), isControlImage: some(allImageUsage, (i) => i.isControlImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
}; };
return { imageUsageSummary }; return { imageUsageSummary };
}), }),

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { selectValidControlNets } from 'features/controlNet/store/controlAdaptersSlice';
import { omit } from 'lodash-es'; import { omit } from 'lodash-es';
import { import {
CollectInvocation, CollectInvocation,
@ -19,17 +19,14 @@ export const addControlNetToLinearGraph = (
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string baseNodeId: string
): void => { ): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; const validControlNets = selectValidControlNets(state.controlAdapters);
const validControlNets = getValidControlNets(controlNets);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation | MetadataAccumulatorInvocation
| undefined; | undefined;
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length) { if (validControlNets.length) {
// We have multiple controlnets, add ControlNet collector // Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,
type: 'collect', type: 'collect',
@ -45,8 +42,11 @@ export const addControlNetToLinearGraph = (
}); });
validControlNets.forEach((controlNet) => { validControlNets.forEach((controlNet) => {
if (!controlNet.model) {
return;
}
const { const {
controlNetId, id,
controlImage, controlImage,
processedControlImage, processedControlImage,
beginStepPct, beginStepPct,
@ -59,14 +59,14 @@ export const addControlNetToLinearGraph = (
} = controlNet; } = controlNet;
const controlNetNode: ControlNetInvocation = { const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`, id: `control_net_${id}`,
type: 'controlnet', type: 'controlnet',
is_intermediate: true, is_intermediate: true,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
control_mode: controlMode, control_mode: controlMode,
resize_mode: resizeMode, resize_mode: resizeMode,
control_model: model as ControlNetInvocation['control_model'], control_model: model,
control_weight: weight, control_weight: weight,
}; };
@ -116,5 +116,4 @@ export const addControlNetToLinearGraph = (
} }
}); });
} }
}
}; };

View File

@ -9,35 +9,37 @@ import {
IP_ADAPTER, IP_ADAPTER,
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
} from './constants'; } from './constants';
import { selectValidIPAdapters } from 'features/controlNet/store/controlAdaptersSlice';
export const addIPAdapterToLinearGraph = ( export const addIPAdapterToLinearGraph = (
state: RootState, state: RootState,
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string baseNodeId: string
): void => { ): void => {
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet; const validIPAdapters = selectValidIPAdapters(state.controlAdapters);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation | MetadataAccumulatorInvocation
| undefined; | undefined;
if (isIPAdapterEnabled && ipAdapterInfo.model) { const ipAdapter = validIPAdapters[0];
// TODO: handle multiple IP adapters once backend is capable
if (ipAdapter && ipAdapter.model) {
const { weight, model, beginStepPct, endStepPct } = ipAdapter;
const ipAdapterNode: IPAdapterInvocation = { const ipAdapterNode: IPAdapterInvocation = {
id: IP_ADAPTER, id: IP_ADAPTER,
type: 'ip_adapter', type: 'ip_adapter',
is_intermediate: true, is_intermediate: true,
weight: ipAdapterInfo.weight, weight: weight,
ip_adapter_model: { ip_adapter_model: model,
base_model: ipAdapterInfo.model?.base_model, begin_step_percent: beginStepPct,
model_name: ipAdapterInfo.model?.model_name, end_step_percent: endStepPct,
},
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
}; };
if (ipAdapterInfo.adapterImage) { if (ipAdapter.controlImage) {
ipAdapterNode.image = { ipAdapterNode.image = {
image_name: ipAdapterInfo.adapterImage, image_name: ipAdapter.controlImage,
}; };
} else { } else {
return; return;
@ -47,15 +49,12 @@ export const addIPAdapterToLinearGraph = (
if (metadataAccumulator?.ipAdapters) { if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = { const ipAdapterField = {
image: { image: {
image_name: ipAdapterInfo.adapterImage, image_name: ipAdapter.controlImage,
}, },
ip_adapter_model: { weight,
base_model: ipAdapterInfo.model?.base_model, ip_adapter_model: model,
model_name: ipAdapterInfo.model?.model_name, begin_step_percent: beginStepPct,
}, end_step_percent: endStepPct,
weight: ipAdapterInfo.weight,
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
}; };
metadataAccumulator.ipAdapters.push(ipAdapterField); metadataAccumulator.ipAdapters.push(ipAdapterField);

View File

@ -1,122 +1,100 @@
import { Divider, Flex } from '@chakra-ui/react'; import { ButtonGroup, Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import IAIIconButton from 'common/components/IAIIconButton';
import ControlNet from 'features/controlNet/components/ControlNet'; import ControlNet from 'features/controlNet/components/ControlNet';
import IPAdapterPanel from 'features/controlNet/components/ipAdapter/IPAdapterPanel'; import { useAddControlNet } from 'features/controlNet/hooks/useAddControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import { useAddIPAdapter } from 'features/controlNet/hooks/useAddIPAdapter';
import { useAddT2IAdapter } from 'features/controlNet/hooks/useAddT2IAdapter';
import { import {
controlNetAdded, selectAllControlNets,
controlNetModelChanged, selectAllIPAdapters,
} from 'features/controlNet/store/controlNetSlice'; selectAllT2IAdapters,
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; } from 'features/controlNet/store/controlAdaptersSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es'; import { Fragment, memo } from 'react';
import { Fragment, memo, useCallback, useMemo } from 'react';
import { FaPlus } from 'react-icons/fa'; import { FaPlus } from 'react-icons/fa';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
({ controlNet }) => { ({ controlAdapters }) => {
const { controlNets, isEnabled, isIPAdapterEnabled, ipAdapterInfo } = const activeLabel: string[] = [];
controlNet;
const validControlNets = getValidControlNets(controlNets); const validIPAdapters = selectAllIPAdapters(controlAdapters);
const isIPAdapterValid = ipAdapterInfo.model && ipAdapterInfo.adapterImage; const validIPAdapterCount = validIPAdapters.length;
let activeLabel = undefined; if (validIPAdapterCount > 0) {
activeLabel.push(`${validIPAdapterCount} IP`);
if (isEnabled && validControlNets.length > 0) {
activeLabel = `${validControlNets.length} ControlNet`;
} }
if (isIPAdapterEnabled && isIPAdapterValid) { const validControlNets = selectAllControlNets(controlAdapters);
if (activeLabel) { const validControlNetCount = validControlNets.length;
activeLabel = `${activeLabel}, IP Adapter`; if (validControlNetCount > 0) {
} else { activeLabel.push(`${validControlNetCount} ControlNet`);
activeLabel = 'IP Adapter';
}
} }
return { controlNetsArray: map(controlNets), activeLabel }; const validT2IAdapters = selectAllT2IAdapters(controlAdapters);
const validT2IAdapterCount = validT2IAdapters.length;
if (validT2IAdapterCount > 0) {
activeLabel.push(`${validT2IAdapterCount} T2I`);
}
return {
controlAdapters: [
...validIPAdapters,
...validControlNets,
...validT2IAdapters,
],
activeLabel: activeLabel.join(', '),
};
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamControlNetCollapse = () => { const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector); const { controlAdapters, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data: controlnetModels } = useGetControlNetModelsQuery(); const { data: controlnetModels } = useGetControlNetModelsQuery();
const { addControlNet } = useAddControlNet();
const firstModel = useMemo(() => { const { addIPAdapter } = useAddIPAdapter();
if (!controlnetModels || !Object.keys(controlnetModels.entities).length) { const { addT2IAdapter } = useAddT2IAdapter();
return undefined;
}
const firstModelId = Object.keys(controlnetModels.entities)[0];
if (!firstModelId) {
return undefined;
}
const firstModel = controlnetModels.entities[firstModelId];
return firstModel ? firstModel : undefined;
}, [controlnetModels]);
const handleClickedAddControlNet = useCallback(() => {
if (!firstModel) {
return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) {
return null;
}
return ( return (
<IAICollapse label="Control Adapters" activeLabel={activeLabel}> <IAICollapse label="Control Adapters" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}> <Flex sx={{ flexDir: 'column', gap: 2 }}>
<Flex <ButtonGroup size="sm" w="full" justifyContent="space-between">
sx={{ <IAIButton
w: '100%', leftIcon={<FaPlus />}
gap: 2, onClick={addControlNet}
p: 2,
ps: 3,
borderRadius: 'base',
alignItems: 'center',
bg: 'base.250',
_dark: {
bg: 'base.750',
},
}}
>
<ParamControlNetFeatureToggle />
<IAIIconButton
tooltip="Add ControlNet"
aria-label="Add ControlNet"
icon={<FaPlus />}
isDisabled={!firstModel}
flexGrow={1}
size="sm"
onClick={handleClickedAddControlNet}
data-testid="add controlnet" data-testid="add controlnet"
/> >
</Flex> ControlNet
{controlNetsArray.map((c, i) => ( </IAIButton>
<Fragment key={c.controlNetId}> <IAIButton
leftIcon={<FaPlus />}
onClick={addIPAdapter}
data-testid="add ip adapter"
>
IP Adapter
</IAIButton>
<IAIButton
leftIcon={<FaPlus />}
onClick={addT2IAdapter}
data-testid="add t2i adapter"
>
T2I Adapter
</IAIButton>
</ButtonGroup>
{controlAdapters.map((ca, i) => (
<Fragment key={ca.id}>
{i > 0 && <Divider />} {i > 0 && <Divider />}
<ControlNet controlNet={c} /> <ControlNet id={ca.id} />
</Fragment> </Fragment>
))} ))}
<IPAdapterPanel />
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -30,17 +30,6 @@ import {
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models'; } from '../../../services/api/endpoints/models';
import {
ControlNetConfig,
IPAdapterConfig,
controlNetEnabled,
controlNetRecalled,
controlNetReset,
initialControlNet,
initialIPAdapterState,
ipAdapterRecalled,
isIPAdapterEnabledChanged,
} from '../../controlNet/store/controlNetSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice'; import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
@ -515,7 +504,7 @@ export const useRecallParameters = () => {
} }
dispatch( dispatch(
controlNetRecalled({ controlAdapterRecalled({
...result.controlnet, ...result.controlnet,
}) })
); );
@ -745,14 +734,14 @@ export const useRecallParameters = () => {
} }
}); });
dispatch(controlNetReset()); dispatch(controlAdaptersReset());
if (controlnets?.length) { if (controlnets?.length) {
dispatch(controlNetEnabled()); dispatch(controlNetEnabled());
} }
controlnets?.forEach((controlnet) => { controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet); const result = prepareControlNetMetadataItem(controlnet);
if (result.controlnet) { if (result.controlnet) {
dispatch(controlNetRecalled(result.controlnet)); dispatch(controlAdapterRecalled(result.controlnet));
} }
}); });

View File

@ -1,6 +1,6 @@
import { Heading, Text } from '@chakra-ui/react'; import { Heading, Text } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlAdaptersReset } from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useEffect } from 'react'; import { useCallback, useEffect } from 'react';
import IAIButton from '../../../../common/components/IAIButton'; import IAIButton from '../../../../common/components/IAIButton';
import { import {
@ -24,7 +24,7 @@ export default function SettingsClearIntermediates() {
clearIntermediates() clearIntermediates()
.unwrap() .unwrap()
.then((response) => { .then((response) => {
dispatch(controlNetReset()); dispatch(controlAdaptersReset());
dispatch(resetCanvas()); dispatch(resetCanvas());
dispatch( dispatch(
addToast({ addToast({

View File

@ -191,13 +191,9 @@ export type GraphInvocationOutput = s['GraphInvocationOutput'];
// Post-image upload actions, controls workflows when images are uploaded // Post-image upload actions, controls workflows when images are uploaded
export type ControlNetAction = { export type ControlAdapterAction = {
type: 'SET_CONTROLNET_IMAGE'; type: 'SET_CONTROL_ADAPTER_IMAGE';
controlNetId: string; id: string;
};
export type IPAdapterAction = {
type: 'SET_IP_ADAPTER_IMAGE';
}; };
export type InitialImageAction = { export type InitialImageAction = {
@ -224,8 +220,7 @@ export type AddToBatchAction = {
}; };
export type PostUploadAction = export type PostUploadAction =
| ControlNetAction | ControlAdapterAction
| IPAdapterAction
| InitialImageAction | InitialImageAction
| NodesAction | NodesAction
| CanvasInitialImageAction | CanvasInitialImageAction