feat(ui): refactor control adapters

Control adapters logic/state/ui is now generalized to hold controlnet, ip_adapter and t2i_adapter. In the future, other control adapter types can be added.

TODO:
- Limit IP adapter to 1
- Add T2I adapter to linear graphs
- Fix autoprocess
- T2I metadata saving & recall
- Improve on control adapters UI
This commit is contained in:
psychedelicious 2023-10-05 22:40:21 +11:00
parent 9c720da021
commit 9508e0c9db
70 changed files with 1860 additions and 1236 deletions

View File

@ -50,6 +50,7 @@
"close": "Close",
"communityLabel": "Community",
"controlNet": "Controlnet",
"controlAdapter": "Control Adapter",
"ipAdapter": "IP Adapter",
"darkMode": "Dark Mode",
"discordLabel": "Discord",

View File

@ -1,5 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
import { controlAdaptersPersistDenylist } from 'features/controlNet/store/controlAdaptersPersistDenylist';
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
@ -20,7 +20,7 @@ const serializationDenylist: {
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
ui: uiPersistDenylist,
controlNet: controlNetDenylist,
controlNet: controlAdaptersPersistDenylist,
dynamicPrompts: dynamicPromptsPersistDenylist,
};

View File

@ -1,5 +1,5 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
import { initialControlAdapterState } from 'features/controlNet/store/controlAdaptersSlice';
import { initialDynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
@ -25,7 +25,7 @@ const initialStates: {
config: initialConfigState,
ui: initialUIState,
hotkeys: initialHotkeysState,
controlNet: initialControlNetState,
controlAdapters: initialControlAdapterState,
dynamicPrompts: initialDynamicPromptsState,
sdxl: initialSDXLState,
};

View File

@ -1,8 +1,5 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlNetReset,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { controlAdaptersReset } from 'features/controlNet/store/controlAdaptersSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
@ -20,8 +17,7 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
let wasInitialImageReset = false;
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wasControlNetReset = false;
let wasIPAdapterReset = false;
let wereControlAdaptersReset = false;
const state = getState();
deleted_images.forEach((image_name) => {
@ -42,14 +38,9 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
wasNodeEditorReset = true;
}
if (imageUsage.isControlNetImage && !wasControlNetReset) {
dispatch(controlNetReset());
wasControlNetReset = true;
}
if (imageUsage.isIPAdapterImage && !wasIPAdapterReset) {
dispatch(ipAdapterStateReset());
wasIPAdapterReset = true;
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
dispatch(controlAdaptersReset());
wereControlAdaptersReset = true;
}
});
},

View File

