feat(ui): auto-process for control layer CAs

This commit is contained in:
psychedelicious 2024-05-02 08:18:34 +10:00 committed by Kent Keirsey
parent 905baf2787
commit c96b98fc9e
10 changed files with 495 additions and 242 deletions

View File

@ -16,6 +16,7 @@ import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listen
import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet'; import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged'; import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery'; import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery';
import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor';
import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess'; import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess';
import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed'; import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed';
import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas'; import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas';
@ -157,3 +158,4 @@ addUpscaleRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening); addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening); addSetDefaultSettingsListener(startAppListening);
addControlAdapterPreprocessor(startAppListening);

View File

@ -0,0 +1,147 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import {
caLayerImageChanged,
caLayerIsProcessingImageChanged,
caLayerModelChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import { CONTROLNET_PROCESSORS } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged);
const DEBOUNCE_MS = 300;
const log = logger('session');
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take }) => {
const { layerId } = action.payload;
const precheckLayer = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
// Conditions to bail
if (
// Layer doesn't exist
!precheckLayer ||
// Layer doesn't have an image
!precheckLayer.controlAdapter.image ||
// Layer doesn't have a processor config
!precheckLayer.controlAdapter.processorConfig ||
// Layer is already processing an image
precheckLayer.controlAdapter.isProcessingImage
) {
return;
}
// Cancel any in-progress instances of this listener
cancelActiveListeners();
log.trace('Control Layer CA auto-process triggered');
// Delay before starting actual work
await delay(DEBOUNCE_MS);
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: true }));
// Double-check that we are still eligible for processing
const state = getState();
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
const image = layer?.controlAdapter.image;
const config = layer?.controlAdapter.processorConfig;
// If we have no image or there is no processor config, bail
if (!layer || !image || !config) {
return;
}
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error...
const processorNode = CONTROLNET_PROCESSORS[config.type].buildNode(image, config);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[processorNode.id]: { ...processorNode, is_intermediate: true },
},
edges: [],
},
runs: 1,
},
};
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === processorNode.id
);
// We still have to check the output type
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received
const [{ payload }] = await take(
(action) =>
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
);
const imageDTO = payload as ImageDTO;
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
// Update the processed image in the store
dispatch(
caLayerProcessedImageChanged({
layerId,
imageDTO,
})
);
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false }));
}
} catch (error) {
console.log(error);
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false }));
if (error instanceof Object) {
if ('data' in error && 'status' in error) {
if (error.status === 403) {
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
return;
}
}
}
dispatch(
addToast({
title: t('queue.graphFailedToQueue'),
status: 'error',
})
);
}
},
});
};

View File

