From c96b98fc9ed75723fbd01885d5a62a1d1b287321 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 2 May 2024 08:18:34 +1000 Subject: [PATCH] feat(ui): auto-process for control layer CAs --- .../middleware/listenerMiddleware/index.ts | 2 + .../listeners/controlAdapterPreprocessor.ts | 147 ++++++ .../components/CALayer/CALayer.tsx | 4 +- .../CALayer/CALayerControlAdapterWrapper.tsx | 4 +- .../ControlAndIPAdapter/ControlAdapter.tsx | 10 +- .../ControlAdapterImagePreview.tsx | 39 +- .../IPALayer/IPALayerIPAdapterWrapper.tsx | 4 +- .../RGLayer/RGLayerIPAdapterWrapper.tsx | 4 +- .../controlLayers/store/controlLayersSlice.ts | 79 ++-- .../controlLayers/util/controlAdapters.ts | 444 +++++++++++------- 10 files changed, 495 insertions(+), 242 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index cd0c1290e9..36040b5e41 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -16,6 +16,7 @@ import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listen import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet'; import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged'; import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery'; +import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor'; import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess'; import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed'; import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas'; @@ -157,3 +158,4 @@ addUpscaleRequestedListener(startAppListening); addDynamicPromptsListener(startAppListening); addSetDefaultSettingsListener(startAppListening); +addControlAdapterPreprocessor(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts new file mode 100644 index 0000000000..9cb7efe572 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts @@ -0,0 +1,147 @@ +import { isAnyOf } from '@reduxjs/toolkit'; +import { logger } from 'app/logging/logger'; +import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { parseify } from 'common/util/serialize'; +import { + caLayerImageChanged, + caLayerIsProcessingImageChanged, + caLayerModelChanged, + caLayerProcessedImageChanged, + caLayerProcessorConfigChanged, + isControlAdapterLayer, +} from 'features/controlLayers/store/controlLayersSlice'; +import { CONTROLNET_PROCESSORS } from 'features/controlLayers/util/controlAdapters'; +import { isImageOutput } from 'features/nodes/types/common'; +import { addToast } from 'features/system/store/systemSlice'; +import { t } from 'i18next'; +import { imagesApi } from 'services/api/endpoints/images'; +import { queueApi } from 'services/api/endpoints/queue'; +import type { BatchConfig, ImageDTO } from 'services/api/types'; +import { socketInvocationComplete } from 'services/events/actions'; + +const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged); + +const DEBOUNCE_MS = 300; +const log = logger('session'); + +export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => { + startAppListening({ + matcher, + effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take }) => { + const { layerId } = action.payload; + const precheckLayer = getState() + .controlLayers.present.layers.filter(isControlAdapterLayer) + .find((l) => l.id === layerId); + + // Conditions to bail + if ( + // Layer doesn't exist + !precheckLayer || + // Layer doesn't have an image + !precheckLayer.controlAdapter.image || + // Layer doesn't have a processor config + !precheckLayer.controlAdapter.processorConfig || + // Layer is already processing an image + precheckLayer.controlAdapter.isProcessingImage + ) { + return; + } + + // Cancel any in-progress instances of this listener + cancelActiveListeners(); + log.trace('Control Layer CA auto-process triggered'); + + // Delay before starting actual work + await delay(DEBOUNCE_MS); + dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: true })); + + // Double-check that we are still eligible for processing + const state = getState(); + const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId); + const image = layer?.controlAdapter.image; + const config = layer?.controlAdapter.processorConfig; + + // If we have no image or there is no processor config, bail + if (!layer || !image || !config) { + return; + } + + // @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error... + const processorNode = CONTROLNET_PROCESSORS[config.type].buildNode(image, config); + const enqueueBatchArg: BatchConfig = { + prepend: true, + batch: { + graph: { + nodes: { + [processorNode.id]: { ...processorNode, is_intermediate: true }, + }, + edges: [], + }, + runs: 1, + }, + }; + + try { + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, { + fixedCacheKey: 'enqueueBatch', + }) + ); + const enqueueResult = await req.unwrap(); + req.reset(); + log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued')); + + const [invocationCompleteAction] = await take( + (action): action is ReturnType => + socketInvocationComplete.match(action) && + action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && + action.payload.data.source_node_id === processorNode.id + ); + + // We still have to check the output type + if (isImageOutput(invocationCompleteAction.payload.data.result)) { + const { image_name } = invocationCompleteAction.payload.data.result.image; + + // Wait for the ImageDTO to be received + const [{ payload }] = await take( + (action) => + imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name + ); + + const imageDTO = payload as ImageDTO; + + log.debug({ layerId, imageDTO }, 'ControlNet image processed'); + + // Update the processed image in the store + dispatch( + caLayerProcessedImageChanged({ + layerId, + imageDTO, + }) + ); + dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false })); + } + } catch (error) { + console.log(error); + log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue')); + dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false })); + + if (error instanceof Object) { + if ('data' in error && 'status' in error) { + if (error.status === 403) { + dispatch(caLayerImageChanged({ layerId, imageDTO: null })); + return; + } + } + } + + dispatch( + addToast({ + title: t('queue.graphFailedToQueue'), + status: 'error', + }) + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx index f9edf42c2f..24de817df2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx @@ -5,7 +5,7 @@ import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; -import { layerSelected, selectCALayer } from 'features/controlLayers/store/controlLayersSlice'; +import { layerSelected, selectCALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice'; import { memo, useCallback } from 'react'; import CALayerOpacity from './CALayerOpacity'; @@ -16,7 +16,7 @@ type Props = { export const CALayer = memo(({ layerId }: Props) => { const dispatch = useAppDispatch(); - const isSelected = useAppSelector((s) => selectCALayer(s.controlLayers.present, layerId).isSelected); + const isSelected = useAppSelector((s) => selectCALayerOrThrow(s.controlLayers.present, layerId).isSelected); const onClickCapture = useCallback(() => { // Must be capture so that the layer is selected before deleting/resetting/etc dispatch(layerSelected(layerId)); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx index 2a2edeb8d8..6793a33f69 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayerControlAdapterWrapper.tsx @@ -7,7 +7,7 @@ import { caLayerProcessorConfigChanged, caOrIPALayerBeginEndStepPctChanged, caOrIPALayerWeightChanged, - selectCALayer, + selectCALayerOrThrow, } from 'features/controlLayers/store/controlLayersSlice'; import type { ControlMode, ProcessorConfig } from 'features/controlLayers/util/controlAdapters'; import type { CALayerImageDropData } from 'features/dnd/types'; @@ -25,7 +25,7 @@ type Props = { export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => { const dispatch = useAppDispatch(); - const controlAdapter = useAppSelector((s) => selectCALayer(s.controlLayers.present, layerId).controlAdapter); + const controlAdapter = useAppSelector((s) => selectCALayerOrThrow(s.controlLayers.present, layerId).controlAdapter); const onChangeBeginEndStepPct = useCallback( (beginEndStepPct: [number, number]) => { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx index 972198cc7e..ecdfa46ef6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapter.tsx @@ -87,11 +87,8 @@ export const ControlAdapter = memo( @@ -99,7 +96,10 @@ export const ControlAdapter = memo( {isExpanded && ( <> - + )} diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx index e4f53c1c70..7def6b2b56 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlAndIPAdapter/ControlAdapterImagePreview.tsx @@ -1,14 +1,12 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; -import { selectControlAdaptersSlice } from 'features/controlAdapters/store/controlAdaptersSlice'; import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice'; -import type { ImageWithDims } from 'features/controlLayers/util/controlAdapters'; +import type { ControlNetConfig, T2IAdapterConfig } from 'features/controlLayers/util/controlAdapters'; import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; @@ -25,46 +23,29 @@ import { import type { ImageDTO, PostUploadAction } from 'services/api/types'; type Props = { - controlAdapterId: string; - image: ImageWithDims | null; - processedImage: ImageWithDims | null; + controlAdapter: ControlNetConfig | T2IAdapterConfig; onChangeImage: (imageDTO: ImageDTO | null) => void; - hasProcessor: boolean; droppableData: TypesafeDroppableData; postUploadAction: PostUploadAction; }; -const selectPendingControlImages = createMemoizedSelector( - selectControlAdaptersSlice, - (controlAdapters) => controlAdapters.pendingControlImages -); - export const ControlAdapterImagePreview = memo( - ({ - image, - processedImage, - onChangeImage, - hasProcessor, - controlAdapterId, - droppableData, - postUploadAction, - }: Props) => { + ({ controlAdapter, onChangeImage, droppableData, postUploadAction }: Props) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); const isConnected = useAppSelector((s) => s.system.isConnected); const activeTabName = useAppSelector(activeTabNameSelector); const optimalDimension = useAppSelector(selectOptimalDimension); - const pendingControlImages = useAppSelector(selectPendingControlImages); const shift = useShiftModifier(); const [isMouseOverImage, setIsMouseOverImage] = useState(false); const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery( - image?.imageName ?? skipToken + controlAdapter.image?.imageName ?? skipToken ); const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery( - processedImage?.imageName ?? skipToken + controlAdapter.processedImage?.imageName ?? skipToken ); const [changeIsIntermediate] = useChangeImageIsIntermediateMutation(); @@ -130,19 +111,19 @@ export const ControlAdapterImagePreview = memo( const draggableData = useMemo(() => { if (controlImage) { return { - id: controlAdapterId, + id: controlAdapter.id, payloadType: 'IMAGE_DTO', payload: { imageDTO: controlImage }, }; } - }, [controlImage, controlAdapterId]); + }, [controlImage, controlAdapter.id]); const shouldShowProcessedImage = controlImage && processedControlImage && !isMouseOverImage && - !pendingControlImages.includes(controlAdapterId) && - hasProcessor; + !controlAdapter.isProcessingImage && + controlAdapter.processorConfig !== null; useEffect(() => { if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) { @@ -207,7 +188,7 @@ export const ControlAdapterImagePreview = memo( /> - {pendingControlImages.includes(controlAdapterId) && ( + {controlAdapter.isProcessingImage && ( { const dispatch = useAppDispatch(); - const ipAdapter = useAppSelector((s) => selectIPALayer(s.controlLayers.present, layerId).ipAdapter); + const ipAdapter = useAppSelector((s) => selectIPALayerOrThrow(s.controlLayers.present, layerId).ipAdapter); const onChangeBeginEndStepPct = useCallback( (beginEndStepPct: [number, number]) => { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerIPAdapterWrapper.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerIPAdapterWrapper.tsx index cc8b0698a5..015cf75e4d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerIPAdapterWrapper.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerIPAdapterWrapper.tsx @@ -9,7 +9,7 @@ import { rgLayerIPAdapterMethodChanged, rgLayerIPAdapterModelChanged, rgLayerIPAdapterWeightChanged, - selectRGLayerIPAdapter, + selectRGLayerIPAdapterOrThrow, } from 'features/controlLayers/store/controlLayersSlice'; import type { CLIPVisionModel, IPMethod } from 'features/controlLayers/util/controlAdapters'; import type { RGLayerIPAdapterImageDropData } from 'features/dnd/types'; @@ -28,7 +28,7 @@ export const RGLayerIPAdapterWrapper = memo(({ layerId, ipAdapterId, ipAdapterNu const onDeleteIPAdapter = useCallback(() => { dispatch(rgLayerIPAdapterDeleted({ layerId, ipAdapterId })); }, [dispatch, ipAdapterId, layerId]); - const ipAdapter = useAppSelector((s) => selectRGLayerIPAdapter(s.controlLayers.present, layerId, ipAdapterId)); + const ipAdapter = useAppSelector((s) => selectRGLayerIPAdapterOrThrow(s.controlLayers.present, layerId, ipAdapterId)); const onChangeBeginEndStepPct = useCallback( (beginEndStepPct: [number, number]) => { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts index 92fe9d0119..fa179a3bbd 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts @@ -78,17 +78,17 @@ const resetLayer = (layer: Layer) => { } }; -export const selectCALayer = (state: ControlLayersState, layerId: string): ControlAdapterLayer => { +export const selectCALayerOrThrow = (state: ControlLayersState, layerId: string): ControlAdapterLayer => { const layer = state.layers.find((l) => l.id === layerId); assert(isControlAdapterLayer(layer)); return layer; }; -export const selectIPALayer = (state: ControlLayersState, layerId: string): IPAdapterLayer => { +export const selectIPALayerOrThrow = (state: ControlLayersState, layerId: string): IPAdapterLayer => { const layer = state.layers.find((l) => l.id === layerId); assert(isIPAdapterLayer(layer)); return layer; }; -export const selectCAOrIPALayer = ( +export const selectCAOrIPALayerOrThrow = ( state: ControlLayersState, layerId: string ): ControlAdapterLayer | IPAdapterLayer => { @@ -96,12 +96,12 @@ export const selectCAOrIPALayer = ( assert(isControlAdapterLayer(layer) || isIPAdapterLayer(layer)); return layer; }; -export const selectRGLayer = (state: ControlLayersState, layerId: string): RegionalGuidanceLayer => { +export const selectRGLayerOrThrow = (state: ControlLayersState, layerId: string): RegionalGuidanceLayer => { const layer = state.layers.find((l) => l.id === layerId); assert(isRegionalGuidanceLayer(layer)); return layer; }; -export const selectRGLayerIPAdapter = ( +export const selectRGLayerIPAdapterOrThrow = ( state: ControlLayersState, layerId: string, ipAdapterId: string @@ -246,7 +246,7 @@ export const controlLayersSlice = createSlice({ }, caLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { const { layerId, imageDTO } = action.payload; - const layer = selectCALayer(state, layerId); + const layer = selectCALayerOrThrow(state, layerId); layer.bbox = null; layer.bboxNeedsUpdate = true; layer.isEnabled = true; @@ -255,7 +255,7 @@ export const controlLayersSlice = createSlice({ }, caLayerProcessedImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { const { layerId, imageDTO } = action.payload; - const layer = selectCALayer(state, layerId); + const layer = selectCALayerOrThrow(state, layerId); layer.bbox = null; layer.bboxNeedsUpdate = true; layer.isEnabled = true; @@ -269,7 +269,7 @@ export const controlLayersSlice = createSlice({ }> ) => { const { layerId, modelConfig } = action.payload; - const layer = selectCALayer(state, layerId); + const layer = selectCALayerOrThrow(state, layerId); if (!modelConfig) { layer.controlAdapter.model = null; return; @@ -285,7 +285,7 @@ export const controlLayersSlice = createSlice({ }, caLayerControlModeChanged: (state, action: PayloadAction<{ layerId: string; controlMode: ControlMode }>) => { const { layerId, controlMode } = action.payload; - const layer = selectCALayer(state, layerId); + const layer = selectCALayerOrThrow(state, layerId); assert(layer.controlAdapter.type === 'controlnet'); layer.controlAdapter.controlMode = controlMode; }, @@ -294,19 +294,27 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; processorConfig: ProcessorConfig | null }> ) => { const { layerId, processorConfig } = action.payload; - const layer = selectCALayer(state, layerId); + const layer = selectCALayerOrThrow(state, layerId); layer.controlAdapter.processorConfig = processorConfig; }, caLayerIsFilterEnabledChanged: (state, action: PayloadAction<{ layerId: string; isFilterEnabled: boolean }>) => { const { layerId, isFilterEnabled } = action.payload; - const layer = selectCALayer(state, layerId); + const layer = selectCALayerOrThrow(state, layerId); layer.isFilterEnabled = isFilterEnabled; }, caLayerOpacityChanged: (state, action: PayloadAction<{ layerId: string; opacity: number }>) => { const { layerId, opacity } = action.payload; - const layer = selectCALayer(state, layerId); + const layer = selectCALayerOrThrow(state, layerId); layer.opacity = opacity; }, + caLayerIsProcessingImageChanged: ( + state, + action: PayloadAction<{ layerId: string; isProcessingImage: boolean }> + ) => { + const { layerId, isProcessingImage } = action.payload; + const layer = selectCALayerOrThrow(state, layerId); + layer.controlAdapter.isProcessingImage = isProcessingImage; + }, //#endregion //#region IP Adapter Layers @@ -325,12 +333,12 @@ export const controlLayersSlice = createSlice({ }, ipaLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { const { layerId, imageDTO } = action.payload; - const layer = selectIPALayer(state, layerId); + const layer = selectIPALayerOrThrow(state, layerId); layer.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; }, ipaLayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => { const { layerId, weight } = action.payload; - const layer = selectIPALayer(state, layerId); + const layer = selectIPALayerOrThrow(state, layerId); layer.ipAdapter.weight = weight; }, ipaLayerBeginEndStepPctChanged: ( @@ -338,12 +346,12 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }> ) => { const { layerId, beginEndStepPct } = action.payload; - const layer = selectIPALayer(state, layerId); + const layer = selectIPALayerOrThrow(state, layerId); layer.ipAdapter.beginEndStepPct = beginEndStepPct; }, ipaLayerMethodChanged: (state, action: PayloadAction<{ layerId: string; method: IPMethod }>) => { const { layerId, method } = action.payload; - const layer = selectIPALayer(state, layerId); + const layer = selectIPALayerOrThrow(state, layerId); layer.ipAdapter.method = method; }, ipaLayerModelChanged: ( @@ -354,7 +362,7 @@ export const controlLayersSlice = createSlice({ }> ) => { const { layerId, modelConfig } = action.payload; - const layer = selectIPALayer(state, layerId); + const layer = selectIPALayerOrThrow(state, layerId); if (!modelConfig) { layer.ipAdapter.model = null; return; @@ -366,7 +374,7 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; clipVisionModel: CLIPVisionModel }> ) => { const { layerId, clipVisionModel } = action.payload; - const layer = selectIPALayer(state, layerId); + const layer = selectIPALayerOrThrow(state, layerId); layer.ipAdapter.clipVisionModel = clipVisionModel; }, //#endregion @@ -374,7 +382,7 @@ export const controlLayersSlice = createSlice({ //#region CA or IPA Layers caOrIPALayerWeightChanged: (state, action: PayloadAction<{ layerId: string; weight: number }>) => { const { layerId, weight } = action.payload; - const layer = selectCAOrIPALayer(state, layerId); + const layer = selectCAOrIPALayerOrThrow(state, layerId); if (layer.type === 'control_adapter_layer') { layer.controlAdapter.weight = weight; } else { @@ -386,7 +394,7 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; beginEndStepPct: [number, number] }> ) => { const { layerId, beginEndStepPct } = action.payload; - const layer = selectCAOrIPALayer(state, layerId); + const layer = selectCAOrIPALayerOrThrow(state, layerId); if (layer.type === 'control_adapter_layer') { layer.controlAdapter.beginEndStepPct = beginEndStepPct; } else { @@ -428,17 +436,17 @@ export const controlLayersSlice = createSlice({ }, rgLayerPositivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { const { layerId, prompt } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); layer.positivePrompt = prompt; }, rgLayerNegativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string | null }>) => { const { layerId, prompt } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); layer.negativePrompt = prompt; }, rgLayerPreviewColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => { const { layerId, color } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); layer.previewColor = color; }, rgLayerLineAdded: { @@ -452,7 +460,7 @@ export const controlLayersSlice = createSlice({ }> ) => { const { layerId, points, tool, lineUuid } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); const lineId = getRGLayerLineId(layer.id, lineUuid); layer.maskObjects.push({ type: 'vector_mask_line', @@ -474,7 +482,7 @@ export const controlLayersSlice = createSlice({ }, rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => { const { layerId, point } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); const lastLine = layer.maskObjects.findLast(isLine); if (!lastLine) { return; @@ -491,7 +499,7 @@ export const controlLayersSlice = createSlice({ // Ignore zero-area rectangles return; } - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); const id = getRGLayerRectId(layer.id, rectUuid); layer.maskObjects.push({ type: 'vector_mask_rect', @@ -510,17 +518,17 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }> ) => { const { layerId, autoNegative } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); layer.autoNegative = autoNegative; }, rgLayerIPAdapterAdded: (state, action: PayloadAction<{ layerId: string; ipAdapter: IPAdapterConfig }>) => { const { layerId, ipAdapter } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); layer.ipAdapters.push(ipAdapter); }, rgLayerIPAdapterDeleted: (state, action: PayloadAction<{ layerId: string; ipAdapterId: string }>) => { const { layerId, ipAdapterId } = action.payload; - const layer = selectRGLayer(state, layerId); + const layer = selectRGLayerOrThrow(state, layerId); layer.ipAdapters = layer.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId); }, rgLayerIPAdapterImageChanged: ( @@ -528,7 +536,7 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; ipAdapterId: string; imageDTO: ImageDTO | null }> ) => { const { layerId, ipAdapterId, imageDTO } = action.payload; - const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); + const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId); ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; }, rgLayerIPAdapterWeightChanged: ( @@ -536,7 +544,7 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; ipAdapterId: string; weight: number }> ) => { const { layerId, ipAdapterId, weight } = action.payload; - const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); + const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId); ipAdapter.weight = weight; }, rgLayerIPAdapterBeginEndStepPctChanged: ( @@ -544,7 +552,7 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; ipAdapterId: string; beginEndStepPct: [number, number] }> ) => { const { layerId, ipAdapterId, beginEndStepPct } = action.payload; - const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); + const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId); ipAdapter.beginEndStepPct = beginEndStepPct; }, rgLayerIPAdapterMethodChanged: ( @@ -552,7 +560,7 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; ipAdapterId: string; method: IPMethod }> ) => { const { layerId, ipAdapterId, method } = action.payload; - const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); + const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId); ipAdapter.method = method; }, rgLayerIPAdapterModelChanged: ( @@ -564,7 +572,7 @@ export const controlLayersSlice = createSlice({ }> ) => { const { layerId, ipAdapterId, modelConfig } = action.payload; - const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); + const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId); if (!modelConfig) { ipAdapter.model = null; return; @@ -576,7 +584,7 @@ export const controlLayersSlice = createSlice({ action: PayloadAction<{ layerId: string; ipAdapterId: string; clipVisionModel: CLIPVisionModel }> ) => { const { layerId, ipAdapterId, clipVisionModel } = action.payload; - const ipAdapter = selectRGLayerIPAdapter(state, layerId, ipAdapterId); + const ipAdapter = selectRGLayerIPAdapterOrThrow(state, layerId, ipAdapterId); ipAdapter.clipVisionModel = clipVisionModel; }, //#endregion @@ -720,6 +728,7 @@ export const { caLayerProcessorConfigChanged, caLayerIsFilterEnabledChanged, caLayerOpacityChanged, + caLayerIsProcessingImageChanged, // IPA Layers ipaLayerAdded, ipaLayerImageChanged, diff --git a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts index 0417c707e4..261cfd2f85 100644 --- a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts @@ -13,6 +13,7 @@ import type { ControlNetModelConfig, DepthAnythingImageProcessorInvocation, DWOpenposeImageProcessorInvocation, + Graph, HedImageProcessorInvocation, ImageDTO, LineartAnimeImageProcessorInvocation, @@ -34,27 +35,33 @@ export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSiz zDepthAnythingModelSize.safeParse(v).success; export type CannyProcessorConfig = Required< - Pick + Pick +>; +export type ColorMapProcessorConfig = Required< + Pick >; -export type ColorMapProcessorConfig = Required>; export type ContentShuffleProcessorConfig = Required< - Pick + Pick >; -export type DepthAnythingProcessorConfig = Required>; -export type HedProcessorConfig = Required>; -export type LineartAnimeProcessorConfig = Required>; -export type LineartProcessorConfig = Required>; +export type DepthAnythingProcessorConfig = Required< + Pick +>; +export type HedProcessorConfig = Required>; +export type LineartAnimeProcessorConfig = Required>; +export type LineartProcessorConfig = Required>; export type MediapipeFaceProcessorConfig = Required< - Pick + Pick >; -export type MidasDepthProcessorConfig = Required>; -export type MlsdProcessorConfig = Required>; -export type NormalbaeProcessorConfig = Required>; +export type MidasDepthProcessorConfig = Required< + Pick +>; +export type MlsdProcessorConfig = Required>; +export type NormalbaeProcessorConfig = Required>; export type DWOpenposeProcessorConfig = Required< - Pick + Pick >; -export type PidiProcessorConfig = Required>; -export type ZoeDepthProcessorConfig = Required>; +export type PidiProcessorConfig = Required>; +export type ZoeDepthProcessorConfig = Required>; export type ProcessorConfig = | CannyProcessorConfig @@ -83,6 +90,7 @@ type ControlAdapterBase = { weight: number; image: ImageWithDims | null; processedImage: ImageWithDims | null; + isProcessingImage: boolean; processorConfig: ProcessorConfig | null; beginEndStepPct: [number, number]; }; @@ -125,157 +133,6 @@ export type IPAdapterConfig = { beginEndStepPct: [number, number]; }; -type ProcessorData = { - labelTKey: string; - descriptionTKey: string; - buildDefaults(baseModel?: BaseModelType): Extract; -}; - -type ControlNetProcessorsDict = { - [key in ProcessorConfig['type']]: ProcessorData; -}; -/** - * A dict of ControlNet processors, including: - * - label translation key - * - description translation key - * - a builder to create default values for the config - * - * TODO: Generate from the OpenAPI schema - */ -export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = { - canny_image_processor: { - labelTKey: 'controlnet.canny', - descriptionTKey: 'controlnet.cannyDescription', - buildDefaults: () => ({ - id: `canny_image_processor_${uuidv4()}`, - type: 'canny_image_processor', - low_threshold: 100, - high_threshold: 200, - }), - }, - color_map_image_processor: { - labelTKey: 'controlnet.colorMap', - descriptionTKey: 'controlnet.colorMapDescription', - buildDefaults: () => ({ - id: `color_map_image_processor_${uuidv4()}`, - type: 'color_map_image_processor', - color_map_tile_size: 64, - }), - }, - content_shuffle_image_processor: { - labelTKey: 'controlnet.contentShuffle', - descriptionTKey: 'controlnet.contentShuffleDescription', - buildDefaults: (baseModel) => ({ - id: `content_shuffle_image_processor_${uuidv4()}`, - type: 'content_shuffle_image_processor', - h: baseModel === 'sdxl' ? 1024 : 512, - w: baseModel === 'sdxl' ? 1024 : 512, - f: baseModel === 'sdxl' ? 512 : 256, - }), - }, - depth_anything_image_processor: { - labelTKey: 'controlnet.depthAnything', - descriptionTKey: 'controlnet.depthAnythingDescription', - buildDefaults: () => ({ - id: `depth_anything_image_processor_${uuidv4()}`, - type: 'depth_anything_image_processor', - model_size: 'small', - }), - }, - hed_image_processor: { - labelTKey: 'controlnet.hed', - descriptionTKey: 'controlnet.hedDescription', - buildDefaults: () => ({ - id: `hed_image_processor_${uuidv4()}`, - type: 'hed_image_processor', - scribble: false, - }), - }, - lineart_anime_image_processor: { - labelTKey: 'controlnet.lineartAnime', - descriptionTKey: 'controlnet.lineartAnimeDescription', - buildDefaults: () => ({ - id: `lineart_anime_image_processor_${uuidv4()}`, - type: 'lineart_anime_image_processor', - }), - }, - lineart_image_processor: { - labelTKey: 'controlnet.lineart', - descriptionTKey: 'controlnet.lineartDescription', - buildDefaults: () => ({ - id: `lineart_image_processor_${uuidv4()}`, - type: 'lineart_image_processor', - coarse: false, - }), - }, - mediapipe_face_processor: { - labelTKey: 'controlnet.mediapipeFace', - descriptionTKey: 'controlnet.mediapipeFaceDescription', - buildDefaults: () => ({ - id: `mediapipe_face_processor_${uuidv4()}`, - type: 'mediapipe_face_processor', - max_faces: 1, - min_confidence: 0.5, - }), - }, - midas_depth_image_processor: { - labelTKey: 'controlnet.depthMidas', - descriptionTKey: 'controlnet.depthMidasDescription', - buildDefaults: () => ({ - id: `midas_depth_image_processor_${uuidv4()}`, - type: 'midas_depth_image_processor', - a_mult: 2, - bg_th: 0.1, - }), - }, - mlsd_image_processor: { - labelTKey: 'controlnet.mlsd', - descriptionTKey: 'controlnet.mlsdDescription', - buildDefaults: () => ({ - id: `mlsd_image_processor_${uuidv4()}`, - type: 'mlsd_image_processor', - thr_d: 0.1, - thr_v: 0.1, - }), - }, - normalbae_image_processor: { - labelTKey: 'controlnet.normalBae', - descriptionTKey: 'controlnet.normalBaeDescription', - buildDefaults: () => ({ - id: `normalbae_image_processor_${uuidv4()}`, - type: 'normalbae_image_processor', - }), - }, - dw_openpose_image_processor: { - labelTKey: 'controlnet.dwOpenpose', - descriptionTKey: 'controlnet.dwOpenposeDescription', - buildDefaults: () => ({ - id: `dw_openpose_image_processor_${uuidv4()}`, - type: 'dw_openpose_image_processor', - draw_body: true, - draw_face: false, - draw_hands: false, - }), - }, - pidi_image_processor: { - labelTKey: 'controlnet.pidi', - descriptionTKey: 'controlnet.pidiDescription', - buildDefaults: () => ({ - id: `pidi_image_processor_${uuidv4()}`, - type: 'pidi_image_processor', - scribble: false, - safe: false, - }), - }, - zoe_depth_image_processor: { - labelTKey: 'controlnet.depthZoe', - descriptionTKey: 'controlnet.depthZoeDescription', - buildDefaults: () => ({ - id: `zoe_depth_image_processor_${uuidv4()}`, - type: 'zoe_depth_image_processor', - }), - }, -}; export const zProcessorType = z.enum([ 'canny_image_processor', 'color_map_image_processor', @@ -295,6 +152,261 @@ export const zProcessorType = z.enum([ export type ProcessorType = z.infer; export const isProcessorType = (v: unknown): v is ProcessorType => zProcessorType.safeParse(v).success; +type ProcessorData = { + type: T; + labelTKey: string; + descriptionTKey: string; + buildDefaults(baseModel?: BaseModelType): Extract; + buildNode( + image: ImageWithDims, + config: Extract + ): Extract; +}; + +const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height); +const getId = (type: ProcessorType): string => `${type}_${uuidv4()}`; + +type CAProcessorsData = { + [key in ProcessorType]: ProcessorData; +}; +/** + * A dict of ControlNet processors, including: + * - label translation key + * - description translation key + * - a builder to create default values for the config + * - a builder to create the node for the config + * + * TODO: Generate from the OpenAPI schema + */ +export const CONTROLNET_PROCESSORS: CAProcessorsData = { + canny_image_processor: { + type: 'canny_image_processor', + labelTKey: 'controlnet.canny', + descriptionTKey: 'controlnet.cannyDescription', + buildDefaults: () => ({ + id: getId('canny_image_processor'), + type: 'canny_image_processor', + low_threshold: 100, + high_threshold: 200, + }), + buildNode: (image, config) => ({ + ...config, + type: 'canny_image_processor', + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + color_map_image_processor: { + type: 'color_map_image_processor', + labelTKey: 'controlnet.colorMap', + descriptionTKey: 'controlnet.colorMapDescription', + buildDefaults: () => ({ + id: getId('color_map_image_processor'), + type: 'color_map_image_processor', + color_map_tile_size: 64, + }), + buildNode: (image, config) => ({ + ...config, + type: 'color_map_image_processor', + image: { image_name: image.imageName }, + }), + }, + content_shuffle_image_processor: { + type: 'content_shuffle_image_processor', + labelTKey: 'controlnet.contentShuffle', + descriptionTKey: 'controlnet.contentShuffleDescription', + buildDefaults: (baseModel) => ({ + id: getId('content_shuffle_image_processor'), + type: 'content_shuffle_image_processor', + h: baseModel === 'sdxl' ? 1024 : 512, + w: baseModel === 'sdxl' ? 1024 : 512, + f: baseModel === 'sdxl' ? 512 : 256, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + depth_anything_image_processor: { + type: 'depth_anything_image_processor', + labelTKey: 'controlnet.depthAnything', + descriptionTKey: 'controlnet.depthAnythingDescription', + buildDefaults: () => ({ + id: getId('depth_anything_image_processor'), + type: 'depth_anything_image_processor', + model_size: 'small', + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + resolution: minDim(image), + }), + }, + hed_image_processor: { + type: 'hed_image_processor', + labelTKey: 'controlnet.hed', + descriptionTKey: 'controlnet.hedDescription', + buildDefaults: () => ({ + id: getId('hed_image_processor'), + type: 'hed_image_processor', + scribble: false, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + lineart_anime_image_processor: { + type: 'lineart_anime_image_processor', + labelTKey: 'controlnet.lineartAnime', + descriptionTKey: 'controlnet.lineartAnimeDescription', + buildDefaults: () => ({ + id: getId('lineart_anime_image_processor'), + type: 'lineart_anime_image_processor', + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + lineart_image_processor: { + type: 'lineart_image_processor', + labelTKey: 'controlnet.lineart', + descriptionTKey: 'controlnet.lineartDescription', + buildDefaults: () => ({ + id: getId('lineart_image_processor'), + type: 'lineart_image_processor', + coarse: false, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + mediapipe_face_processor: { + type: 'mediapipe_face_processor', + labelTKey: 'controlnet.mediapipeFace', + descriptionTKey: 'controlnet.mediapipeFaceDescription', + buildDefaults: () => ({ + id: getId('mediapipe_face_processor'), + type: 'mediapipe_face_processor', + max_faces: 1, + min_confidence: 0.5, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + midas_depth_image_processor: { + type: 'midas_depth_image_processor', + labelTKey: 'controlnet.depthMidas', + descriptionTKey: 'controlnet.depthMidasDescription', + buildDefaults: () => ({ + id: getId('midas_depth_image_processor'), + type: 'midas_depth_image_processor', + a_mult: 2, + bg_th: 0.1, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + mlsd_image_processor: { + type: 'mlsd_image_processor', + labelTKey: 'controlnet.mlsd', + descriptionTKey: 'controlnet.mlsdDescription', + buildDefaults: () => ({ + id: getId('mlsd_image_processor'), + type: 'mlsd_image_processor', + thr_d: 0.1, + thr_v: 0.1, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + normalbae_image_processor: { + type: 'normalbae_image_processor', + labelTKey: 'controlnet.normalBae', + descriptionTKey: 'controlnet.normalBaeDescription', + buildDefaults: () => ({ + id: getId('normalbae_image_processor'), + type: 'normalbae_image_processor', + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + dw_openpose_image_processor: { + type: 'dw_openpose_image_processor', + labelTKey: 'controlnet.dwOpenpose', + descriptionTKey: 'controlnet.dwOpenposeDescription', + buildDefaults: () => ({ + id: getId('dw_openpose_image_processor'), + type: 'dw_openpose_image_processor', + draw_body: true, + draw_face: false, + draw_hands: false, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + image_resolution: minDim(image), + }), + }, + pidi_image_processor: { + type: 'pidi_image_processor', + labelTKey: 'controlnet.pidi', + descriptionTKey: 'controlnet.pidiDescription', + buildDefaults: () => ({ + id: getId('pidi_image_processor'), + type: 'pidi_image_processor', + scribble: false, + safe: false, + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + detect_resolution: minDim(image), + image_resolution: minDim(image), + }), + }, + zoe_depth_image_processor: { + type: 'zoe_depth_image_processor', + labelTKey: 'controlnet.depthZoe', + descriptionTKey: 'controlnet.depthZoeDescription', + buildDefaults: () => ({ + id: getId('zoe_depth_image_processor'), + type: 'zoe_depth_image_processor', + }), + buildNode: (image, config) => ({ + ...config, + image: { image_name: image.imageName }, + }), + }, +}; + export const initialControlNet: Omit = { type: 'controlnet', model: null, @@ -303,6 +415,7 @@ export const initialControlNet: Omit = { controlMode: 'balanced', image: null, processedImage: null, + isProcessingImage: false, processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(), }; @@ -313,6 +426,7 @@ export const initialT2IAdapter: Omit = { beginEndStepPct: [0, 1], image: null, processedImage: null, + isProcessingImage: false, processorConfig: CONTROLNET_PROCESSORS.canny_image_processor.buildDefaults(), };