@ -1,20 +1,21 @@
import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { controlAdapterImageChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
export const addCanvasImageToControlNetListener = () => {
startAppListening({
actionCreator: canvasImageToControlNet,
actionCreator: canvasImageToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
let blob;
let blob: Blob;
try {
blob = await getBaseLayerBlob(state, true);
} catch (err) {
@ -50,8 +51,8 @@ export const addCanvasImageToControlNetListener = () => {
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);

View File

@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { controlAdapterImageChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -9,11 +9,11 @@ import { startAppListening } from '..';
export const addCanvasMaskToControlNetListener = () => {
startAppListening({
actionCreator: canvasMaskToControlNet,
actionCreator: canvasMaskToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
@ -61,8 +61,8 @@ export const addCanvasMaskToControlNetListener = () => {
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);

View File

@ -1,15 +1,24 @@
import { AnyListenerPredicate } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { controlAdapterImageProcessed } from 'features/controlNet/store/actions';
import {
controlNetAutoConfigToggled,
controlNetImageChanged,
controlNetModelChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterAutoConfigToggled,
controlAdapterImageChanged,
controlAdapterModelChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
selectControlAdapterById,
} from 'features/controlNet/store/controlAdaptersSlice';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
type AnyControlAdapterParamChangeAction =
| ReturnType<typeof controlAdapterProcessorParamsChanged>
| ReturnType<typeof controlAdapterModelChanged>
| ReturnType<typeof controlAdapterImageChanged>
| ReturnType<typeof controlAdapterProcessortTypeChanged>
| ReturnType<typeof controlAdapterAutoConfigToggled>;
const predicate: AnyListenerPredicate<RootState> = (
action,
@ -17,35 +26,31 @@ const predicate: AnyListenerPredicate<RootState> = (
prevState
) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) ||
controlNetImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action) ||
controlNetAutoConfigToggled.match(action);
controlAdapterProcessorParamsChanged.match(action) ||
controlAdapterModelChanged.match(action) ||
controlAdapterImageChanged.match(action) ||
controlAdapterProcessortTypeChanged.match(action) ||
controlAdapterAutoConfigToggled.match(action);
if (!isActionMatched) {
return false;
}
if (controlNetAutoConfigToggled.match(action)) {
const { id } = action.payload;
const ca = selectControlAdapterById(prevState.controlAdapters, id);
if (!ca || !isControlNetOrT2IAdapter(ca)) {
return false;
}
if (controlAdapterAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
?.shouldAutoConfig === true
) {
if (ca.shouldAutoConfig === true) {
return false;
}
}
const cn = state.controlNet.controlNets[action.payload.controlNetId];
if (!cn) {
// something is wrong, the controlNet should exist
return false;
}
const { controlImage, processorType, shouldAutoConfig } = cn;
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
const { controlImage, processorType, shouldAutoConfig } = ca;
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;
}
@ -67,7 +72,7 @@ export const addControlNetAutoProcessListener = () => {
predicate,
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const log = logger('session');
const { controlNetId } = action.payload;
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
@ -75,7 +80,7 @@ export const addControlNetAutoProcessListener = () => {
// Delay before starting actual work
await delay(300);
dispatch(controlNetImageProcessed({ controlNetId }));
dispatch(controlAdapterImageProcessed({ id }));
},
});
};

View File

@ -1,11 +1,11 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import {
clearPendingControlImages,
controlNetImageChanged,
controlNetProcessedImageChanged,
} from 'features/controlNet/store/controlNetSlice';
pendingControlImagesCleared,
controlAdapterImageChanged,
selectControlAdapterById,
controlAdapterProcessedImageChanged,
} from 'features/controlNet/store/controlAdaptersSlice';
import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
@ -15,16 +15,18 @@ import { isImageOutput } from 'services/api/guards';
import { Graph, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
import { controlAdapterImageProcessed } from 'features/controlNet/store/actions';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
export const addControlNetImageProcessedListener = () => {
startAppListening({
actionCreator: controlNetImageProcessed,
actionCreator: controlAdapterImageProcessed,
effect: async (action, { dispatch, getState, take }) => {
const log = logger('session');
const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId];
const { id } = action.payload;
const ca = selectControlAdapterById(getState().controlAdapters, id);
if (!controlNet?.controlImage) {
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
log.error('Unable to process ControlNet image');
return;
}
@ -33,10 +35,10 @@ export const addControlNetImageProcessedListener = () => {
// Also we need to grab the image.
const graph: Graph = {
nodes: {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
[ca.processorNode.id]: {
...ca.processorNode,
is_intermediate: true,
image: { image_name: controlNet.controlImage },
image: { image_name: ca.controlImage },
},
[SAVE_IMAGE]: {
id: SAVE_IMAGE,
@ -48,7 +50,7 @@ export const addControlNetImageProcessedListener = () => {
edges: [
{
source: {
node_id: controlNet.processorNode.id,
node_id: ca.processorNode.id,
field: 'image',
},
destination: {
@ -103,8 +105,8 @@ export const addControlNetImageProcessedListener = () => {
// Update the processed image in the store
dispatch(
controlNetProcessedImageChanged({
controlNetId,
controlAdapterProcessedImageChanged({
id,
processedControlImage: processedControlImage.image_name,
})
);
@ -126,10 +128,8 @@ export const addControlNetImageProcessedListener = () => {
duration: 15000,
})
);
dispatch(clearPendingControlImages());
dispatch(
controlNetImageChanged({ controlNetId, controlImage: null })
);
dispatch(pendingControlImagesCleared());
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
return;
}
}

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetProcessedImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
selectControlAdapterAll,
} from 'features/controlNet/store/controlAdaptersSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
@ -17,6 +17,7 @@ import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
export const addRequestedSingleImageDeletionListener = () => {
startAppListening({
@ -90,35 +91,28 @@ export const addRequestedSingleImageDeletionListener = () => {
dispatch(clearInitialImage());
}
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
// reset control adapters that use the deleted images
forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
if (
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
) {
dispatch(
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterImageChanged({
id: ca.id,
controlImage: null,
})
);
dispatch(
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterProcessedImageChanged({
id: ca.id,
processedControlImage: null,
})
);
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
@ -215,35 +209,28 @@ export const addRequestedMultipleImageDeletionListener = () => {
dispatch(clearInitialImage());
}
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
// reset control adapters that use the deleted images
forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
if (
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
) {
dispatch(
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterImageChanged({
id: ca.id,
controlImage: null,
})
);
dispatch(
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterProcessedImageChanged({
id: ca.id,
processedControlImage: null,
})
);
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {

View File

@ -3,11 +3,9 @@ import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlNet/store/controlAdaptersSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
@ -90,39 +88,26 @@ export const addImageDroppedListener = () => {
* Image dropped on ControlNet
*/
if (
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { controlNetId } = overData.context;
const { id } = overData.context;
dispatch(
controlNetImageChanged({
controlAdapterImageChanged({
id,
controlImage: activeData.payload.imageDTO.image_name,
controlNetId,
})
);
dispatch(
controlNetIsEnabledChanged({
controlNetId,
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
return;
}
/**
* Image dropped on IP Adapter image
*/
if (
overData.actionType === 'SET_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
return;
}
/**
* Image dropped on Canvas
*/

View File

@ -18,8 +18,7 @@ export const addImageToDeleteSelectedListener = () => {
const isImageInUse =
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlNetImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isControlImage) ||
imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) {

View File

@ -2,11 +2,9 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlNet/store/controlAdaptersSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
@ -87,17 +85,17 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
const { controlNetId } = postUploadAction;
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
const { id } = postUploadAction;
dispatch(
controlNetIsEnabledChanged({
controlNetId,
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
dispatch(
controlNetImageChanged({
controlNetId,
controlAdapterImageChanged({
id,
controlImage: imageDTO.image_name,
})
);
@ -110,18 +108,6 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setIPAdapterImage'),
})
);
return;
}
if (postUploadAction?.type === 'SET_INITIAL_IMAGE') {
dispatch(initialImageChanged(imageDTO));
dispatch(

View File

@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import {
controlNetRemoved,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterRemoved,
selectControlAdapterAll,
} from 'features/controlNet/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
@ -60,24 +60,13 @@ export const addModelSelectedListener = () => {
}
// handle incompatible controlnets
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base_model !== base_model) {
dispatch(controlAdapterRemoved({ id: ca.id }));
modelsCleared += 1;
}
});
// handle incompatible IP-Adapter
const { ipAdapterInfo } = state.controlNet;
if (
ipAdapterInfo.model &&
ipAdapterInfo.model.base_model !== base_model
) {
dispatch(ipAdapterStateReset());
modelsCleared += 1;
}
if (modelsCleared > 0) {
dispatch(
addToast(

View File

@ -1,8 +1,10 @@
import { logger } from 'app/logging/logger';
import {
controlNetRemoved,
ipAdapterModelChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterModelCleared,
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
} from 'features/controlNet/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import {
modelChanged,
@ -19,14 +21,12 @@ import {
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import {
ipAdapterModelsAdapter,
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
import { zIPAdapterModel } from 'features/nodes/types/types';
export const addModelsLoadedListener = () => {
startAppListening({
@ -221,21 +221,45 @@ export const addModelsLoadedListener = () => {
`ControlNet models loaded (${action.payload.ids.length})`
);
const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
if (isControlNetAvailable) {
if (isModelAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info(
{ models: action.payload.entities },
`ControlNet models loaded (${action.payload.ids.length})`
);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
@ -249,38 +273,20 @@ export const addModelsLoadedListener = () => {
`IP Adapter models loaded (${action.payload.ids.length})`
);
const { model } = getState().controlNet.ipAdapterInfo;
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === model?.model_name &&
m?.base_model === model?.base_model
);
if (isModelAvailable) {
return;
}
const firstModel = ipAdapterModelsAdapter
.getSelectors()
.selectAll(action.payload)[0];
if (!firstModel) {
dispatch(ipAdapterModelChanged(null));
}
const result = zIPAdapterModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse IP Adapter model'
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
return;
}
dispatch(ipAdapterModelChanged(result.data));
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({

View File

@ -7,7 +7,7 @@ import {
} from '@reduxjs/toolkit';
import canvasReducer from 'features/canvas/store/canvasSlice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import controlAdaptersReducer from 'features/controlNet/store/controlAdaptersSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
@ -44,7 +44,7 @@ const allReducers = {
config: configReducer,
ui: uiReducer,
hotkeys: hotkeysReducer,
controlNet: controlNetReducer,
controlAdapters: controlAdaptersReducer,
dynamicPrompts: dynamicPromptsReducer,
deleteImageModal: deleteImageModalReducer,
changeBoardModal: changeBoardModalReducer,
@ -68,7 +68,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'postprocessing',
'system',
'ui',
'controlNet',
'controlAdapters',
'dynamicPrompts',
'lora',
'modelmanager',

View File

@ -15,7 +15,7 @@ type UseImageUploadButtonArgs = {
* @example
* const { getUploadButtonProps, getUploadInputProps, openUploader } = useImageUploadButton({
* postUploadAction: {
* type: 'SET_CONTROLNET_IMAGE',
* type: 'SET_CONTROL_ADAPTER_IMAGE',
* controlNetId: '12345',
* },
* isDisabled: getIsUploadDisabled(),

View File

@ -2,16 +2,18 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectControlAdapterAll } from 'features/controlNet/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
import { isInvocationNode } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach, map } from 'lodash-es';
import { forEach } from 'lodash-es';
import { getConnectedEdges } from 'reactflow';
const selector = createSelector(
[stateSelector, activeTabNameSelector],
(
{ controlNet, generation, system, nodes, dynamicPrompts },
{ controlAdapters, generation, system, nodes, dynamicPrompts },
activeTabName
) => {
const { initialImage, model } = generation;
@ -87,30 +89,29 @@ const selector = createSelector(
reasons.push(i18n.t('parameters.invoke.noModelSelected'));
}
if (controlNet.isEnabled) {
map(controlNet.controlNets).forEach((controlNet, i) => {
if (!controlNet.isEnabled) {
return;
}
if (!controlNet.model) {
reasons.push(
i18n.t('parameters.invoke.noModelForControlNet', { index: i + 1 })
);
}
selectControlAdapterAll(controlAdapters).forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (!ca.model) {
reasons.push(
i18n.t('parameters.invoke.noModelForControlNet', { index: i + 1 })
);
}
if (
!controlNet.controlImage ||
(!controlNet.processedControlImage &&
controlNet.processorType !== 'none')
) {
reasons.push(
i18n.t('parameters.invoke.noControlImageForControlNet', {
index: i + 1,
})
);
}
});
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) &&
!ca.processedControlImage &&
ca.processorType !== 'none')
) {
reasons.push(
i18n.t('parameters.invoke.noControlImageForControlNet', {
index: i + 1,
})
);
}
});
}
return { isReady: !reasons.length, reasons };

View File

@ -1,5 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@ -22,10 +21,10 @@ export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved'
);
export const canvasMaskToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasMaskToControlNet');
export const canvasMaskToControlAdapter = createAction<{ id: string }>(
'canvas/canvasMaskToControlAdapter'
);
export const canvasImageToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasImageToControlNet');
export const canvasImageToControlAdapter = createAction<{ id: string }>(
'canvas/canvasImageToControlAdapter'
);

View File

@ -3,11 +3,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetDuplicated,
controlNetRemoved,
controlNetIsEnabledChanged,
} from '../store/controlNetSlice';
controlAdapterDuplicated,
controlAdapterIsEnabledChanged,
controlAdapterRemoved,
selectControlAdapterById,
} from '../store/controlAdaptersSlice';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
@ -19,8 +19,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useTranslation } from 'react-i18next';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import { isControlNetOrT2IAdapter } from '../store/types';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
@ -29,14 +28,12 @@ import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
import { useToggle } from 'react-use';
import { useControlAdapterType } from '../hooks/useControlAdapterType';
type ControlNetProps = {
controlNet: ControlNetConfig;
};
const ControlNet = (props: ControlNetProps) => {
const { controlNet } = props;
const { controlNetId } = controlNet;
const ControlNet = (props: { id: string }) => {
const { id } = props;
const controlAdapterType = useControlAdapterType(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -44,8 +41,8 @@ const ControlNet = (props: ControlNetProps) => {
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const cn = controlNet.controlNets[controlNetId];
({ controlAdapters }) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (!cn) {
return {
@ -54,9 +51,15 @@ const ControlNet = (props: ControlNetProps) => {
};
}
const { isEnabled, shouldAutoConfig } = cn;
const isEnabled = cn.isEnabled;
const shouldAutoConfig = isControlNetOrT2IAdapter(cn)
? cn.shouldAutoConfig
: false;
return { isEnabled, shouldAutoConfig };
return {
isEnabled,
shouldAutoConfig,
};
},
defaultSelectorOptions
);
@ -65,30 +68,29 @@ const ControlNet = (props: ControlNetProps) => {
const [isExpanded, toggleIsExpanded] = useToggle(false);
const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]);
dispatch(controlAdapterRemoved({ id }));
}, [id, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(
controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
);
}, [controlNetId, dispatch]);
dispatch(controlAdapterDuplicated(id));
}, [id, dispatch]);
const handleToggleIsEnabled = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
controlNetIsEnabledChanged({
controlNetId,
controlAdapterIsEnabledChanged({
id,
isEnabled: e.target.checked,
})
);
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!controlAdapterType) {
return null;
}
return (
<Flex
sx={{
@ -120,10 +122,10 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s',
}}
>
<ParamControlNetModel controlNet={controlNet} />
<ParamControlNetModel id={id} />
</Box>
{activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports controlNet={controlNet} />
<ControlNetCanvasImageImports id={id} />
)}
<IAIIconButton
size="sm"
@ -207,8 +209,8 @@ const ControlNet = (props: ControlNetProps) => {
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight controlNet={controlNet} />
<ParamControlNetBeginEnd controlNet={controlNet} />
<ParamControlNetWeight id={id} />
<ParamControlNetBeginEnd id={id} />
</Flex>
{!isExpanded && (
<Flex
@ -220,22 +222,22 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={controlNet} isSmall />
<ControlNetImagePreview id={id} isSmall />
</Flex>
)}
</Flex>
<Flex sx={{ gap: 2 }}>
<ParamControlNetControlMode controlNet={controlNet} />
<ParamControlNetResizeMode controlNet={controlNet} />
<ParamControlNetControlMode id={id} />
<ParamControlNetResizeMode id={id} />
</Flex>
<ParamControlNetProcessorSelect controlNet={controlNet} />
<ParamControlNetProcessorSelect id={id} />
</Flex>
{isExpanded && (
<>
<ControlNetImagePreview controlNet={controlNet} />
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
<ControlNetProcessorComponent controlNet={controlNet} />
<ControlNetImagePreview id={id} />
<ParamControlNetShouldAutoConfig id={id} />
<ControlNetProcessorComponent id={id} />
</>
)}
</Flex>

View File

@ -23,20 +23,20 @@ import {
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import {
ControlNetConfig,
controlNetImageChanged,
} from '../store/controlNetSlice';
import { controlAdapterImageChanged } from '../store/controlAdaptersSlice';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from '../hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from '../hooks/useControlAdapterProcessorType';
type Props = {
controlNet: ControlNetConfig;
id: string;
isSmall?: boolean;
};
const selector = createSelector(
stateSelector,
({ controlNet, gallery }) => {
const { pendingControlImages } = controlNet;
({ controlAdapters, gallery }) => {
const { pendingControlImages } = controlAdapters;
const { autoAddBoardId } = gallery;
return {
@ -47,13 +47,10 @@ const selector = createSelector(
defaultSelectorOptions
);
const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const {
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
controlNetId,
} = controlNet;
const ControlNetImagePreview = ({ isSmall, id }: Props) => {
const controlImageName = useControlAdapterControlImage(id);
const processedControlImageName = useControlAdapterProcessedControlImage(id);
const processorType = useControlAdapterProcessorType(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -75,8 +72,8 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]);
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
}, [id, dispatch]);
const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) {
@ -133,32 +130,32 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) {
return {
id: controlNetId,
id,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, controlNetId]);
}, [controlImage, id]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: controlNetId,
actionType: 'SET_CONTROLNET_IMAGE',
context: { controlNetId },
id,
actionType: 'SET_CONTROL_ADAPTER_IMAGE',
context: { id },
}),
[controlNetId]
[id]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
[controlNetId]
() => ({ type: 'SET_CONTROL_ADAPTER_IMAGE', id }),
[id]
);
const shouldShowProcessedImage =
controlImage &&
processedControlImage &&
!isMouseOverImage &&
!pendingControlImages.includes(controlNetId) &&
!pendingControlImages.includes(id) &&
processorType !== 'none';
return (
@ -222,7 +219,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
/>
</>
{pendingControlImages.includes(controlNetId) && (
{pendingControlImages.includes(id) && (
<Flex
sx={{
position: 'absolute',

View File

@ -1,26 +1,26 @@
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetImageProcessed } from '../store/actions';
import IAIButton from 'common/components/IAIButton';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
import { memo, useCallback } from 'react';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { controlAdapterImageProcessed } from '../store/actions';
type Props = {
controlNet: ControlNetConfig;
id: string;
};
const ControlNetPreprocessButton = (props: Props) => {
const { controlNetId, controlImage } = props.controlNet;
const ControlNetPreprocessButton = ({ id }: Props) => {
const controlImage = useControlAdapterControlImage(id);
const dispatch = useAppDispatch();
const isReady = useIsReadyToEnqueue();
const handleProcess = useCallback(() => {
dispatch(
controlNetImageProcessed({
controlNetId,
controlAdapterImageProcessed({
id,
})
);
}, [controlNetId, dispatch]);
}, [id, dispatch]);
return (
<IAIButton

View File

@ -1,5 +1,4 @@
import { memo } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import CannyProcessor from './processors/CannyProcessor';
import ColorMapProcessor from './processors/ColorMapProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
@ -13,18 +12,25 @@ import NormalBaeProcessor from './processors/NormalBaeProcessor';
import OpenposeProcessor from './processors/OpenposeProcessor';
import PidiProcessor from './processors/PidiProcessor';
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterProcessorNode } from '../hooks/useControlAdapterProcessorNode';
export type ControlNetProcessorProps = {
controlNet: ControlNetConfig;
export type Props = {
id: string;
};
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, isEnabled, processorNode } = props.controlNet;
const ControlNetProcessorComponent = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const processorNode = useControlAdapterProcessorNode(id);
if (!processorNode) {
return null;
}
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -34,7 +40,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'color_map_image_processor') {
return (
<ColorMapProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -44,7 +50,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -54,7 +60,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'lineart_image_processor') {
return (
<LineartProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -64,7 +70,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'content_shuffle_image_processor') {
return (
<ContentShuffleProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -74,7 +80,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'lineart_anime_image_processor') {
return (
<LineartAnimeProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -84,7 +90,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'mediapipe_face_processor') {
return (
<MediapipeFaceProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -94,7 +100,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'midas_depth_image_processor') {
return (
<MidasDepthProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -104,7 +110,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'mlsd_image_processor') {
return (
<MlsdImageProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -114,7 +120,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'normalbae_image_processor') {
return (
<NormalBaeProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -124,7 +130,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'openpose_image_processor') {
return (
<OpenposeProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -134,7 +140,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'pidi_image_processor') {
return (
<PidiProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@ -144,7 +150,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'zoe_depth_image_processor') {
return (
<ZoeDepthProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>

View File

@ -1,24 +1,24 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import {
ControlNetConfig,
controlNetAutoConfigToggled,
} from 'features/controlNet/store/controlNetSlice';
import { controlAdapterAutoConfigToggled } from 'features/controlNet/store/controlAdaptersSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterShouldAutoConfig } from '../hooks/useControlAdapterShouldAutoConfig';
type Props = {
controlNet: ControlNetConfig;
id: string;
};
const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet;
const ParamControlNetShouldAutoConfig = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]);
dispatch(controlAdapterAutoConfigToggled({ id }));
}, [id, dispatch]);
return (
<IAISwitch

View File

@ -1,16 +1,16 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetProcessorParamsChanged } from 'features/controlNet/store/controlNetSlice';
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
import { controlAdapterProcessorParamsChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { ControlAdapterProcessorNode } from 'features/controlNet/store/types';
import { useCallback } from 'react';
export const useProcessorNodeChanged = () => {
const dispatch = useAppDispatch();
const handleProcessorNodeChanged = useCallback(
(controlNetId: string, changes: Partial<ControlNetProcessorNode>) => {
(id: string, params: Partial<ControlAdapterProcessorNode>) => {
dispatch(
controlNetProcessorParamsChanged({
controlNetId,
changes,
controlAdapterProcessorParamsChanged({
id,
params,
})
);
},

View File

@ -2,32 +2,31 @@ import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
canvasImageToControlNet,
canvasMaskToControlNet,
canvasImageToControlAdapter,
canvasMaskToControlAdapter,
} from 'features/canvas/store/actions';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { FaImage, FaMask } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
type ControlNetCanvasImageImportsProps = {
controlNet: ControlNetConfig;
id: string;
};
const ControlNetCanvasImageImports = (
props: ControlNetCanvasImageImportsProps
) => {
const { controlNet } = props;
const { id } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleImportImageFromCanvas = useCallback(() => {
dispatch(canvasImageToControlNet({ controlNet }));
}, [controlNet, dispatch]);
dispatch(canvasImageToControlAdapter({ id }));
}, [id, dispatch]);
const handleImportMaskFromCanvas = useCallback(() => {
dispatch(canvasMaskToControlNet({ controlNet }));
}, [controlNet, dispatch]);
dispatch(canvasMaskToControlAdapter({ id }));
}, [id, dispatch]);
return (
<Flex

View File

@ -11,44 +11,49 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import { useControlAdapterBeginEndStepPct } from 'features/controlNet/hooks/useControlAdapterBeginEndStepPct';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import {
ControlNetConfig,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
} from 'features/controlNet/store/controlAdaptersSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
controlNet: ControlNetConfig;
id: string;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { beginStepPct, endStepPct, isEnabled, controlNetId } =
props.controlNet;
const ParamControlNetBeginEnd = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const stepPcts = useControlAdapterBeginEndStepPct(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleStepPctChanged = useCallback(
(v: number[]) => {
dispatch(
controlNetBeginStepPctChanged({
controlNetId,
controlAdapterBeginStepPctChanged({
id,
beginStepPct: v[0] as number,
})
);
dispatch(
controlNetEndStepPctChanged({
controlNetId,
controlAdapterEndStepPctChanged({
id,
endStepPct: v[1] as number,
})
);
},
[controlNetId, dispatch]
[dispatch, id]
);
if (!stepPcts) {
return null;
}
return (
<IAIInformationalPopover feature="controlNetBeginEnd">
<FormControl isDisabled={!isEnabled}>
@ -56,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
aria-label={['Begin Step %', 'End Step %!']}
value={[beginStepPct, endStepPct]}
value={[stepPcts.beginStepPct, stepPcts.endStepPct]}
onChange={handleStepPctChanged}
min={0}
max={1}
@ -67,10 +72,18 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={formatPct(beginStepPct)} placement="top" hasArrow>
<Tooltip
label={formatPct(stepPcts.beginStepPct)}
placement="top"
hasArrow
>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<Tooltip
label={formatPct(stepPcts.endStepPct)}
placement="top"
hasArrow
>
<RangeSliderThumb index={1} />
</Tooltip>
<RangeSliderMark

View File

@ -1,22 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
ControlNetConfig,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterControlMode } from 'features/controlNet/hooks/useControlAdapterControlMode';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import { controlAdapterControlModeChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { ControlMode } from 'features/controlNet/store/types';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNet: ControlNetConfig;
type Props = {
id: string;
};
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlMode, isEnabled, controlNetId } = props.controlNet;
export default function ParamControlNetControlMode({ id }: Props) {
const isEnabled = useControlAdapterIsEnabled(id);
const controlMode = useControlAdapterControlMode(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -28,19 +26,23 @@ export default function ParamControlNetControlMode(
];
const handleControlModeChange = useCallback(
(controlMode: ControlModes) => {
dispatch(controlNetControlModeChanged({ controlNetId, controlMode }));
(controlMode: ControlMode) => {
dispatch(controlAdapterControlModeChanged({ id, controlMode }));
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!controlMode) {
return null;
}
return (
<IAIInformationalPopover feature="controlNetControlMode">
<IAIMantineSelect
disabled={!isEnabled}
label={t('controlnet.controlMode')}
data={CONTROL_MODE_DATA}
value={String(controlMode)}
value={controlMode}
onChange={handleControlModeChange}
/>
</IAIInformationalPopover>

View File

@ -1,23 +1,21 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import {
ControlNetConfig,
controlNetModelChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlNet/hooks/useControlAdapterModel';
import { useControlAdapterModels } from 'features/controlNet/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlNet/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = {
controlNet: ControlNetConfig;
id: string;
};
const selector = createSelector(
@ -29,23 +27,31 @@ const selector = createSelector(
defaultSelectorOptions
);
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet;
const ParamControlNetModel = ({ id }: ParamControlNetModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const { mainModel } = useAppSelector(selector);
const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery();
const models = useControlAdapterModels(controlAdapterType);
const data = useMemo(() => {
if (!controlNetModels) {
if (!models) {
return [];
}
const data: SelectItem[] = [];
const data: {
value: string;
label: string;
group: string;
disabled: boolean;
tooltip?: string;
}[] = [];
forEach(controlNetModels.entities, (model, id) => {
models.forEach((model) => {
if (!model) {
return;
}
@ -53,7 +59,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
value: model.id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
@ -63,20 +69,24 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
});
});
data.sort((a, b) =>
// sort 'none' to the top
a.disabled ? 1 : b.disabled ? -1 : a.label.localeCompare(b.label)
);
return data;
}, [controlNetModels, mainModel?.base_model, t]);
}, [mainModel?.base_model, models, t]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
models?.find(
(m) =>
m?.id ===
`${model?.base_model}/${controlAdapterType}/${model?.model_name}`
// (m) => m?.id === `${model?.base_model}/controlnet/${model?.model_name}`
) ?? null,
[controlAdapterType, model?.base_model, model?.model_name, models]
);
const handleModelChanged = useCallback(
@ -91,13 +101,13 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
dispatch(controlAdapterModelChanged({ id, model: newControlNetModel }));
},
[controlNetId, dispatch]
[dispatch, id]
);
console.log(model, selectedModel);
return (
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}

View File

@ -5,19 +5,18 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import { useControlAdapterProcessorNode } from 'features/controlNet/hooks/useControlAdapterProcessorNode';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import {
ControlNetConfig,
controlNetProcessorTypeChanged,
} from '../../store/controlNetSlice';
import { ControlNetProcessorType } from '../../store/types';
import { useTranslation } from 'react-i18next';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlAdapterProcessortTypeChanged } from '../../store/controlAdaptersSlice';
import { ControlAdapterProcessorType } from '../../store/types';
type ParamControlNetProcessorSelectProps = {
controlNet: ControlNetConfig;
type Props = {
id: string;
};
const selector = createSelector(
@ -41,7 +40,7 @@ const selector = createSelector(
.filter(
(d) =>
!config.sd.disabledControlNetProcessors.includes(
d.value as ControlNetProcessorType
d.value as ControlAdapterProcessorType
)
);
@ -50,26 +49,29 @@ const selector = createSelector(
defaultSelectorOptions
);
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const ParamControlNetProcessorSelect = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const processorNode = useControlAdapterProcessorNode(id);
const dispatch = useAppDispatch();
const { controlNetId, isEnabled, processorNode } = props.controlNet;
const controlNetProcessors = useAppSelector(selector);
const { t } = useTranslation();
const handleProcessorTypeChanged = useCallback(
(v: string | null) => {
dispatch(
controlNetProcessorTypeChanged({
controlNetId,
processorType: v as ControlNetProcessorType,
controlAdapterProcessortTypeChanged({
id,
processorType: v as ControlAdapterProcessorType,
})
);
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!processorNode) {
return null;
}
return (
<IAIMantineSearchableSelect
label={t('controlnet.processor')}

View File

@ -1,22 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlNetConfig,
ResizeModes,
controlNetResizeModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import { useControlAdapterResizeMode } from 'features/controlNet/hooks/useControlAdapterResizeMode';
import { controlAdapterResizeModeChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { ResizeMode } from 'features/controlNet/store/types';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetResizeModeProps = {
controlNet: ControlNetConfig;
type Props = {
id: string;
};
export default function ParamControlNetResizeMode(
props: ParamControlNetResizeModeProps
) {
const { resizeMode, isEnabled, controlNetId } = props.controlNet;
export default function ParamControlNetResizeMode({ id }: Props) {
const isEnabled = useControlAdapterIsEnabled(id);
const resizeMode = useControlAdapterResizeMode(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -27,19 +25,23 @@ export default function ParamControlNetResizeMode(
];
const handleResizeModeChange = useCallback(
(resizeMode: ResizeModes) => {
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
(resizeMode: ResizeMode) => {
dispatch(controlAdapterResizeModeChanged({ id, resizeMode }));
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!resizeMode) {
return null;
}
return (
<IAIInformationalPopover feature="controlNetResizeMode">
<IAIMantineSelect
disabled={!isEnabled}
label={t('controlnet.resizeMode')}
data={RESIZE_MODE_DATA}
value={String(resizeMode)}
value={resizeMode}
onChange={handleResizeModeChange}
/>
</IAIInformationalPopover>

View File

@ -1,28 +1,34 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAISlider from 'common/components/IAISlider';
import {
ControlNetConfig,
controlNetWeightChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterIsEnabled } from 'features/controlNet/hooks/useControlAdapterIsEnabled';
import { useControlAdapterWeight } from 'features/controlNet/hooks/useControlAdapterWeight';
import { controlAdapterWeightChanged } from 'features/controlNet/store/controlAdaptersSlice';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetWeightProps = {
controlNet: ControlNetConfig;
id: string;
};
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { weight, isEnabled, controlNetId } = props.controlNet;
const ParamControlNetWeight = ({ id }: ParamControlNetWeightProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const weight = useControlAdapterWeight(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight }));
dispatch(controlAdapterWeightChanged({ id, weight }));
},
[controlNetId, dispatch]
[dispatch, id]
);
if (isNil(weight)) {
// should never happen
return null;
}
return (
<IAIInformationalPopover feature="controlNetWeight">
<IAISlider

View File

@ -0,0 +1,45 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { controlAdapterAdded } from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useMemo } from 'react';
import {
controlNetModelsAdapter,
useGetControlNetModelsQuery,
} from 'services/api/endpoints/models';
export const useAddControlNet = () => {
const dispatch = useAppDispatch();
const baseModel = useAppSelector(
(state) => state.generation.model?.base_model
);
const { data: controlNetModels } = useGetControlNetModelsQuery();
const firstControlNetModel = useMemo(
() =>
controlNetModels
? controlNetModelsAdapter
.getSelectors()
.selectAll(controlNetModels)
.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]
: undefined,
[baseModel, controlNetModels]
);
const isDisabled = useMemo(
() => !firstControlNetModel,
[firstControlNetModel]
);
const addControlNet = useCallback(() => {
if (isDisabled) {
return;
}
dispatch(
controlAdapterAdded({
type: 'controlnet',
overrides: { model: firstControlNetModel },
})
);
}, [dispatch, firstControlNetModel, isDisabled]);
return {
addControlNet,
isDisabled,
};
};

View File

@ -0,0 +1,63 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
controlAdapterAdded,
selectAllIPAdapters,
} from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useMemo } from 'react';
import {
ipAdapterModelsAdapter,
useGetIPAdapterModelsQuery,
} from 'services/api/endpoints/models';
const selector = createSelector(
[stateSelector],
({ controlAdapters, generation }) => {
const ipAdapterCount = selectAllIPAdapters(controlAdapters).length;
const { model } = generation;
return {
ipAdapterCount,
baseModel: model?.base_model,
};
},
defaultSelectorOptions
);
export const useAddIPAdapter = () => {
const { ipAdapterCount, baseModel } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const firstIPAdapterModel = useMemo(
() =>
ipAdapterModels
? ipAdapterModelsAdapter
.getSelectors()
.selectAll(ipAdapterModels)
.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]
: undefined,
[baseModel, ipAdapterModels]
);
const isDisabled = useMemo(
() => !firstIPAdapterModel && ipAdapterCount === 0,
[firstIPAdapterModel, ipAdapterCount]
);
const addIPAdapter = useCallback(() => {
if (isDisabled) {
return;
}
dispatch(
controlAdapterAdded({
type: 'ip_adapter',
overrides: { model: firstIPAdapterModel },
})
);
}, [dispatch, firstIPAdapterModel, isDisabled]);
return {
addIPAdapter,
isDisabled,
};
};

View File

@ -0,0 +1,45 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { controlAdapterAdded } from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useMemo } from 'react';
import {
t2iAdapterModelsAdapter,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useAddT2IAdapter = () => {
const dispatch = useAppDispatch();
const baseModel = useAppSelector(
(state) => state.generation.model?.base_model
);
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const firstT2IAdapterModel = useMemo(
() =>
t2iAdapterModels
? t2iAdapterModelsAdapter
.getSelectors()
.selectAll(t2iAdapterModels)
.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]
: undefined,
[baseModel, t2iAdapterModels]
);
const isDisabled = useMemo(
() => !firstT2IAdapterModel,
[firstT2IAdapterModel]
);
const addT2IAdapter = useCallback(() => {
if (isDisabled) {
return;
}
dispatch(
controlAdapterAdded({
type: 't2i_adapter',
overrides: { model: firstT2IAdapterModel },
})
);
}, [dispatch, firstT2IAdapterModel, isDisabled]);
return {
addT2IAdapter,
isDisabled,
};
};

View File

@ -0,0 +1,22 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapter = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => selectControlAdapterById(controlAdapters, id),
defaultSelectorOptions
),
[id]
);
const controlAdapter = useAppSelector(selector);
return controlAdapter;
};

View File

@ -0,0 +1,30 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterBeginEndStepPct = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const cn = selectControlAdapterById(controlAdapters, id);
return cn
? {
beginStepPct: cn.beginStepPct,
endStepPct: cn.endStepPct,
}
: undefined;
},
defaultSelectorOptions
),
[id]
);
const stepPcts = useAppSelector(selector);
return stepPcts;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterControlImage = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.controlImage,
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNet } from '../store/types';
export const useControlAdapterControlMode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNet(ca)) {
return ca.controlMode;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterIsEnabled = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.isEnabled ?? false,
defaultSelectorOptions
),
[id]
);
const isEnabled = useAppSelector(selector);
return isEnabled;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterModel = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.model,
defaultSelectorOptions
),
[id]
);
const model = useAppSelector(selector);
return model;
};

View File

@ -0,0 +1,48 @@
import { useMemo } from 'react';
import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
t2iAdapterModelsAdapter,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
import { ControlAdapterType } from '../store/types';
export const useControlAdapterModels = (type?: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
const controlNetModels = useMemo(
() =>
controlNetModelsData
? controlNetModelsAdapter.getSelectors().selectAll(controlNetModelsData)
: [],
[controlNetModelsData]
);
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const t2iAdapterModels = useMemo(
() =>
t2iAdapterModelsData
? t2iAdapterModelsAdapter.getSelectors().selectAll(t2iAdapterModelsData)
: [],
[t2iAdapterModelsData]
);
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
const ipAdapterModels = useMemo(
() =>
ipAdapterModelsData
? ipAdapterModelsAdapter.getSelectors().selectAll(ipAdapterModelsData)
: [],
[ipAdapterModelsData]
);
if (type === 'controlnet') {
return controlNetModels;
}
if (type === 't2i_adapter') {
return t2iAdapterModels;
}
if (type === 'ip_adapter') {
return ipAdapterModels;
}
return;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessedControlImage = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processedControlImage
: undefined;
},
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessorNode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processorNode
: undefined;
},
defaultSelectorOptions
),
[id]
);
const processorNode = useAppSelector(selector);
return processorNode;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessorType = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processorType
: undefined;
},
defaultSelectorOptions
),
[id]
);
const processorType = useAppSelector(selector);
return processorType;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterResizeMode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNetOrT2IAdapter(ca)) {
return ca.resizeMode;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@ -0,0 +1,29 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterShouldAutoConfig = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNetOrT2IAdapter(ca)) {
return ca.shouldAutoConfig;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterType = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.type,
defaultSelectorOptions
),
[id]
);
const type = useAppSelector(selector);
return type;
};

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterWeight = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.weight,
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
export const controlNetImageProcessed = createAction<{
controlNetId: string;
}>('controlNet/imageProcessed');
export const controlAdapterImageProcessed = createAction<{
id: string;
}>('controlAdapters/imageProcessed');

View File

@ -1,16 +1,16 @@
import i18n from 'i18next';
import {
ControlNetProcessorType,
RequiredControlNetProcessorNode,
ControlAdapterProcessorType,
RequiredControlAdapterProcessorNode,
} from './types';
type ControlNetProcessorsDict = Record<
ControlNetProcessorType,
ControlAdapterProcessorType,
{
type: ControlNetProcessorType | 'none';
type: ControlAdapterProcessorType | 'none';
label: string;
description: string;
default: RequiredControlNetProcessorNode | { type: 'none' };
default: RequiredControlAdapterProcessorNode | { type: 'none' };
}
>;
/**
@ -240,7 +240,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
};
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
[key: string]: ControlAdapterProcessorType;
} = {
canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor',

View File

@ -0,0 +1,8 @@
import { ControlAdaptersState } from './types';
/**
* ControlNet slice persist denylist
*/
export const controlAdaptersPersistDenylist: (keyof ControlAdaptersState)[] = [
'pendingControlImages',
];

View File

@ -0,0 +1,479 @@
import {
PayloadAction,
Update,
createEntityAdapter,
createSlice,
} from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, merge, uniq } from 'lodash-es';
import { appSocketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { buildControlAdapter } from '../util/buildControlAdapter';
import { controlAdapterImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlAdapterConfig,
ControlAdapterProcessorType,
ControlAdapterType,
ControlAdaptersState,
ControlMode,
ControlNetConfig,
RequiredControlAdapterProcessorNode,
ResizeMode,
T2IAdapterConfig,
isControlNet,
isControlNetOrT2IAdapter,
isIPAdapter,
isT2IAdapter,
} from './types';
export const caAdapter = createEntityAdapter<ControlAdapterConfig>();
export const {
selectById: selectControlAdapterById,
selectAll: selectControlAdapterAll,
selectEntities: selectControlAdapterEntities,
selectIds: selectControlAdapterIds,
selectTotal: selectControlAdapterTotal,
} = caAdapter.getSelectors();
export const initialControlAdapterState: ControlAdaptersState =
caAdapter.getInitialState<{
pendingControlImages: string[];
}>({
pendingControlImages: [],
});
export const selectAllControlNets = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isControlNet);
export const selectValidControlNets = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isControlNet)
.filter(
(ca) =>
ca.isEnabled &&
ca.model &&
(Boolean(ca.processedControlImage) ||
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
export const selectAllIPAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isIPAdapter);
export const selectValidIPAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isIPAdapter)
.filter((ca) => ca.isEnabled && ca.model && Boolean(ca.controlImage));
export const selectAllT2IAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isT2IAdapter);
export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isT2IAdapter)
.filter(
(ca) =>
ca.isEnabled &&
ca.model &&
(Boolean(ca.processedControlImage) ||
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
export const controlAdaptersSlice = createSlice({
name: 'controlAdapters',
initialState: initialControlAdapterState,
reducers: {
controlAdapterAdded: {
reducer: (
state,
action: PayloadAction<{
id: string;
type: ControlAdapterType;
overrides?: Partial<ControlAdapterConfig>;
}>
) => {
const { id, type, overrides } = action.payload;
caAdapter.addOne(state, buildControlAdapter(id, type, overrides));
},
prepare: ({
type,
overrides,
}: {
type: ControlAdapterType;
overrides?: Partial<ControlAdapterConfig>;
}) => {
return { payload: { id: uuidv4(), type, overrides } };
},
},
controlAdapterRecalled: (
state,
action: PayloadAction<ControlAdapterConfig>
) => {
const config = action.payload;
caAdapter.addOne(state, config);
},
controlAdapterDuplicated: {
reducer: (
state,
action: PayloadAction<{
id: string;
newId: string;
}>
) => {
const { id, newId } = action.payload;
const controlAdapter = selectControlAdapterById(state, id);
if (!controlAdapter) {
return;
}
const newControlAdapter = merge(cloneDeep(controlAdapter), {
id: newId,
});
caAdapter.addOne(state, newControlAdapter);
},
prepare: (id: string) => {
return { payload: { id, newId: uuidv4() } };
},
},
controlAdapterAddedFromImage: {
reducer: (
state,
action: PayloadAction<{
id: string;
type: ControlAdapterType;
controlImage: string;
}>
) => {
const { id, type, controlImage } = action.payload;
caAdapter.addOne(
state,
buildControlAdapter(id, type, { controlImage })
);
},
prepare: (payload: {
type: ControlAdapterType;
controlImage: string;
}) => {
return { payload: { ...payload, id: uuidv4() } };
},
},
controlAdapterRemoved: (state, action: PayloadAction<{ id: string }>) => {
caAdapter.removeOne(state, action.payload.id);
},
controlAdapterIsEnabledChanged: (
state,
action: PayloadAction<{ id: string; isEnabled: boolean }>
) => {
const { id, isEnabled } = action.payload;
caAdapter.updateOne(state, { id, changes: { isEnabled } });
},
controlAdapterImageChanged: (
state,
action: PayloadAction<{
id: string;
controlImage: string | null;
}>
) => {
const { id, controlImage } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
caAdapter.updateOne(state, {
id,
changes: { controlImage, processedControlImage: null },
});
if (
controlImage !== null &&
isControlNetOrT2IAdapter(cn) &&
cn.processorType !== 'none'
) {
state.pendingControlImages.push(id);
}
},
controlAdapterProcessedImageChanged: (
state,
action: PayloadAction<{
id: string;
processedControlImage: string | null;
}>
) => {
const { id, processedControlImage } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
return;
}
caAdapter.updateOne(state, {
id,
changes: {
processedControlImage,
},
});
state.pendingControlImages = state.pendingControlImages.filter(
(pendingId) => pendingId !== id
);
},
controlAdapterModelCleared: (
state,
action: PayloadAction<{ id: string }>
) => {
caAdapter.updateOne(state, {
id: action.payload.id,
changes: { model: null },
});
},
controlAdapterModelChanged: (
state,
action: PayloadAction<{
id: string;
model:
| ControlNetModelParam
| T2IAdapterModelParam
| IPAdapterModelParam;
}>
) => {
const { id, model } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } });
return;
}
const update: Update<ControlNetConfig | T2IAdapterConfig> = {
id,
changes: { model },
};
update.changes.processedControlImage = null;
if (cn.shouldAutoConfig) {
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
}
caAdapter.updateOne(state, update);
},
controlAdapterWeightChanged: (
state,
action: PayloadAction<{ id: string; weight: number }>
) => {
const { id, weight } = action.payload;
caAdapter.updateOne(state, { id, changes: { weight } });
},
controlAdapterBeginStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginStepPct: number }>
) => {
const { id, beginStepPct } = action.payload;
caAdapter.updateOne(state, { id, changes: { beginStepPct } });
},
controlAdapterEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; endStepPct: number }>
) => {
const { id, endStepPct } = action.payload;
caAdapter.updateOne(state, { id, changes: { endStepPct } });
},
controlAdapterControlModeChanged: (
state,
action: PayloadAction<{ id: string; controlMode: ControlMode }>
) => {
const { id, controlMode } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNet(cn)) {
return;
}
caAdapter.updateOne(state, { id, changes: { controlMode } });
},
controlAdapterResizeModeChanged: (
state,
action: PayloadAction<{
id: string;
resizeMode: ResizeMode;
}>
) => {
const { id, resizeMode } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
caAdapter.updateOne(state, { id, changes: { resizeMode } });
},
controlAdapterProcessorParamsChanged: (
state,
action: PayloadAction<{
id: string;
params: Partial<RequiredControlAdapterProcessorNode>;
}>
) => {
const { id, params } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn) || !cn.processorNode) {
return;
}
const processorNode = merge(cloneDeep(cn.processorNode), params);
caAdapter.updateOne(state, {
id,
changes: {
shouldAutoConfig: false,
processorNode,
},
});
},
controlAdapterProcessortTypeChanged: (
state,
action: PayloadAction<{
id: string;
processorType: ControlAdapterProcessorType;
}>
) => {
const { id, processorType } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
const processorNode = cloneDeep(
CONTROLNET_PROCESSORS[processorType].default
) as RequiredControlAdapterProcessorNode;
caAdapter.updateOne(state, {
id,
changes: {
processorType,
processedControlImage: null,
processorNode,
shouldAutoConfig: false,
},
});
},
controlAdapterAutoConfigToggled: (
state,
action: PayloadAction<{
id: string;
}>
) => {
const { id } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
const update: Update<ControlNetConfig | T2IAdapterConfig> = {
id,
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
};
if (update.changes.shouldAutoConfig) {
// manage the processor for the user
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (cn.model?.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
}
caAdapter.updateOne(state, update);
},
controlAdaptersReset: () => {
return cloneDeep(initialControlAdapterState);
},
pendingControlImagesCleared: (state) => {
state.pendingControlImages = [];
},
},
extraReducers: (builder) => {
builder.addCase(controlAdapterImageProcessed, (state, action) => {
const cn = selectControlAdapterById(state, action.payload.id);
if (!cn) {
return;
}
if (cn.controlImage !== null) {
state.pendingControlImages = uniq(
state.pendingControlImages.concat(action.payload.id)
);
}
});
builder.addCase(appSocketInvocationError, (state) => {
state.pendingControlImages = [];
});
},
});
export const {
controlAdapterAdded,
controlAdapterRecalled,
controlAdapterDuplicated,
controlAdapterAddedFromImage,
controlAdapterRemoved,
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
controlAdapterIsEnabledChanged,
controlAdapterModelChanged,
controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
controlAdapterControlModeChanged,
controlAdapterResizeModeChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
controlAdaptersReset,
controlAdapterAutoConfigToggled,
pendingControlImagesCleared,
controlAdapterModelCleared,
} = controlAdaptersSlice.actions;
export default controlAdaptersSlice.reducer;

View File

@ -1,8 +0,0 @@
import { ControlNetState } from './controlNetSlice';
/**
* ControlNet slice persist denylist
*/
export const controlNetDenylist: (keyof ControlNetState)[] = [
'pendingControlImages',
];

View File

@ -1,486 +0,0 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
export type ControlModes = NonNullable<
components['schemas']['ControlNetInvocation']['control_mode']
>;
export type ResizeModes = NonNullable<
components['schemas']['ControlNetInvocation']['resize_mode']
>;
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlModes;
resizeMode: ResizeModes;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
};
export type IPAdapterConfig = {
adapterImage: string | null;
model: IPAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
};
export type ControlNetState = {
controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean;
pendingControlImages: string[];
isIPAdapterEnabled: boolean;
ipAdapterInfo: IPAdapterConfig;
};
export const initialIPAdapterState: IPAdapterConfig = {
adapterImage: null,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
};
export const initialControlNetState: ControlNetState = {
controlNets: {},
isEnabled: false,
pendingControlImages: [],
isIPAdapterEnabled: false,
ipAdapterInfo: { ...initialIPAdapterState },
};
export const controlNetSlice = createSlice({
name: 'controlNet',
initialState: initialControlNetState,
reducers: {
isControlNetEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
controlNetEnabled: (state) => {
state.isEnabled = true;
},
controlNetAdded: (
state,
action: PayloadAction<{
controlNetId: string;
controlNet?: ControlNetConfig;
}>
) => {
const { controlNetId, controlNet } = action.payload;
state.controlNets[controlNetId] = {
...(controlNet ?? initialControlNet),
controlNetId,
};
},
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
const controlNet = action.payload;
state.controlNets[controlNet.controlNetId] = {
...controlNet,
};
},
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const oldControlNet = state.controlNets[sourceControlNetId];
if (!oldControlNet) {
return;
}
const newControlnet = cloneDeep(oldControlNet);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
...initialControlNet,
controlNetId,
controlImage,
};
},
controlNetRemoved: (
state,
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
delete state.controlNets[controlNetId];
},
controlNetIsEnabledChanged: (
state,
action: PayloadAction<{ controlNetId: string; isEnabled: boolean }>
) => {
const { controlNetId, isEnabled } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.isEnabled = isEnabled;
},
controlNetImageChanged: (
state,
action: PayloadAction<{
controlNetId: string;
controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.controlImage = controlImage;
cn.processedControlImage = null;
if (controlImage !== null && cn.processorType !== 'none') {
state.pendingControlImages.push(controlNetId);
}
},
controlNetProcessedImageChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.processedControlImage = processedControlImage;
state.pendingControlImages = state.pendingControlImages.filter(
(id) => id !== controlNetId
);
},
controlNetModelChanged: (
state,
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelParam;
}>
) => {
const { controlNetId, model } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.model = model;
cn.processedControlImage = null;
if (cn.shouldAutoConfig) {
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
} else {
cn.processorType = 'none';
cn.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlNetProcessorNode;
}
}
},
controlNetWeightChanged: (
state,
action: PayloadAction<{ controlNetId: string; weight: number }>
) => {
const { controlNetId, weight } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.weight = weight;
},
controlNetBeginStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; beginStepPct: number }>
) => {
const { controlNetId, beginStepPct } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.beginStepPct = beginStepPct;
},
controlNetEndStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; endStepPct: number }>
) => {
const { controlNetId, endStepPct } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.endStepPct = endStepPct;
},
controlNetControlModeChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
) => {
const { controlNetId, controlMode } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.controlMode = controlMode;
},
controlNetResizeModeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
resizeMode: ResizeModes;
}>
) => {
const { controlNetId, resizeMode } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.resizeMode = resizeMode;
},
controlNetProcessorParamsChanged: (
state,
action: PayloadAction<{
controlNetId: string;
changes: Omit<
Partial<RequiredControlNetProcessorNode>,
'id' | 'type' | 'is_intermediate'
>;
}>
) => {
const { controlNetId, changes } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
const processorNode = cn.processorNode;
cn.processorNode = {
...processorNode,
...changes,
};
cn.shouldAutoConfig = false;
},
controlNetProcessorTypeChanged: (
state,
action: PayloadAction<{
controlNetId: string;
processorType: ControlNetProcessorType;
}>
) => {
const { controlNetId, processorType } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.processedControlImage = null;
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
cn.shouldAutoConfig = false;
},
controlNetAutoConfigToggled: (
state,
action: PayloadAction<{
controlNetId: string;
}>
) => {
const { controlNetId } = action.payload;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
const newShouldAutoConfig = !cn.shouldAutoConfig;
if (newShouldAutoConfig) {
// manage the processor for the user
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (cn.model?.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
} else {
cn.processorType = 'none';
cn.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlNetProcessorNode;
}
}
cn.shouldAutoConfig = newShouldAutoConfig;
},
controlNetReset: () => {
return { ...initialControlNetState };
},
isIPAdapterEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isIPAdapterEnabled = action.payload;
},
ipAdapterRecalled: (state, action: PayloadAction<IPAdapterConfig>) => {
state.ipAdapterInfo = action.payload;
},
ipAdapterImageChanged: (state, action: PayloadAction<string | null>) => {
state.ipAdapterInfo.adapterImage = action.payload;
},
ipAdapterWeightChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.weight = action.payload;
},
ipAdapterModelChanged: (
state,
action: PayloadAction<IPAdapterModelParam | null>
) => {
state.ipAdapterInfo.model = action.payload;
},
ipAdapterBeginStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.beginStepPct = action.payload;
},
ipAdapterEndStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.endStepPct = action.payload;
},
ipAdapterStateReset: (state) => {
state.isIPAdapterEnabled = false;
state.ipAdapterInfo = { ...initialIPAdapterState };
},
clearPendingControlImages: (state) => {
state.pendingControlImages = [];
},
},
extraReducers: (builder) => {
builder.addCase(controlNetImageProcessed, (state, action) => {
const cn = state.controlNets[action.payload.controlNetId];
if (!cn) {
return;
}
if (cn.controlImage !== null) {
state.pendingControlImages.push(action.payload.controlNetId);
}
});
builder.addCase(appSocketInvocationError, (state) => {
state.pendingControlImages = [];
});
builder.addMatcher(
imagesApi.endpoints.deleteImage.matchFulfilled,
(state, action) => {
// Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg.originalArgs;
forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) {
c.controlImage = null;
c.processedControlImage = null;
}
if (c.processedControlImage === image_name) {
c.processedControlImage = null;
}
});
}
);
},
});
export const {
isControlNetEnabledToggled,
controlNetEnabled,
controlNetAdded,
controlNetRecalled,
controlNetDuplicated,
controlNetAddedFromImage,
controlNetRemoved,
controlNetImageChanged,
controlNetProcessedImageChanged,
controlNetIsEnabledChanged,
controlNetModelChanged,
controlNetWeightChanged,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
controlNetControlModeChanged,
controlNetResizeModeChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
controlNetReset,
controlNetAutoConfigToggled,
isIPAdapterEnabledChanged,
ipAdapterRecalled,
ipAdapterImageChanged,
ipAdapterWeightChanged,
ipAdapterModelChanged,
ipAdapterBeginStepPctChanged,
ipAdapterEndStepPctChanged,
ipAdapterStateReset,
clearPendingControlImages,
} = controlNetSlice.actions;
export default controlNetSlice.reducer;

View File

@ -1,4 +1,11 @@
import { EntityState } from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { isObject } from 'lodash-es';
import { components } from 'services/api/schema';
import {
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
@ -12,6 +19,7 @@ import {
NormalbaeImageProcessorInvocation,
OpenposeImageProcessorInvocation,
PidiImageProcessorInvocation,
T2IAdapterModelConfig,
ZoeDepthImageProcessorInvocation,
} from 'services/api/types';
import { O } from 'ts-toolbelt';
@ -19,7 +27,7 @@ import { O } from 'ts-toolbelt';
/**
* Any ControlNet processor node
*/
export type ControlNetProcessorNode =
export type ControlAdapterProcessorNode =
| CannyImageProcessorInvocation
| ColorMapImageProcessorInvocation
| ContentShuffleImageProcessorInvocation
@ -37,8 +45,8 @@ export type ControlNetProcessorNode =
/**
* Any ControlNet processor type
*/
export type ControlNetProcessorType = NonNullable<
ControlNetProcessorNode['type'] | 'none'
export type ControlAdapterProcessorType = NonNullable<
ControlAdapterProcessorNode['type'] | 'none'
>;
/**
@ -148,7 +156,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
/**
* Any ControlNet Processor node, with its parameters flagged as required
*/
export type RequiredControlNetProcessorNode = O.Required<
export type RequiredControlAdapterProcessorNode = O.Required<
| RequiredCannyImageProcessorInvocation
| RequiredColorMapImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
@ -356,3 +364,90 @@ export const isZoeDepthImageProcessorInvocation = (
}
return false;
};
export type ControlMode = NonNullable<
components['schemas']['ControlNetInvocation']['control_mode']
>;
export type ResizeMode = NonNullable<
components['schemas']['ControlNetInvocation']['resize_mode']
>;
export type ControlNetConfig = {
type: 'controlnet';
id: string;
isEnabled: boolean;
model: ControlNetModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
controlMode: ControlMode;
resizeMode: ResizeMode;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean;
};
export type T2IAdapterConfig = {
type: 't2i_adapter';
id: string;
isEnabled: boolean;
model: T2IAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
resizeMode: ResizeMode;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlAdapterProcessorType;
processorNode: RequiredControlAdapterProcessorNode;
shouldAutoConfig: boolean;
};
export type IPAdapterConfig = {
type: 'ip_adapter';
id: string;
isEnabled: boolean;
controlImage: string | null;
model: IPAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
};
export type ControlAdapterConfig =
| ControlNetConfig
| IPAdapterConfig
| T2IAdapterConfig;
export type ControlAdapterType = ControlAdapterConfig['type'];
export type ControlAdaptersState = EntityState<ControlAdapterConfig> & {
pendingControlImages: string[];
};
export const isControlNet = (
controlAdapter: ControlAdapterConfig
): controlAdapter is ControlNetConfig => {
return controlAdapter.type === 'controlnet';
};
export const isIPAdapter = (
controlAdapter: ControlAdapterConfig
): controlAdapter is IPAdapterConfig => {
return controlAdapter.type === 'ip_adapter';
};
export const isT2IAdapter = (
controlAdapter: ControlAdapterConfig
): controlAdapter is T2IAdapterConfig => {
return controlAdapter.type === 't2i_adapter';
};
export const isControlNetOrT2IAdapter = (
controlAdapter: ControlAdapterConfig
): controlAdapter is ControlNetConfig | T2IAdapterConfig => {
return isControlNet(controlAdapter) || isT2IAdapter(controlAdapter);
};

View File

@ -0,0 +1,70 @@
import { cloneDeep, merge } from 'lodash-es';
import {
ControlAdapterConfig,
ControlAdapterType,
ControlNetConfig,
IPAdapterConfig,
RequiredCannyImageProcessorInvocation,
T2IAdapterConfig,
} from '../store/types';
import { CONTROLNET_PROCESSORS } from '../store/constants';
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
type: 'controlnet',
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export const initialT2IAdapter: Omit<T2IAdapterConfig, 'id'> = {
type: 't2i_adapter',
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
type: 'ip_adapter',
isEnabled: true,
controlImage: null,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
};
export const buildControlAdapter = (
id: string,
type: ControlAdapterType,
overrides: Partial<ControlAdapterConfig> = {}
): ControlAdapterConfig => {
switch (type) {
case 'controlnet':
return merge(cloneDeep(initialControlNet), { id, ...overrides });
case 't2i_adapter':
return merge(cloneDeep(initialT2IAdapter), { id, ...overrides });
case 'ip_adapter':
return merge(cloneDeep(initialIPAdapter), { id, ...overrides });
default:
throw new Error(`Unknown control adapter type: ${type}`);
}
};

View File

@ -1,15 +0,0 @@
import { filter } from 'lodash-es';
import { ControlNetConfig } from '../store/controlNetSlice';
export const getValidControlNets = (
controlNets: Record<string, ControlNetConfig>
) => {
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
return validControlNets;
};

View File

@ -41,8 +41,7 @@ const selector = createSelector(
isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
isControlImage: some(allImageUsage, (i) => i.isControlImage),
};
return {

View File

@ -35,12 +35,9 @@ const ImageUsageMessage = (props: Props) => {
{imageUsage.isCanvasImage && (
<ListItem>{t('common.unifiedCanvas')}</ListItem>
)}
{imageUsage.isControlNetImage && (
{imageUsage.isControlImage && (
<ListItem>{t('common.controlNet')}</ListItem>
)}
{imageUsage.isIPAdapterImage && (
<ListItem>{t('common.ipAdapter')}</ListItem>
)}
{imageUsage.isNodesImage && (
<ListItem>{t('common.nodeEditor')}</ListItem>
)}

View File

@ -4,9 +4,11 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isInvocationNode } from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
import { selectControlAdapterAll } from 'features/controlNet/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlNet/store/types';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state;
const { generation, canvas, nodes, controlAdapters } = state;
const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some(
@ -21,20 +23,17 @@ export const getImageUsage = (state: RootState, image_name: string) => {
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
c.controlImage === image_name || c.processedControlImage === image_name
const isControlImage = selectControlAdapterAll(controlAdapters).some(
(ca) =>
ca.controlImage === image_name ||
(isControlNetOrT2IAdapter(ca) && ca.processedControlImage === image_name)
);
const isIPAdapterImage = controlNet.ipAdapterInfo.adapterImage === image_name;
const imageUsage: ImageUsage = {
isInitialImage,
isCanvasImage,
isNodesImage,
isControlNetImage,
isIPAdapterImage,
isControlImage,
};
return imageUsage;

View File

@ -9,6 +9,5 @@ export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
isIPAdapterImage: boolean;
isControlImage: boolean;
};

View File

@ -28,10 +28,10 @@ export type InitialImageDropData = BaseDropData & {
actionType: 'SET_INITIAL_IMAGE';
};
export type ControlNetDropData = BaseDropData & {
actionType: 'SET_CONTROLNET_IMAGE';
export type ControlAdapterDropData = BaseDropData & {
actionType: 'SET_CONTROL_ADAPTER_IMAGE';
context: {
controlNetId: string;
id: string;
};
};
@ -76,8 +76,7 @@ export type AddFieldToLinearViewDropData = BaseDropData & {
export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
| ControlNetDropData
| IPAdapterImageDropData
| ControlAdapterDropData
| CanvasInitialImageDropData
| NodesImageDropData
| AddToBatchDropData

View File

@ -22,9 +22,7 @@ export const isValidDrop = (
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_IP_ADAPTER_IMAGE':
case 'SET_CONTROL_ADAPTER_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';

View File

@ -52,8 +52,7 @@ const DeleteBoardModal = (props: Props) => {
isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
isControlImage: some(allImageUsage, (i) => i.isControlImage),
};
return { imageUsageSummary };
}),

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { selectValidControlNets } from 'features/controlNet/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import {
CollectInvocation,
@ -19,102 +19,101 @@ export const addControlNetToLinearGraph = (
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = getValidControlNets(controlNets);
const validControlNets = selectValidControlNets(state.controlAdapters);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length) {
// We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
if (validControlNets.length) {
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 'control',
},
});
validControlNets.forEach((controlNet) => {
if (!controlNet.model) {
return;
}
const {
id,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${id}`,
type: 'controlnet',
is_intermediate: true,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
resize_mode: resizeMode,
control_model: model,
control_weight: weight,
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
if (metadataAccumulator?.controlnets) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
}
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: baseNodeId,
field: 'control',
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
validControlNets.forEach((controlNet) => {
const {
controlNetId,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`,
type: 'controlnet',
is_intermediate: true,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
control_mode: controlMode,
resize_mode: resizeMode,
control_model: model as ControlNetInvocation['control_model'],
control_weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
if (metadataAccumulator?.controlnets) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
'id',
'type',
]) as ControlField;
metadataAccumulator.controlnets.push(controlField);
}
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'control',
},
});
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'control',
},
});
}
});
}
}
});
}
};