@ -5,7 +5,7 @@ import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon
import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { layerSelected, selectCALayer } from 'features/controlLayers/store/controlLayersSlice'; import { layerSelected, selectCALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import CALayerOpacity from './CALayerOpacity'; import CALayerOpacity from './CALayerOpacity';
@ -16,7 +16,7 @@ type Props = {
export const CALayer = memo(({ layerId }: Props) => { export const CALayer = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isSelected = useAppSelector((s) => selectCALayer(s.controlLayers.present, layerId).isSelected); const isSelected = useAppSelector((s) => selectCALayerOrThrow(s.controlLayers.present, layerId).isSelected);
const onClickCapture = useCallback(() => { const onClickCapture = useCallback(() => {
// Must be capture so that the layer is selected before deleting/resetting/etc // Must be capture so that the layer is selected before deleting/resetting/etc
dispatch(layerSelected(layerId)); dispatch(layerSelected(layerId));

View File

@ -7,7 +7,7 @@ import {
caLayerProcessorConfigChanged, caLayerProcessorConfigChanged,
caOrIPALayerBeginEndStepPctChanged, caOrIPALayerBeginEndStepPctChanged,
caOrIPALayerWeightChanged, caOrIPALayerWeightChanged,
selectCALayer, selectCALayerOrThrow,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import type { ControlMode, ProcessorConfig } from 'features/controlLayers/util/controlAdapters'; import type { ControlMode, ProcessorConfig } from 'features/controlLayers/util/controlAdapters';
import type { CALayerImageDropData } from 'features/dnd/types'; import type { CALayerImageDropData } from 'features/dnd/types';
@ -25,7 +25,7 @@ type Props = {
export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => { export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const controlAdapter = useAppSelector((s) => selectCALayer(s.controlLayers.present, layerId).controlAdapter); const controlAdapter = useAppSelector((s) => selectCALayerOrThrow(s.controlLayers.present, layerId).controlAdapter);
const onChangeBeginEndStepPct = useCallback( const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => { (beginEndStepPct: [number, number]) => {

View File

@ -87,11 +87,8 @@ export const ControlAdapter = memo(
</Flex> </Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1"> <Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<ControlAdapterImagePreview <ControlAdapterImagePreview
image={controlAdapter.image} controlAdapter={controlAdapter}
processedImage={controlAdapter.processedImage}
onChangeImage={onChangeImage} onChangeImage={onChangeImage}
hasProcessor={Boolean(controlAdapter.processorConfig)}
controlAdapterId={controlAdapter.id}
droppableData={droppableData} droppableData={droppableData}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
/> />
@ -99,7 +96,10 @@ export const ControlAdapter = memo(
</Flex> </Flex>
{isExpanded && ( {isExpanded && (
<> <>
<ControlAdapterProcessorTypeSelect config={controlAdapter.processorConfig} onChange={onChangeProcessorConfig} /> <ControlAdapterProcessorTypeSelect
config={controlAdapter.processorConfig}
onChange={onChangeProcessorConfig}
/>
<ControlAdapterProcessorConfig config={controlAdapter.processorConfig} onChange={onChangeProcessorConfig} /> <ControlAdapterProcessorConfig config={controlAdapter.processorConfig} onChange={onChangeProcessorConfig} />
</> </>
)} )}

View File

@ -1,14 +1,12 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library'; import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library'; import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { selectControlAdaptersSlice } from 'features/controlAdapters/store/controlAdaptersSlice';
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice'; import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
import type { ImageWithDims } from 'features/controlLayers/util/controlAdapters'; import type { ControlNetConfig, T2IAdapterConfig } from 'features/controlLayers/util/controlAdapters';
import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types'; import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
@ -25,46 +23,29 @@ import {
import type { ImageDTO, PostUploadAction } from 'services/api/types'; import type { ImageDTO, PostUploadAction } from 'services/api/types';
type Props = { type Props = {
controlAdapterId: string; controlAdapter: ControlNetConfig | T2IAdapterConfig;
image: ImageWithDims | null;
processedImage: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void; onChangeImage: (imageDTO: ImageDTO | null) => void;
hasProcessor: boolean;
droppableData: TypesafeDroppableData; droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction; postUploadAction: PostUploadAction;
}; };
const selectPendingControlImages = createMemoizedSelector(
selectControlAdaptersSlice,
(controlAdapters) => controlAdapters.pendingControlImages
);
export const ControlAdapterImagePreview = memo( export const ControlAdapterImagePreview = memo(
({ ({ controlAdapter, onChangeImage, droppableData, postUploadAction }: Props) => {
image,
processedImage,
onChangeImage,
hasProcessor,
controlAdapterId,
droppableData,
postUploadAction,
}: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const isConnected = useAppSelector((s) => s.system.isConnected); const isConnected = useAppSelector((s) => s.system.isConnected);
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
const optimalDimension = useAppSelector(selectOptimalDimension); const optimalDimension = useAppSelector(selectOptimalDimension);
const pendingControlImages = useAppSelector(selectPendingControlImages);
const shift = useShiftModifier(); const shift = useShiftModifier();
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery( const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.imageName ?? skipToken controlAdapter.image?.imageName ?? skipToken
); );
const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery( const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery(
processedImage?.imageName ?? skipToken controlAdapter.processedImage?.imageName ?? skipToken
); );
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation(); const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
@ -130,19 +111,19 @@ export const ControlAdapterImagePreview = memo(
const draggableData = useMemo<ImageDraggableData | undefined>(() => { const draggableData = useMemo<ImageDraggableData | undefined>(() => {
if (controlImage) { if (controlImage) {
return { return {
id: controlAdapterId, id: controlAdapter.id,
payloadType: 'IMAGE_DTO', payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage }, payload: { imageDTO: controlImage },
}; };
} }
}, [controlImage, controlAdapterId]); }, [controlImage, controlAdapter.id]);
const shouldShowProcessedImage = const shouldShowProcessedImage =
controlImage && controlImage &&
processedControlImage && processedControlImage &&
!isMouseOverImage && !isMouseOverImage &&
!pendingControlImages.includes(controlAdapterId) && !controlAdapter.isProcessingImage &&
hasProcessor; controlAdapter.processorConfig !== null;
useEffect(() => { useEffect(() => {
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) { if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
@ -207,7 +188,7 @@ export const ControlAdapterImagePreview = memo(
/> />
</> </>
{pendingControlImages.includes(controlAdapterId) && ( {controlAdapter.isProcessingImage && (
<Flex <Flex
position="absolute" position="absolute"
top={0} top={0}

View File

@ -7,7 +7,7 @@ import {
ipaLayerImageChanged, ipaLayerImageChanged,
ipaLayerMethodChanged, ipaLayerMethodChanged,
ipaLayerModelChanged, ipaLayerModelChanged,
selectIPALayer, selectIPALayerOrThrow,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import type { CLIPVisionModel, IPMethod } from 'features/controlLayers/util/controlAdapters'; import type { CLIPVisionModel, IPMethod } from 'features/controlLayers/util/controlAdapters';
import type { IPALayerImageDropData } from 'features/dnd/types'; import type { IPALayerImageDropData } from 'features/dnd/types';
@ -20,7 +20,7 @@ type Props = {
export const IPALayerIPAdapterWrapper = memo(({ layerId }: Props) => { export const IPALayerIPAdapterWrapper = memo(({ layerId }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const ipAdapter = useAppSelector((s) => selectIPALayer(s.controlLayers.present, layerId).ipAdapter); const ipAdapter = useAppSelector((s) => selectIPALayerOrThrow(s.controlLayers.present, layerId).ipAdapter);
const onChangeBeginEndStepPct = useCallback( const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => { (beginEndStepPct: [number, number]) => {

View File

@ -9,7 +9,7 @@ import {
rgLayerIPAdapterMethodChanged, rgLayerIPAdapterMethodChanged,
rgLayerIPAdapterModelChanged, rgLayerIPAdapterModelChanged,
rgLayerIPAdapterWeightChanged, rgLayerIPAdapterWeightChanged,
selectRGLayerIPAdapter, selectRGLayerIPAdapterOrThrow,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import type { CLIPVisionModel, IPMethod } from 'features/controlLayers/util/controlAdapters'; import type { CLIPVisionModel, IPMethod } from 'features/controlLayers/util/controlAdapters';
import type { RGLayerIPAdapterImageDropData } from 'features/dnd/types'; import type { RGLayerIPAdapterImageDropData } from 'features/dnd/types';
@ -28,7 +28,7 @@ export const RGLayerIPAdapterWrapper = memo(({ layerId, ipAdapterId, ipAdapterNu
const onDeleteIPAdapter = useCallback(() => { const onDeleteIPAdapter = useCallback(() => {
dispatch(rgLayerIPAdapterDeleted({ layerId, ipAdapterId })); dispatch(rgLayerIPAdapterDeleted({ layerId, ipAdapterId }));
}, [dispatch, ipAdapterId, layerId]); }, [dispatch, ipAdapterId, layerId]);
const ipAdapter = useAppSelector((s) => selectRGLayerIPAdapter(s.controlLayers.present, layerId, ipAdapterId)); const ipAdapter = useAppSelector((s) => selectRGLayerIPAdapterOrThrow(s.controlLayers.present, layerId, ipAdapterId));
const onChangeBeginEndStepPct = useCallback( const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => { (beginEndStepPct: [number, number]) => {

View File

@ -78,17 +78,17 @@ const resetLayer = (layer: Layer) => {
} }
}; };
export const selectCALayer = (state: ControlLayersState, layerId: string): ControlAdapterLayer => { export const selectCALayerOrThrow = (state: ControlLayersState, layerId: string): ControlAdapterLayer => {
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
assert(isControlAdapterLayer(layer)); assert(isControlAdapterLayer(layer));
return layer; return layer;
}; };
export const selectIPALayer = (state: ControlLayersState, layerId: string): IPAdapterLayer => { export const selectIPALayerOrThrow = (state: ControlLayersState, layerId: string): IPAdapterLayer => {
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
assert(isIPAdapterLayer(layer)); assert(isIPAdapterLayer(layer));
return layer; return layer;
}; };
export const selectCAOrIPALayer = ( export const selectCAOrIPALayerOrThrow = (
state: ControlLayersState, state: ControlLayersState,
layerId: string layerId: string
): ControlAdapterLayer | IPAdapterLayer => { ): ControlAdapterLayer | IPAdapterLayer => {
@ -96,12 +96,12 @@ export const selectCAOrIPALayer = (
assert(isControlAdapterLayer(layer) || isIPAdapterLayer(layer)); assert(isControlAdapterLayer(layer) || isIPAdapterLayer(layer));
return layer; return layer;
}; };
export const selectRGLayer = (state: ControlLayersState, layerId: string): RegionalGuidanceLayer => { export const selectRGLayerOrThrow = (state: ControlLayersState, layerId: string): RegionalGuidanceLayer => {
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
assert(isRegionalGuidanceLayer(layer)); assert(isRegionalGuidanceLayer(layer));
return layer; return layer;
}; };
export const selectRGLayerIPAdapter = ( export const selectRGLayerIPAdapterOrThrow = (
state: ControlLayersState, state: ControlLayersState,
layerId: string, layerId: string,
ipAdapterId: string ipAdapterId: string
@ -246,7 +246,7 @@ export const controlLayersSlice = createSlice({
}, },
caLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { caLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload; const { layerId, imageDTO } = action.payload;
const layer = selectCALayer(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
layer.bbox = null; layer.bbox = null;
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
layer.isEnabled = true; layer.isEnabled = true;
@ -255,7 +255,7 @@ export const controlLayersSlice = createSlice({
}, },
caLayerProcessedImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { caLayerProcessedImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload; const { layerId, imageDTO } = action.payload;
const layer = selectCALayer(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
layer.bbox = null; layer.bbox = null;
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
layer.isEnabled = true; layer.isEnabled = true;
@ -269,7 +269,7 @@ export const controlLayersSlice = createSlice({
}> }>
) => { ) => {
const { layerId, modelConfig } = action.payload; const { layerId, modelConfig } = action.payload;
const layer = selectCALayer(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
if (!modelConfig) { if (!modelConfig) {
layer.controlAdapter.model = null; layer.controlAdapter.model = null;
return; return;
@ -285,7 +285,7 @@ export const controlLayersSlice = createSlice({
}, },
caLayerControlModeChanged: (state, action: PayloadAction<{ layerId: string; controlMode: ControlMode }>) => { caLayerControlModeChanged: (state, action: PayloadAction<{ layerId: string; controlMode: ControlMode }>) => {
const { layerId, controlMode } = action.payload; const { layerId, controlMode } = action.payload;
const layer = selectCALayer(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
assert(layer.controlAdapter.type === 'controlnet'); assert(layer.controlAdapter.type === 'controlnet');
layer.controlAdapter.controlMode = controlMode; layer.controlAdapter.controlMode = controlMode;
}, },
@ -294,19 +294,27 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; processorConfig: ProcessorConfig | null }> action: PayloadAction<{ layerId: string; processorConfig: ProcessorConfig | null }>
) => { ) => {
const { layerId, processorConfig } = action.payload; const { layerId, processorConfig } = action.payload;
const layer = selectCALayer(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
layer.controlAdapter.processorConfig = processorConfig; layer.controlAdapter.processorConfig = processorConfig;
}, },
caLayerIsFilterEnabledChanged: (state, action: PayloadAction<{ layerId: string; isFilterEnabled: boolean }>) => { caLayerIsFilterEnabledChanged: (state, action: PayloadAction<{ layerId: string; isFilterEnabled: boolean }>) => {
const { layerId, isFilterEnabled } = action.payload; const { layerId, isFilterEnabled } = action.payload;
const layer = selectCALayer(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
layer.isFilterEnabled = isFilterEnabled; layer.isFilterEnabled = isFilterEnabled;
}, },
caLayerOpacityChanged: (state, action: PayloadAction<{ layerId: string; opacity: number }>) => { caLayerOpacityChanged: (state, action: PayloadAction<{ layerId: string; opacity: number }>) => {
const { layerId, opacity } = action.payload; const { layerId, opacity } = action.payload;
const layer = selectCALayer(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
layer.opacity = opacity; layer.opacity = opacity;
}, },
caLayerIsProcessingImageChanged: (
state,
action: PayloadAction<{ layerId: string; isProcessingImage: boolean }>
) => {
const { layerId, isProcessingImage } = action.payload;
const layer = selectCALayerOrThrow(state, layerId);
layer.controlAdapter.isProcessingImage = isProcessingImage;
},
//#endregion //#endregion
//#region IP Adapter Layers //#region IP Adapter Layers
@ -325,12 +333,12 @@ export const controlLayersSlice = createSlice({
}, },
ipaLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { ipaLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload; const { layerId, imageDTO } = action.payload;
const layer = selectIPALayer(state, layerId); const layer = selectIPALayerOrThrow(state, layerId);
layer.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; layer.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
}, },
ipaLayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => { ipaLayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => {
const { layerId, weight } = action.payload; const { layerId, weight } = action.payload;
const layer = selectIPALayer(state, layerId); const layer = selectIPALayerOrThrow(state, layerId);
layer.ipAdapter.weight = weight; layer.ipAdapter.weight = weight;
}, },
ipaLayerBeginEndStepPctChanged: ( ipaLayerBeginEndStepPctChanged: (
@ -338,12 +346,12 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }> action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }>
) => { ) => {
const { layerId, beginEndStepPct } = action.payload; const { layerId, beginEndStepPct } = action.payload;
const layer = selectIPALayer(state, layerId); const layer = selectIPALayerOrThrow(state, layerId);
layer.ipAdapter.beginEndStepPct = beginEndStepPct; layer.ipAdapter.beginEndStepPct = beginEndStepPct;
}, },
ipaLayerMethodChanged: (state, action: PayloadAction<{ layerId: string; method: IPMethod }>) => { ipaLayerMethodChanged: (state, action: PayloadAction<{ layerId: string; method: IPMethod }>) => {
const { layerId, method } = action.payload; const { layerId, method } = action.payload;
const layer = selectIPALayer(state, layerId); const layer = selectIPALayerOrThrow(state, layerId);
layer.ipAdapter.method = method; layer.ipAdapter.method = method;
}, },
ipaLayerModelChanged: ( ipaLayerModelChanged: (
@ -354,7 +362,7 @@ export const controlLayersSlice = createSlice({
}> }>
) => { ) => {
const { layerId, modelConfig } = action.payload; const { layerId, modelConfig } = action.payload;
const layer = selectIPALayer(state, layerId); const layer = selectIPALayerOrThrow(state, layerId);
if (!modelConfig) { if (!modelConfig) {
layer.ipAdapter.model = null; layer.ipAdapter.model = null;
return; return;
@ -366,7 +374,7 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; clipVisionModel: CLIPVisionModel }> action: PayloadAction<{ layerId: string; clipVisionModel: CLIPVisionModel }>
) => { ) => {
const { layerId, clipVisionModel } = action.payload; const { layerId, clipVisionModel } = action.payload;
const layer = selectIPALayer(state, layerId); const layer = selectIPALayerOrThrow(state, layerId);
layer.ipAdapter.clipVisionModel = clipVisionModel; layer.ipAdapter.clipVisionModel = clipVisionModel;
}, },
//#endregion //#endregion
@ -374,7 +382,7 @@ export const controlLayersSlice = createSlice({
//#region CA or IPA Layers //#region CA or IPA Layers
caOrIPALayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => { caOrIPALayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => {
const { layerId, weight } = action.payload; const { layerId, weight } = action.payload;
const layer = selectCAOrIPALayer(state, layerId); const layer = selectCAOrIPALayerOrThrow(state, layerId);
if (layer.type === 'control_adapter_layer') { if (layer.type === 'control_adapter_layer') {
layer.controlAdapter.weight = weight; layer.controlAdapter.weight = weight;
} else { } else {
@ -386,7 +394,7 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }> action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }>
) => { ) => {
const { layerId, beginEndStepPct } = action.payload; const { layerId, beginEndStepPct } = action.payload;
const layer = selectCAOrIPALayer(state, layerId); const layer = selectCAOrIPALayerOrThrow(state, layerId);
if (layer.type === 'control_adapter_layer') { if (layer.type === 'control_adapter_layer') {
layer.controlAdapter.beginEndStepPct = beginEndStepPct; layer.controlAdapter.beginEndStepPct = beginEndStepPct;
} else { } else {
@ -428,17 +436,17 @@ export const controlLayersSlice = createSlice({
}, },
rgLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { rgLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload; const { layerId, prompt } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
layer.positivePrompt = prompt; layer.positivePrompt = prompt;
}, },
rgLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { rgLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => {
const { layerId, prompt } = action.payload; const { layerId, prompt } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
layer.negativePrompt = prompt; layer.negativePrompt = prompt;
}, },
rgLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => { rgLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload; const { layerId, color } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
layer.previewColor = color; layer.previewColor = color;
}, },
rgLayerLineAdded: { rgLayerLineAdded: {
@ -452,7 +460,7 @@ export const controlLayersSlice = createSlice({
}> }>
) => { ) => {
const { layerId, points, tool, lineUuid } = action.payload; const { layerId, points, tool, lineUuid } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
const lineId = getRGLayerLineId(layer.id, lineUuid); const lineId = getRGLayerLineId(layer.id, lineUuid);
layer.maskObjects.push({ layer.maskObjects.push({
type: 'vector_mask_line', type: 'vector_mask_line',
@ -474,7 +482,7 @@ export const controlLayersSlice = createSlice({
}, },
rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => { rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => {
const { layerId, point } = action.payload; const { layerId, point } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
const lastLine = layer.maskObjects.findLast(isLine); const lastLine = layer.maskObjects.findLast(isLine);
if (!lastLine) { if (!lastLine) {
return; return;
@ -491,7 +499,7 @@ export const controlLayersSlice = createSlice({
// Ignore zero-area rectangles // Ignore zero-area rectangles
return; return;
} }
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
const id = getRGLayerRectId(layer.id, rectUuid); const id = getRGLayerRectId(layer.id, rectUuid);
layer.maskObjects.push({ layer.maskObjects.push({
type: 'vector_mask_rect', type: 'vector_mask_rect',
@ -510,17 +518,17 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }> action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
) => { ) => {
const { layerId, autoNegative } = action.payload; const { layerId, autoNegative } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
layer.autoNegative = autoNegative; layer.autoNegative = autoNegative;
}, },
rgLayerIPAdapterAdded: (state, action: PayloadAction<{ layerId: string; ipAdapter: IPAdapterConfig }>) => { rgLayerIPAdapterAdded: (state, action: PayloadAction<{ layerId: string; ipAdapter: IPAdapterConfig }>) => {
const { layerId, ipAdapter } = action.payload; const { layerId, ipAdapter } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
layer.ipAdapters.push(ipAdapter); layer.ipAdapters.push(ipAdapter);
}, },
rgLayerIPAdapterDeleted: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => { rgLayerIPAdapterDeleted: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => {
const { layerId, ipAdapterId } = action.payload; const { layerId, ipAdapterId } = action.payload;
const layer = selectRGLayer(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
layer.ipAdapters = layer.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId); layer.ipAdapters = layer.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId);
}, },
rgLayerIPAdapterImageChanged: ( rgLayerIPAdapterImageChanged: (
@ -528,7 +536,7 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; ipAdapterId: string; imageDTO: ImageDTO | null }> action: PayloadAction<{ layerId: string; ipAdapterId: string; imageDTO: ImageDTO | null }>
) => { ) => {
const { layerId, ipAdapterId, imageDTO } = action.payload; const { layerId, ipAdapterId, imageDTO } = action.payload;
const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId);
ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
}, },
rgLayerIPAdapterWeightChanged: ( rgLayerIPAdapterWeightChanged: (
@ -536,7 +544,7 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; ipAdapterId: string; weight: number }> action: PayloadAction<{ layerId: string; ipAdapterId: string; weight: number }>
) => { ) => {
const { layerId, ipAdapterId, weight } = action.payload; const { layerId, ipAdapterId, weight } = action.payload;
const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId);
ipAdapter.weight = weight; ipAdapter.weight = weight;
}, },
rgLayerIPAdapterBeginEndStepPctChanged: ( rgLayerIPAdapterBeginEndStepPctChanged: (
@ -544,7 +552,7 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; ipAdapterId: string; beginEndStepPct: [number, number] }> action: PayloadAction<{ layerId: string; ipAdapterId: string; beginEndStepPct: [number, number] }>
) => { ) => {
const { layerId, ipAdapterId, beginEndStepPct } = action.payload; const { layerId, ipAdapterId, beginEndStepPct } = action.payload;
const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId);
ipAdapter.beginEndStepPct = beginEndStepPct; ipAdapter.beginEndStepPct = beginEndStepPct;
}, },
rgLayerIPAdapterMethodChanged: ( rgLayerIPAdapterMethodChanged: (
@ -552,7 +560,7 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; ipAdapterId: string; method: IPMethod }> action: PayloadAction<{ layerId: string; ipAdapterId: string; method: IPMethod }>
) => { ) => {
const { layerId, ipAdapterId, method } = action.payload; const { layerId, ipAdapterId, method } = action.payload;
const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId);
ipAdapter.method = method; ipAdapter.method = method;
}, },
rgLayerIPAdapterModelChanged: ( rgLayerIPAdapterModelChanged: (
@ -564,7 +572,7 @@ export const controlLayersSlice = createSlice({
}> }>
) => { ) => {
const { layerId, ipAdapterId, modelConfig } = action.payload; const { layerId, ipAdapterId, modelConfig } = action.payload;
const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId);
if (!modelConfig) { if (!modelConfig) {
ipAdapter.model = null; ipAdapter.model = null;
return; return;
@ -576,7 +584,7 @@ export const controlLayersSlice = createSlice({
action: PayloadAction<{ layerId: string; ipAdapterId: string; clipVisionModel: CLIPVisionModel }> action: PayloadAction<{ layerId: string; ipAdapterId: string; clipVisionModel: CLIPVisionModel }>
) => { ) => {
const { layerId, ipAdapterId, clipVisionModel } = action.payload; const { layerId, ipAdapterId, clipVisionModel } = action.payload;
const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId);
ipAdapter.clipVisionModel = clipVisionModel; ipAdapter.clipVisionModel = clipVisionModel;
}, },
//#endregion //#endregion
@ -720,6 +728,7 @@ export const {
caLayerProcessorConfigChanged, caLayerProcessorConfigChanged,
caLayerIsFilterEnabledChanged, caLayerIsFilterEnabledChanged,
caLayerOpacityChanged, caLayerOpacityChanged,
caLayerIsProcessingImageChanged,
// IPA Layers // IPA Layers
ipaLayerAdded, ipaLayerAdded,
ipaLayerImageChanged, ipaLayerImageChanged,

View File

@ -13,6 +13,7 @@ import type {
ControlNetModelConfig, ControlNetModelConfig,
DepthAnythingImageProcessorInvocation, DepthAnythingImageProcessorInvocation,
DWOpenposeImageProcessorInvocation, DWOpenposeImageProcessorInvocation,
Graph,
HedImageProcessorInvocation, HedImageProcessorInvocation,
ImageDTO, ImageDTO,
LineartAnimeImageProcessorInvocation, LineartAnimeImageProcessorInvocation,
@ -34,27 +35,33 @@ export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSiz
zDepthAnythingModelSize.safeParse(v).success; zDepthAnythingModelSize.safeParse(v).success;
export type CannyProcessorConfig = Required< export type CannyProcessorConfig = Required<
Pick<CannyImageProcessorInvocation, 'type' | 'low_threshold' | 'high_threshold'> Pick<CannyImageProcessorInvocation, 'id' | 'type' | 'low_threshold' | 'high_threshold'>
>;
export type ColorMapProcessorConfig = Required<
Pick<ColorMapImageProcessorInvocation, 'id' | 'type' | 'color_map_tile_size'>
>; >;
export type ColorMapProcessorConfig = Required<Pick<ColorMapImageProcessorInvocation, 'type' | 'color_map_tile_size'>>;
export type ContentShuffleProcessorConfig = Required< export type ContentShuffleProcessorConfig = Required<
Pick<ContentShuffleImageProcessorInvocation, 'type' | 'w' | 'h' | 'f'> Pick<ContentShuffleImageProcessorInvocation, 'id' | 'type' | 'w' | 'h' | 'f'>
>; >;
export type DepthAnythingProcessorConfig = Required<Pick<DepthAnythingImageProcessorInvocation, 'type' | 'model_size'>>; export type DepthAnythingProcessorConfig = Required<
export type HedProcessorConfig = Required<Pick<HedImageProcessorInvocation, 'type' | 'scribble'>>; Pick<DepthAnythingImageProcessorInvocation, 'id' | 'type' | 'model_size'>
export type LineartAnimeProcessorConfig = Required<Pick<LineartAnimeImageProcessorInvocation, 'type'>>; >;
export type LineartProcessorConfig = Required<Pick<LineartImageProcessorInvocation, 'type' | 'coarse'>>; export type HedProcessorConfig = Required<Pick<HedImageProcessorInvocation, 'id' | 'type' | 'scribble'>>;
export type LineartAnimeProcessorConfig = Required<Pick<LineartAnimeImageProcessorInvocation, 'id' | 'type'>>;
export type LineartProcessorConfig = Required<Pick<LineartImageProcessorInvocation, 'id' | 'type' | 'coarse'>>;
export type MediapipeFaceProcessorConfig = Required< export type MediapipeFaceProcessorConfig = Required<
Pick<MediapipeFaceProcessorInvocation, 'type' | 'max_faces' | 'min_confidence'> Pick<MediapipeFaceProcessorInvocation, 'id' | 'type' | 'max_faces' | 'min_confidence'>
>; >;
export type MidasDepthProcessorConfig = Required<Pick<MidasDepthImageProcessorInvocation, 'type' | 'a_mult' | 'bg_th'>>; export type MidasDepthProcessorConfig = Required<
export type MlsdProcessorConfig = Required<Pick<MlsdImageProcessorInvocation, 'type' | 'thr_v' | 'thr_d'>>; Pick<MidasDepthImageProcessorInvocation, 'id' | 'type' | 'a_mult' | 'bg_th'>
export type NormalbaeProcessorConfig = Required<Pick<NormalbaeImageProcessorInvocation, 'type'>>; >;
export type MlsdProcessorConfig = Required<Pick<MlsdImageProcessorInvocation, 'id' | 'type' | 'thr_v' | 'thr_d'>>;
export type NormalbaeProcessorConfig = Required<Pick<NormalbaeImageProcessorInvocation, 'id' | 'type'>>;
export type DWOpenposeProcessorConfig = Required< export type DWOpenposeProcessorConfig = Required<
Pick<DWOpenposeImageProcessorInvocation, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'> Pick<DWOpenposeImageProcessorInvocation, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>; >;
export type PidiProcessorConfig = Required<Pick<PidiImageProcessorInvocation, 'type' | 'safe' | 'scribble'>>; export type PidiProcessorConfig = Required<Pick<PidiImageProcessorInvocation, 'id' | 'type' | 'safe' | 'scribble'>>;
export type ZoeDepthProcessorConfig = Required<Pick<ZoeDepthImageProcessorInvocation, 'type'>>; export type ZoeDepthProcessorConfig = Required<Pick<ZoeDepthImageProcessorInvocation, 'id' | 'type'>>;
export type ProcessorConfig = export type ProcessorConfig =
| CannyProcessorConfig | CannyProcessorConfig
@ -83,6 +90,7 @@ type ControlAdapterBase = {
weight: number; weight: number;
image: ImageWithDims | null; image: ImageWithDims | null;
processedImage: ImageWithDims | null; processedImage: ImageWithDims | null;
isProcessingImage: boolean;
processorConfig: ProcessorConfig | null; processorConfig: ProcessorConfig | null;
beginEndStepPct: [number, number]; beginEndStepPct: [number, number];
}; };
@ -125,157 +133,6 @@ export type IPAdapterConfig = {
beginEndStepPct: [number, number]; beginEndStepPct: [number, number];
}; };
type ProcessorData<T extends ProcessorConfig['type']> = {
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
};
type ControlNetProcessorsDict = {
[key in ProcessorConfig['type']]: ProcessorData<key>;
};
/**
* A dict of ControlNet processors, including:
* - label translation key
* - description translation key
* - a builder to create default values for the config
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
canny_image_processor: {
labelTKey: 'controlnet.canny',
descriptionTKey: 'controlnet.cannyDescription',
buildDefaults: () => ({
id: `canny_image_processor_${uuidv4()}`,
type: 'canny_image_processor',
low_threshold: 100,
high_threshold: 200,
}),
},
color_map_image_processor: {
labelTKey: 'controlnet.colorMap',
descriptionTKey: 'controlnet.colorMapDescription',
buildDefaults: () => ({
id: `color_map_image_processor_${uuidv4()}`,
type: 'color_map_image_processor',
color_map_tile_size: 64,
}),
},
content_shuffle_image_processor: {
labelTKey: 'controlnet.contentShuffle',
descriptionTKey: 'controlnet.contentShuffleDescription',
buildDefaults: (baseModel) => ({
id: `content_shuffle_image_processor_${uuidv4()}`,
type: 'content_shuffle_image_processor',
h: baseModel === 'sdxl' ? 1024 : 512,
w: baseModel === 'sdxl' ? 1024 : 512,
f: baseModel === 'sdxl' ? 512 : 256,
}),
},
depth_anything_image_processor: {
labelTKey: 'controlnet.depthAnything',
descriptionTKey: 'controlnet.depthAnythingDescription',
buildDefaults: () => ({
id: `depth_anything_image_processor_${uuidv4()}`,
type: 'depth_anything_image_processor',
model_size: 'small',
}),
},
hed_image_processor: {
labelTKey: 'controlnet.hed',
descriptionTKey: 'controlnet.hedDescription',
buildDefaults: () => ({
id: `hed_image_processor_${uuidv4()}`,
type: 'hed_image_processor',
scribble: false,
}),
},
lineart_anime_image_processor: {
labelTKey: 'controlnet.lineartAnime',
descriptionTKey: 'controlnet.lineartAnimeDescription',
buildDefaults: () => ({
id: `lineart_anime_image_processor_${uuidv4()}`,
type: 'lineart_anime_image_processor',
}),
},
lineart_image_processor: {
labelTKey: 'controlnet.lineart',
descriptionTKey: 'controlnet.lineartDescription',
buildDefaults: () => ({
id: `lineart_image_processor_${uuidv4()}`,
type: 'lineart_image_processor',
coarse: false,
}),
},
mediapipe_face_processor: {
labelTKey: 'controlnet.mediapipeFace',
descriptionTKey: 'controlnet.mediapipeFaceDescription',
buildDefaults: () => ({
id: `mediapipe_face_processor_${uuidv4()}`,
type: 'mediapipe_face_processor',
max_faces: 1,
min_confidence: 0.5,
}),
},
midas_depth_image_processor: {
labelTKey: 'controlnet.depthMidas',
descriptionTKey: 'controlnet.depthMidasDescription',
buildDefaults: () => ({
id: `midas_depth_image_processor_${uuidv4()}`,
type: 'midas_depth_image_processor',
a_mult: 2,
bg_th: 0.1,
}),
},
mlsd_image_processor: {
labelTKey: 'controlnet.mlsd',
descriptionTKey: 'controlnet.mlsdDescription',
buildDefaults: () => ({
id: `mlsd_image_processor_${uuidv4()}`,
type: 'mlsd_image_processor',
thr_d: 0.1,
thr_v: 0.1,
}),
},
normalbae_image_processor: {
labelTKey: 'controlnet.normalBae',
descriptionTKey: 'controlnet.normalBaeDescription',
buildDefaults: () => ({
id: `normalbae_image_processor_${uuidv4()}`,
type: 'normalbae_image_processor',
}),
},
dw_openpose_image_processor: {
labelTKey: 'controlnet.dwOpenpose',
descriptionTKey: 'controlnet.dwOpenposeDescription',
buildDefaults: () => ({
id: `dw_openpose_image_processor_${uuidv4()}`,
type: 'dw_openpose_image_processor',
draw_body: true,
draw_face: false,
draw_hands: false,
}),
},
pidi_image_processor: {
labelTKey: 'controlnet.pidi',
descriptionTKey: 'controlnet.pidiDescription',
buildDefaults: () => ({
id: `pidi_image_processor_${uuidv4()}`,
type: 'pidi_image_processor',
scribble: false,
safe: false,
}),
},
zoe_depth_image_processor: {
labelTKey: 'controlnet.depthZoe',
descriptionTKey: 'controlnet.depthZoeDescription',
buildDefaults: () => ({
id: `zoe_depth_image_processor_${uuidv4()}`,
type: 'zoe_depth_image_processor',
}),
},
};
export const zProcessorType = z.enum([ export const zProcessorType = z.enum([
'canny_image_processor', 'canny_image_processor',
'color_map_image_processor', 'color_map_image_processor',
@ -295,6 +152,261 @@ export const zProcessorType = z.enum([
export type ProcessorType = z.infer<typeof zProcessorType>; export type ProcessorType = z.infer<typeof zProcessorType>;
export const isProcessorType = (v: unknown): v is ProcessorType => zProcessorType.safeParse(v).success; export const isProcessorType = (v: unknown): v is ProcessorType => zProcessorType.safeParse(v).success;
type ProcessorData<T extends ProcessorType> = {
type: T;
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
buildNode(
image: ImageWithDims,
config: Extract<ProcessorConfig, { type: T }>
): Extract<Graph['nodes'][string], { type: T }>;
};
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);
const getId = (type: ProcessorType): string => `${type}_${uuidv4()}`;
type CAProcessorsData = {
[key in ProcessorType]: ProcessorData<key>;
};
/**
* A dict of ControlNet processors, including:
* - label translation key
* - description translation key
* - a builder to create default values for the config
* - a builder to create the node for the config
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS: CAProcessorsData = {
canny_image_processor: {
type: 'canny_image_processor',
labelTKey: 'controlnet.canny',
descriptionTKey: 'controlnet.cannyDescription',
buildDefaults: () => ({
id: getId('canny_image_processor'),
type: 'canny_image_processor',
low_threshold: 100,
high_threshold: 200,
}),
buildNode: (image, config) => ({
...config,
type: 'canny_image_processor',
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
color_map_image_processor: {
type: 'color_map_image_processor',
labelTKey: 'controlnet.colorMap',
descriptionTKey: 'controlnet.colorMapDescription',
buildDefaults: () => ({
id: getId('color_map_image_processor'),
type: 'color_map_image_processor',
color_map_tile_size: 64,
}),
buildNode: (image, config) => ({
...config,
type: 'color_map_image_processor',
image: { image_name: image.imageName },
}),
},
content_shuffle_image_processor: {
type: 'content_shuffle_image_processor',
labelTKey: 'controlnet.contentShuffle',
descriptionTKey: 'controlnet.contentShuffleDescription',
buildDefaults: (baseModel) => ({
id: getId('content_shuffle_image_processor'),
type: 'content_shuffle_image_processor',
h: baseModel === 'sdxl' ? 1024 : 512,
w: baseModel === 'sdxl' ? 1024 : 512,
f: baseModel === 'sdxl' ? 512 : 256,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
depth_anything_image_processor: {
type: 'depth_anything_image_processor',
labelTKey: 'controlnet.depthAnything',
descriptionTKey: 'controlnet.depthAnythingDescription',
buildDefaults: () => ({
id: getId('depth_anything_image_processor'),
type: 'depth_anything_image_processor',
model_size: 'small',
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
resolution: minDim(image),
}),
},
hed_image_processor: {
type: 'hed_image_processor',
labelTKey: 'controlnet.hed',
descriptionTKey: 'controlnet.hedDescription',
buildDefaults: () => ({
id: getId('hed_image_processor'),
type: 'hed_image_processor',
scribble: false,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
lineart_anime_image_processor: {
type: 'lineart_anime_image_processor',
labelTKey: 'controlnet.lineartAnime',
descriptionTKey: 'controlnet.lineartAnimeDescription',
buildDefaults: () => ({
id: getId('lineart_anime_image_processor'),
type: 'lineart_anime_image_processor',
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
lineart_image_processor: {
type: 'lineart_image_processor',
labelTKey: 'controlnet.lineart',
descriptionTKey: 'controlnet.lineartDescription',
buildDefaults: () => ({
id: getId('lineart_image_processor'),
type: 'lineart_image_processor',
coarse: false,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
mediapipe_face_processor: {
type: 'mediapipe_face_processor',
labelTKey: 'controlnet.mediapipeFace',
descriptionTKey: 'controlnet.mediapipeFaceDescription',
buildDefaults: () => ({
id: getId('mediapipe_face_processor'),
type: 'mediapipe_face_processor',
max_faces: 1,
min_confidence: 0.5,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
midas_depth_image_processor: {
type: 'midas_depth_image_processor',
labelTKey: 'controlnet.depthMidas',
descriptionTKey: 'controlnet.depthMidasDescription',
buildDefaults: () => ({
id: getId('midas_depth_image_processor'),
type: 'midas_depth_image_processor',
a_mult: 2,
bg_th: 0.1,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
mlsd_image_processor: {
type: 'mlsd_image_processor',
labelTKey: 'controlnet.mlsd',
descriptionTKey: 'controlnet.mlsdDescription',
buildDefaults: () => ({
id: getId('mlsd_image_processor'),
type: 'mlsd_image_processor',
thr_d: 0.1,
thr_v: 0.1,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
normalbae_image_processor: {
type: 'normalbae_image_processor',
labelTKey: 'controlnet.normalBae',
descriptionTKey: 'controlnet.normalBaeDescription',
buildDefaults: () => ({
id: getId('normalbae_image_processor'),
type: 'normalbae_image_processor',
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
dw_openpose_image_processor: {
type: 'dw_openpose_image_processor',
labelTKey: 'controlnet.dwOpenpose',
descriptionTKey: 'controlnet.dwOpenposeDescription',
buildDefaults: () => ({
id: getId('dw_openpose_image_processor'),
type: 'dw_openpose_image_processor',
draw_body: true,
draw_face: false,
draw_hands: false,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
image_resolution: minDim(image),
}),
},
pidi_image_processor: {
type: 'pidi_image_processor',
labelTKey: 'controlnet.pidi',
descriptionTKey: 'controlnet.pidiDescription',
buildDefaults: () => ({
id: getId('pidi_image_processor'),
type: 'pidi_image_processor',
scribble: false,
safe: false,
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
detect_resolution: minDim(image),
image_resolution: minDim(image),
}),
},
zoe_depth_image_processor: {
type: 'zoe_depth_image_processor',
labelTKey: 'controlnet.depthZoe',
descriptionTKey: 'controlnet.depthZoeDescription',
buildDefaults: () => ({
id: getId('zoe_depth_image_processor'),
type: 'zoe_depth_image_processor',
}),
buildNode: (image, config) => ({
...config,
image: { image_name: image.imageName },
}),
},
};
export const initialControlNet: Omit<ControlNetConfig, 'id'> = { export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
type: 'controlnet', type: 'controlnet',
model: null, model: null,
@ -303,6 +415,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
controlMode: 'balanced', controlMode: 'balanced',
image: null, image: null,
processedImage: null, processedImage: null,
isProcessingImage: false,
processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(), processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(),
}; };
@ -313,6 +426,7 @@ export const initialT2IAdapter: Omit<T2IAdapterConfig, 'id'> = {
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
image: null, image: null,
processedImage: null, processedImage: null,
isProcessingImage: false,
processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(), processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(),
}; };