View File

@ -9,35 +9,37 @@ import {
IP_ADAPTER,
METADATA_ACCUMULATOR,
} from './constants';
import { selectValidIPAdapters } from 'features/controlNet/store/controlAdaptersSlice';
export const addIPAdapterToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
const validIPAdapters = selectValidIPAdapters(state.controlAdapters);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isIPAdapterEnabled && ipAdapterInfo.model) {
const ipAdapter = validIPAdapters[0];
// TODO: handle multiple IP adapters once backend is capable
if (ipAdapter && ipAdapter.model) {
const { weight, model, beginStepPct, endStepPct } = ipAdapter;
const ipAdapterNode: IPAdapterInvocation = {
id: IP_ADAPTER,
type: 'ip_adapter',
is_intermediate: true,
weight: ipAdapterInfo.weight,
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
weight: weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
if (ipAdapterInfo.adapterImage) {
if (ipAdapter.controlImage) {
ipAdapterNode.image = {
image_name: ipAdapterInfo.adapterImage,
image_name: ipAdapter.controlImage,
};
} else {
return;
@ -47,15 +49,12 @@ export const addIPAdapterToLinearGraph = (
if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = {
image: {
image_name: ipAdapterInfo.adapterImage,
image_name: ipAdapter.controlImage,
},
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
weight: ipAdapterInfo.weight,
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
};
metadataAccumulator.ipAdapters.push(ipAdapterField);

View File

@ -1,122 +1,100 @@
import { Divider, Flex } from '@chakra-ui/react';
import { ButtonGroup, Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse';
import IAIIconButton from 'common/components/IAIIconButton';
import ControlNet from 'features/controlNet/components/ControlNet';
import IPAdapterPanel from 'features/controlNet/components/ipAdapter/IPAdapterPanel';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import { useAddControlNet } from 'features/controlNet/hooks/useAddControlNet';
import { useAddIPAdapter } from 'features/controlNet/hooks/useAddIPAdapter';
import { useAddT2IAdapter } from 'features/controlNet/hooks/useAddT2IAdapter';
import {
controlNetAdded,
controlNetModelChanged,
} from 'features/controlNet/store/controlNetSlice';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
} from 'features/controlNet/store/controlAdaptersSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es';
import { Fragment, memo, useCallback, useMemo } from 'react';
import { Fragment, memo } from 'react';
import { FaPlus } from 'react-icons/fa';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid';
const selector = createSelector(
[stateSelector],
({ controlNet }) => {
const { controlNets, isEnabled, isIPAdapterEnabled, ipAdapterInfo } =
controlNet;
({ controlAdapters }) => {
const activeLabel: string[] = [];
const validControlNets = getValidControlNets(controlNets);
const isIPAdapterValid = ipAdapterInfo.model && ipAdapterInfo.adapterImage;
let activeLabel = undefined;
if (isEnabled && validControlNets.length > 0) {
activeLabel = `${validControlNets.length} ControlNet`;
const validIPAdapters = selectAllIPAdapters(controlAdapters);
const validIPAdapterCount = validIPAdapters.length;
if (validIPAdapterCount > 0) {
activeLabel.push(`${validIPAdapterCount} IP`);
}
if (isIPAdapterEnabled && isIPAdapterValid) {
if (activeLabel) {
activeLabel = `${activeLabel}, IP Adapter`;
} else {
activeLabel = 'IP Adapter';
}
const validControlNets = selectAllControlNets(controlAdapters);
const validControlNetCount = validControlNets.length;
if (validControlNetCount > 0) {
activeLabel.push(`${validControlNetCount} ControlNet`);
}
return { controlNetsArray: map(controlNets), activeLabel };
const validT2IAdapters = selectAllT2IAdapters(controlAdapters);
const validT2IAdapterCount = validT2IAdapters.length;
if (validT2IAdapterCount > 0) {
activeLabel.push(`${validT2IAdapterCount} T2I`);
}
return {
controlAdapters: [
...validIPAdapters,
...validControlNets,
...validT2IAdapters,
],
activeLabel: activeLabel.join(', '),
};
},
defaultSelectorOptions
);
const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector);
const { controlAdapters, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch();
const { data: controlnetModels } = useGetControlNetModelsQuery();
const firstModel = useMemo(() => {
if (!controlnetModels || !Object.keys(controlnetModels.entities).length) {
return undefined;
}
const firstModelId = Object.keys(controlnetModels.entities)[0];
if (!firstModelId) {
return undefined;
}
const firstModel = controlnetModels.entities[firstModelId];
return firstModel ? firstModel : undefined;
}, [controlnetModels]);
const handleClickedAddControlNet = useCallback(() => {
if (!firstModel) {
return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) {
return null;
}
const { addControlNet } = useAddControlNet();
const { addIPAdapter } = useAddIPAdapter();
const { addT2IAdapter } = useAddT2IAdapter();
return (
<IAICollapse label="Control Adapters" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<Flex
sx={{
w: '100%',
gap: 2,
p: 2,
ps: 3,
borderRadius: 'base',
alignItems: 'center',
bg: 'base.250',
_dark: {
bg: 'base.750',
},
}}
>
<ParamControlNetFeatureToggle />
<IAIIconButton
tooltip="Add ControlNet"
aria-label="Add ControlNet"
icon={<FaPlus />}
isDisabled={!firstModel}
flexGrow={1}
size="sm"
onClick={handleClickedAddControlNet}
<ButtonGroup size="sm" w="full" justifyContent="space-between">
<IAIButton
leftIcon={<FaPlus />}
onClick={addControlNet}
data-testid="add controlnet"
/>
</Flex>
{controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}>
>
ControlNet
</IAIButton>
<IAIButton
leftIcon={<FaPlus />}
onClick={addIPAdapter}
data-testid="add ip adapter"
>
IP Adapter
</IAIButton>
<IAIButton
leftIcon={<FaPlus />}
onClick={addT2IAdapter}
data-testid="add t2i adapter"
>
T2I Adapter
</IAIButton>
</ButtonGroup>
{controlAdapters.map((ca, i) => (
<Fragment key={ca.id}>
{i > 0 && <Divider />}
<ControlNet controlNet={c} />
<ControlNet id={ca.id} />
</Fragment>
))}
<IPAdapterPanel />
</Flex>
</IAICollapse>
);

View File

@ -30,17 +30,6 @@ import {
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import {
ControlNetConfig,
IPAdapterConfig,
controlNetEnabled,
controlNetRecalled,
controlNetReset,
initialControlNet,
initialIPAdapterState,
ipAdapterRecalled,
isIPAdapterEnabledChanged,
} from '../../controlNet/store/controlNetSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions';
import {
@ -515,7 +504,7 @@ export const useRecallParameters = () => {
}
dispatch(
controlNetRecalled({
controlAdapterRecalled({
...result.controlnet,
})
);
@ -745,14 +734,14 @@ export const useRecallParameters = () => {
}
});
dispatch(controlNetReset());
dispatch(controlAdaptersReset());
if (controlnets?.length) {
dispatch(controlNetEnabled());
}
controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet);
if (result.controlnet) {
dispatch(controlNetRecalled(result.controlnet));
dispatch(controlAdapterRecalled(result.controlnet));
}
});

View File

@ -1,6 +1,6 @@
import { Heading, Text } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { controlAdaptersReset } from 'features/controlNet/store/controlAdaptersSlice';
import { useCallback, useEffect } from 'react';
import IAIButton from '../../../../common/components/IAIButton';
import {
@ -24,7 +24,7 @@ export default function SettingsClearIntermediates() {
clearIntermediates()
.unwrap()
.then((response) => {
dispatch(controlNetReset());
dispatch(controlAdaptersReset());
dispatch(resetCanvas());
dispatch(
addToast({

View File

@ -191,13 +191,9 @@ export type GraphInvocationOutput = s['GraphInvocationOutput'];
// Post-image upload actions, controls workflows when images are uploaded
export type ControlNetAction = {
type: 'SET_CONTROLNET_IMAGE';
controlNetId: string;
};
export type IPAdapterAction = {
type: 'SET_IP_ADAPTER_IMAGE';
export type ControlAdapterAction = {
type: 'SET_CONTROL_ADAPTER_IMAGE';
id: string;
};
export type InitialImageAction = {
@ -224,8 +220,7 @@ export type AddToBatchAction = {
};
export type PostUploadAction =
| ControlNetAction
| IPAdapterAction
| ControlAdapterAction
| InitialImageAction
| NodesAction
| CanvasInitialImageAction