diff --git a/invokeai/frontend/web/src/features/controlLayers/store/bboxReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/bboxReducers.ts deleted file mode 100644 index af328d944d..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/bboxReducers.ts +++ /dev/null @@ -1,119 +0,0 @@ -import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; -import { deepClone } from 'common/util/deepClone'; -import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple'; -import type { BoundingBoxScaleMethod, CanvasState, Dimensions } from 'features/controlLayers/store/types'; -import { getScaledBoundingBoxDimensions } from 'features/controlLayers/util/getScaledBoundingBoxDimensions'; -import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize'; -import { ASPECT_RATIO_MAP, initialAspectRatioState } from 'features/parameters/components/DocumentSize/constants'; -import type { AspectRatioID } from 'features/parameters/components/DocumentSize/types'; -import type { IRect } from 'konva/lib/types'; - -const syncScaledSize = (state: CanvasState) => { - if (state.bbox.scaleMethod === 'auto') { - const { width, height } = state.bbox.rect; - state.bbox.scaledSize = getScaledBoundingBoxDimensions({ width, height }, state.bbox.optimalDimension); - } -}; - -export const bboxReducers = { - bboxScaledSizeChanged: (state, action: PayloadAction>) => { - state.bbox.scaledSize = { ...state.bbox.scaledSize, ...action.payload }; - }, - bboxScaleMethodChanged: (state, action: PayloadAction) => { - state.bbox.scaleMethod = action.payload; - syncScaledSize(state); - }, - bboxChanged: (state, action: PayloadAction) => { - state.bbox.rect = action.payload; - syncScaledSize(state); - }, - bboxWidthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => { - const { width, updateAspectRatio, clamp } = action.payload; - state.bbox.rect.width = clamp ? Math.max(roundDownToMultiple(width, 8), 64) : width; - - if (state.bbox.aspectRatio.isLocked) { - state.bbox.rect.height = roundToMultiple(state.bbox.rect.width / state.bbox.aspectRatio.value, 8); - } - - if (updateAspectRatio || !state.bbox.aspectRatio.isLocked) { - state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height; - state.bbox.aspectRatio.id = 'Free'; - state.bbox.aspectRatio.isLocked = false; - } - - syncScaledSize(state); - }, - bboxHeightChanged: ( - state, - action: PayloadAction<{ height: number; updateAspectRatio?: boolean; clamp?: boolean }> - ) => { - const { height, updateAspectRatio, clamp } = action.payload; - - state.bbox.rect.height = clamp ? Math.max(roundDownToMultiple(height, 8), 64) : height; - - if (state.bbox.aspectRatio.isLocked) { - state.bbox.rect.width = roundToMultiple(state.bbox.rect.height * state.bbox.aspectRatio.value, 8); - } - - if (updateAspectRatio || !state.bbox.aspectRatio.isLocked) { - state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height; - state.bbox.aspectRatio.id = 'Free'; - state.bbox.aspectRatio.isLocked = false; - } - - syncScaledSize(state); - }, - bboxAspectRatioLockToggled: (state) => { - state.bbox.aspectRatio.isLocked = !state.bbox.aspectRatio.isLocked; - }, - bboxAspectRatioIdChanged: (state, action: PayloadAction<{ id: AspectRatioID }>) => { - const { id } = action.payload; - state.bbox.aspectRatio.id = id; - if (id === 'Free') { - state.bbox.aspectRatio.isLocked = false; - } else { - state.bbox.aspectRatio.isLocked = true; - state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio; - const { width, height } = calculateNewSize( - state.bbox.aspectRatio.value, - state.bbox.rect.width * state.bbox.rect.height - ); - state.bbox.rect.width = width; - state.bbox.rect.height = height; - } - - syncScaledSize(state); - }, - bboxDimensionsSwapped: (state) => { - state.bbox.aspectRatio.value = 1 / state.bbox.aspectRatio.value; - if (state.bbox.aspectRatio.id === 'Free') { - const newWidth = state.bbox.rect.height; - const newHeight = state.bbox.rect.width; - state.bbox.rect.width = newWidth; - state.bbox.rect.height = newHeight; - } else { - const { width, height } = calculateNewSize( - state.bbox.aspectRatio.value, - state.bbox.rect.width * state.bbox.rect.height - ); - state.bbox.rect.width = width; - state.bbox.rect.height = height; - state.bbox.aspectRatio.id = ASPECT_RATIO_MAP[state.bbox.aspectRatio.id].inverseID; - } - - syncScaledSize(state); - }, - bboxSizeOptimized: (state) => { - if (state.bbox.aspectRatio.isLocked) { - const { width, height } = calculateNewSize(state.bbox.aspectRatio.value, state.bbox.optimalDimension ** 2); - state.bbox.rect.width = width; - state.bbox.rect.height = height; - } else { - state.bbox.aspectRatio = deepClone(initialAspectRatioState); - state.bbox.rect.width = state.bbox.optimalDimension; - state.bbox.rect.height = state.bbox.optimalDimension; - } - - syncScaledSize(state); - }, -} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 189f6e32c3..9a466c70ea 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -3,34 +3,83 @@ import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig } from 'app/store/store'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { deepClone } from 'common/util/deepClone'; +import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple'; import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { bboxReducers } from 'features/controlLayers/store/bboxReducers'; -import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers'; -import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers'; -import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers'; import { modelChanged } from 'features/controlLayers/store/paramsSlice'; -import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers'; -import { regionsReducers } from 'features/controlLayers/store/regionsReducers'; -import { selectAllEntities, selectAllEntitiesOfType, selectEntity } from 'features/controlLayers/store/selectors'; +import { + selectAllEntities, + selectAllEntitiesOfType, + selectEntity, + selectRegionalGuidanceIPAdapter, +} from 'features/controlLayers/store/selectors'; +import type { + CanvasInpaintMaskState, + FillStyle, + RegionalGuidanceIPAdapterConfig, + RgbColor, +} from 'features/controlLayers/store/types'; import { getScaledBoundingBoxDimensions } from 'features/controlLayers/util/getScaledBoundingBoxDimensions'; import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize'; -import { initialAspectRatioState } from 'features/parameters/components/DocumentSize/constants'; +import { ASPECT_RATIO_MAP, initialAspectRatioState } from 'features/parameters/components/DocumentSize/constants'; +import type { AspectRatioID } from 'features/parameters/components/DocumentSize/types'; import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; -import { pick } from 'lodash-es'; +import type { IRect } from 'konva/lib/types'; +import { isEqual, merge, omit } from 'lodash-es'; +import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import type { + BoundingBoxScaleMethod, + CanvasControlLayerState, CanvasEntityIdentifier, + CanvasIPAdapterState, + CanvasRasterLayerState, + CanvasRegionalGuidanceState, CanvasState, + CLIPVisionModelV2, + ControlModeV2, + ControlNetConfig, + Dimensions, EntityBrushLineAddedPayload, EntityEraserLineAddedPayload, EntityIdentifierPayload, EntityMovedPayload, EntityRasterizedPayload, EntityRectAddedPayload, + IPMethodV2, + T2IAdapterConfig, } from './types'; -import { getEntityIdentifier, isDrawableEntity } from './types'; +import { + getEntityIdentifier, + imageDTOToImageWithDims, + initialControlNet, + initialIPAdapter, + isDrawableEntity, +} from './types'; + +const DEFAULT_MASK_COLORS: RgbColor[] = [ + { r: 121, g: 157, b: 219 }, // rgb(121, 157, 219) + { r: 131, g: 214, b: 131 }, // rgb(131, 214, 131) + { r: 250, g: 225, b: 80 }, // rgb(250, 225, 80) + { r: 220, g: 144, b: 101 }, // rgb(220, 144, 101) + { r: 224, g: 117, b: 117 }, // rgb(224, 117, 117) + { r: 213, g: 139, b: 202 }, // rgb(213, 139, 202) + { r: 161, g: 120, b: 214 }, // rgb(161, 120, 214) +]; + +const getRGMaskFill = (state: CanvasState): RgbColor => { + const lastFill = state.regions.entities.slice(-1)[0]?.fill.color; + let i = DEFAULT_MASK_COLORS.findIndex((c) => isEqual(c, lastFill)); + if (i === -1) { + i = 0; + } + i = (i + 1) % DEFAULT_MASK_COLORS.length; + const fill = DEFAULT_MASK_COLORS[i]; + assert(fill, 'This should never happen'); + return fill; +}; const initialState: CanvasState = { _version: 3, @@ -69,12 +118,667 @@ export const canvasSlice = createSlice({ initialState, reducers: { // undoable canvas state - ...rasterLayersReducers, - ...controlLayersReducers, - ...ipAdaptersReducers, - ...regionsReducers, - ...inpaintMaskReducers, - ...bboxReducers, + //#region Raster layers + rasterLayerAdded: { + reducer: ( + state, + action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> + ) => { + const { id, overrides, isSelected } = action.payload; + const entity: CanvasRasterLayerState = { + id, + name: null, + type: 'raster_layer', + isEnabled: true, + objects: [], + opacity: 1, + position: { x: 0, y: 0 }, + }; + merge(entity, overrides); + state.rasterLayers.entities.push(entity); + if (isSelected) { + state.selectedEntityIdentifier = getEntityIdentifier(entity); + } + }, + prepare: (payload: { overrides?: Partial; isSelected?: boolean }) => ({ + payload: { ...payload, id: getPrefixedId('raster_layer') }, + }), + }, + rasterLayerRecalled: (state, action: PayloadAction<{ data: CanvasRasterLayerState }>) => { + const { data } = action.payload; + state.rasterLayers.entities.push(data); + state.selectedEntityIdentifier = getEntityIdentifier(data); + }, + rasterLayerConvertedToControlLayer: { + reducer: (state, action: PayloadAction>) => { + const { entityIdentifier, newId } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + + // Convert the raster layer to control layer + const controlLayerState: CanvasControlLayerState = { + ...deepClone(layer), + id: newId, + type: 'control_layer', + controlAdapter: deepClone(initialControlNet), + withTransparencyEffect: true, + }; + + // Remove the raster layer + state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id); + + // Add the converted control layer + state.controlLayers.entities.push(controlLayerState); + + state.selectedEntityIdentifier = { type: controlLayerState.type, id: controlLayerState.id }; + }, + prepare: (payload: EntityIdentifierPayload) => ({ + payload: { ...payload, newId: getPrefixedId('control_layer') }, + }), + }, + //#region Control layers + controlLayerAdded: { + reducer: ( + state, + action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> + ) => { + const { id, overrides, isSelected } = action.payload; + const entity: CanvasControlLayerState = { + id, + name: null, + type: 'control_layer', + isEnabled: true, + withTransparencyEffect: true, + objects: [], + opacity: 1, + position: { x: 0, y: 0 }, + controlAdapter: deepClone(initialControlNet), + }; + merge(entity, overrides); + state.controlLayers.entities.push(entity); + if (isSelected) { + state.selectedEntityIdentifier = getEntityIdentifier(entity); + } + }, + prepare: (payload: { overrides?: Partial; isSelected?: boolean }) => ({ + payload: { ...payload, id: getPrefixedId('control_layer') }, + }), + }, + controlLayerRecalled: (state, action: PayloadAction<{ data: CanvasControlLayerState }>) => { + const { data } = action.payload; + state.controlLayers.entities.push(data); + state.selectedEntityIdentifier = { type: 'control_layer', id: data.id }; + }, + controlLayerConvertedToRasterLayer: { + reducer: (state, action: PayloadAction>) => { + const { entityIdentifier, newId } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + + // Convert the raster layer to control layer + const rasterLayerState: CanvasRasterLayerState = { + ...omit(deepClone(layer), ['type', 'controlAdapter', 'withTransparencyEffect']), + id: newId, + type: 'raster_layer', + }; + + // Remove the control layer + state.controlLayers.entities = state.controlLayers.entities.filter((layer) => layer.id !== entityIdentifier.id); + + // Add the new raster layer + state.rasterLayers.entities.push(rasterLayerState); + + state.selectedEntityIdentifier = { type: rasterLayerState.type, id: rasterLayerState.id }; + }, + prepare: (payload: EntityIdentifierPayload) => ({ + payload: { ...payload, newId: getPrefixedId('raster_layer') }, + }), + }, + controlLayerModelChanged: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null; + }, + 'control_layer' + > + > + ) => { + const { entityIdentifier, modelConfig } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer || !layer.controlAdapter) { + return; + } + if (!modelConfig) { + layer.controlAdapter.model = null; + return; + } + layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig); + + // We may need to convert the CA to match the model + if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') { + // Converting from T2I Adapter to ControlNet - add `controlMode` + const controlNetConfig: ControlNetConfig = { + ...layer.controlAdapter, + type: 'controlnet', + controlMode: 'balanced', + }; + layer.controlAdapter = controlNetConfig; + } else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') { + // Converting from ControlNet to T2I Adapter - remove `controlMode` + const { controlMode: _, ...rest } = layer.controlAdapter; + const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' }; + layer.controlAdapter = t2iAdapterConfig; + } + }, + controlLayerControlModeChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, controlMode } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') { + return; + } + layer.controlAdapter.controlMode = controlMode; + }, + controlLayerWeightChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, weight } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer || !layer.controlAdapter) { + return; + } + layer.controlAdapter.weight = weight; + }, + controlLayerBeginEndStepPctChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, beginEndStepPct } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer || !layer.controlAdapter) { + return; + } + layer.controlAdapter.beginEndStepPct = beginEndStepPct; + }, + controlLayerWithTransparencyEffectToggled: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + layer.withTransparencyEffect = !layer.withTransparencyEffect; + }, + //#region IP Adapters + ipaAdded: { + reducer: ( + state, + action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> + ) => { + const { id, overrides, isSelected } = action.payload; + const entity: CanvasIPAdapterState = { + id, + type: 'ip_adapter', + name: null, + isEnabled: true, + ipAdapter: deepClone(initialIPAdapter), + }; + merge(entity, overrides); + state.ipAdapters.entities.push(entity); + if (isSelected) { + state.selectedEntityIdentifier = getEntityIdentifier(entity); + } + }, + prepare: (payload?: { overrides?: Partial; isSelected?: boolean }) => ({ + payload: { ...payload, id: getPrefixedId('ip_adapter') }, + }), + }, + ipaRecalled: (state, action: PayloadAction<{ data: CanvasIPAdapterState }>) => { + const { data } = action.payload; + state.ipAdapters.entities.push(data); + state.selectedEntityIdentifier = { type: 'ip_adapter', id: data.id }; + }, + ipaImageChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, imageDTO } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; + }, + ipaMethodChanged: (state, action: PayloadAction>) => { + const { entityIdentifier, method } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.ipAdapter.method = method; + }, + ipaModelChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, modelConfig } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null; + }, + ipaCLIPVisionModelChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, clipVisionModel } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.ipAdapter.clipVisionModel = clipVisionModel; + }, + ipaWeightChanged: (state, action: PayloadAction>) => { + const { entityIdentifier, weight } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.ipAdapter.weight = weight; + }, + ipaBeginEndStepPctChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, beginEndStepPct } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.ipAdapter.beginEndStepPct = beginEndStepPct; + }, + //#region Regional Guidance + rgAdded: { + reducer: ( + state, + action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> + ) => { + const { id, overrides, isSelected } = action.payload; + const entity: CanvasRegionalGuidanceState = { + id, + name: null, + type: 'regional_guidance', + isEnabled: true, + objects: [], + fill: { + style: 'solid', + color: getRGMaskFill(state), + }, + opacity: 0.5, + position: { x: 0, y: 0 }, + autoNegative: true, + positivePrompt: '', + negativePrompt: null, + ipAdapters: [], + }; + merge(entity, overrides); + state.regions.entities.push(entity); + if (isSelected) { + state.selectedEntityIdentifier = getEntityIdentifier(entity); + } + }, + prepare: (payload?: { overrides?: Partial; isSelected?: boolean }) => ({ + payload: { ...payload, id: getPrefixedId('regional_guidance') }, + }), + }, + rgRecalled: (state, action: PayloadAction<{ data: CanvasRegionalGuidanceState }>) => { + const { data } = action.payload; + state.regions.entities.push(data); + state.selectedEntityIdentifier = { type: 'regional_guidance', id: data.id }; + }, + rgPositivePromptChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, prompt } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.positivePrompt = prompt; + }, + rgNegativePromptChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, prompt } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.negativePrompt = prompt; + }, + rgFillColorChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, color } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.fill.color = color; + }, + rgFillStyleChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, style } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.fill.style = style; + }, + + rgAutoNegativeToggled: (state, action: PayloadAction>) => { + const { entityIdentifier } = action.payload; + const rg = selectEntity(state, entityIdentifier); + if (!rg) { + return; + } + rg.autoNegative = !rg.autoNegative; + }, + rgIPAdapterAdded: { + reducer: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { ipAdapterId: string; overrides?: Partial }, + 'regional_guidance' + > + > + ) => { + const { entityIdentifier, overrides, ipAdapterId } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + const ipAdapter = { ...deepClone(initialIPAdapter), id: ipAdapterId }; + merge(ipAdapter, overrides); + entity.ipAdapters.push(ipAdapter); + }, + prepare: ( + payload: EntityIdentifierPayload<{ overrides?: Partial }, 'regional_guidance'> + ) => ({ + payload: { ...payload, ipAdapterId: getPrefixedId('regional_guidance_ip_adapter') }, + }), + }, + rgIPAdapterDeleted: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, ipAdapterId } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.ipAdapters = entity.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId); + }, + rgIPAdapterImageChanged: ( + state, + action: PayloadAction< + EntityIdentifierPayload<{ ipAdapterId: string; imageDTO: ImageDTO | null }, 'regional_guidance'> + > + ) => { + const { entityIdentifier, ipAdapterId, imageDTO } = action.payload; + const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); + if (!ipAdapter) { + return; + } + ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; + }, + rgIPAdapterWeightChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, ipAdapterId, weight } = action.payload; + const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); + if (!ipAdapter) { + return; + } + ipAdapter.weight = weight; + }, + rgIPAdapterBeginEndStepPctChanged: ( + state, + action: PayloadAction< + EntityIdentifierPayload<{ ipAdapterId: string; beginEndStepPct: [number, number] }, 'regional_guidance'> + > + ) => { + const { entityIdentifier, ipAdapterId, beginEndStepPct } = action.payload; + const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); + if (!ipAdapter) { + return; + } + ipAdapter.beginEndStepPct = beginEndStepPct; + }, + rgIPAdapterMethodChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, ipAdapterId, method } = action.payload; + const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); + if (!ipAdapter) { + return; + } + ipAdapter.method = method; + }, + rgIPAdapterModelChanged: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + ipAdapterId: string; + modelConfig: IPAdapterModelConfig | null; + }, + 'regional_guidance' + > + > + ) => { + const { entityIdentifier, ipAdapterId, modelConfig } = action.payload; + const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); + if (!ipAdapter) { + return; + } + ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null; + }, + rgIPAdapterCLIPVisionModelChanged: ( + state, + action: PayloadAction< + EntityIdentifierPayload<{ ipAdapterId: string; clipVisionModel: CLIPVisionModelV2 }, 'regional_guidance'> + > + ) => { + const { entityIdentifier, ipAdapterId, clipVisionModel } = action.payload; + const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); + if (!ipAdapter) { + return; + } + ipAdapter.clipVisionModel = clipVisionModel; + }, + //#region Inpaint mask + inpaintMaskAdded: { + reducer: ( + state, + action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> + ) => { + const { id, overrides, isSelected } = action.payload; + const entity: CanvasInpaintMaskState = { + id, + name: null, + type: 'inpaint_mask', + isEnabled: true, + objects: [], + opacity: 1, + position: { x: 0, y: 0 }, + fill: { + style: 'diagonal', + color: { r: 255, g: 122, b: 0 }, // some orange color + }, + }; + merge(entity, overrides); + state.inpaintMasks.entities.push(entity); + if (isSelected) { + state.selectedEntityIdentifier = getEntityIdentifier(entity); + } + }, + prepare: (payload?: { overrides?: Partial; isSelected?: boolean }) => ({ + payload: { ...payload, id: getPrefixedId('inpaint_mask') }, + }), + }, + inpaintMaskRecalled: (state, action: PayloadAction<{ data: CanvasInpaintMaskState }>) => { + const { data } = action.payload; + state.inpaintMasks.entities = [data]; + state.selectedEntityIdentifier = { type: 'inpaint_mask', id: data.id }; + }, + inpaintMaskFillColorChanged: ( + state, + action: PayloadAction> + ) => { + const { color, entityIdentifier } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.fill.color = color; + }, + inpaintMaskFillStyleChanged: ( + state, + action: PayloadAction> + ) => { + const { style, entityIdentifier } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + entity.fill.style = style; + }, + //#region BBox + bboxScaledSizeChanged: (state, action: PayloadAction>) => { + state.bbox.scaledSize = { ...state.bbox.scaledSize, ...action.payload }; + }, + bboxScaleMethodChanged: (state, action: PayloadAction) => { + state.bbox.scaleMethod = action.payload; + syncScaledSize(state); + }, + bboxChanged: (state, action: PayloadAction) => { + state.bbox.rect = action.payload; + syncScaledSize(state); + }, + bboxWidthChanged: ( + state, + action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }> + ) => { + const { width, updateAspectRatio, clamp } = action.payload; + state.bbox.rect.width = clamp ? Math.max(roundDownToMultiple(width, 8), 64) : width; + + if (state.bbox.aspectRatio.isLocked) { + state.bbox.rect.height = roundToMultiple(state.bbox.rect.width / state.bbox.aspectRatio.value, 8); + } + + if (updateAspectRatio || !state.bbox.aspectRatio.isLocked) { + state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height; + state.bbox.aspectRatio.id = 'Free'; + state.bbox.aspectRatio.isLocked = false; + } + + syncScaledSize(state); + }, + bboxHeightChanged: ( + state, + action: PayloadAction<{ height: number; updateAspectRatio?: boolean; clamp?: boolean }> + ) => { + const { height, updateAspectRatio, clamp } = action.payload; + + state.bbox.rect.height = clamp ? Math.max(roundDownToMultiple(height, 8), 64) : height; + + if (state.bbox.aspectRatio.isLocked) { + state.bbox.rect.width = roundToMultiple(state.bbox.rect.height * state.bbox.aspectRatio.value, 8); + } + + if (updateAspectRatio || !state.bbox.aspectRatio.isLocked) { + state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height; + state.bbox.aspectRatio.id = 'Free'; + state.bbox.aspectRatio.isLocked = false; + } + + syncScaledSize(state); + }, + bboxAspectRatioLockToggled: (state) => { + state.bbox.aspectRatio.isLocked = !state.bbox.aspectRatio.isLocked; + }, + bboxAspectRatioIdChanged: (state, action: PayloadAction<{ id: AspectRatioID }>) => { + const { id } = action.payload; + state.bbox.aspectRatio.id = id; + if (id === 'Free') { + state.bbox.aspectRatio.isLocked = false; + } else { + state.bbox.aspectRatio.isLocked = true; + state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio; + const { width, height } = calculateNewSize( + state.bbox.aspectRatio.value, + state.bbox.rect.width * state.bbox.rect.height + ); + state.bbox.rect.width = width; + state.bbox.rect.height = height; + } + + syncScaledSize(state); + }, + bboxDimensionsSwapped: (state) => { + state.bbox.aspectRatio.value = 1 / state.bbox.aspectRatio.value; + if (state.bbox.aspectRatio.id === 'Free') { + const newWidth = state.bbox.rect.height; + const newHeight = state.bbox.rect.width; + state.bbox.rect.width = newWidth; + state.bbox.rect.height = newHeight; + } else { + const { width, height } = calculateNewSize( + state.bbox.aspectRatio.value, + state.bbox.rect.width * state.bbox.rect.height + ); + state.bbox.rect.width = width; + state.bbox.rect.height = height; + state.bbox.aspectRatio.id = ASPECT_RATIO_MAP[state.bbox.aspectRatio.id].inverseID; + } + + syncScaledSize(state); + }, + bboxSizeOptimized: (state) => { + if (state.bbox.aspectRatio.isLocked) { + const { width, height } = calculateNewSize(state.bbox.aspectRatio.value, state.bbox.optimalDimension ** 2); + state.bbox.rect.width = width; + state.bbox.rect.height = height; + } else { + state.bbox.aspectRatio = deepClone(initialAspectRatioState); + state.bbox.rect.width = state.bbox.optimalDimension; + state.bbox.rect.height = state.bbox.optimalDimension; + } + + syncScaledSize(state); + }, + //#region Shared entity entitySelected: (state, action: PayloadAction) => { const { entityIdentifier } = action.payload; state.selectedEntityIdentifier = entityIdentifier; @@ -322,19 +1026,11 @@ export const canvasSlice = createSlice({ state.selectedEntityIdentifier = deepClone(initialState.selectedEntityIdentifier); }, canvasReset: (state) => { - state.bbox = deepClone(initialState.bbox); - state.bbox.rect.width = state.bbox.optimalDimension; - state.bbox.rect.height = state.bbox.optimalDimension; - const size = pick(state.bbox.rect, 'width', 'height'); - state.bbox.scaledSize = getScaledBoundingBoxDimensions(size, state.bbox.optimalDimension); - - state.ipAdapters = deepClone(initialState.ipAdapters); - state.rasterLayers = deepClone(initialState.rasterLayers); - state.controlLayers = deepClone(initialState.controlLayers); - state.regions = deepClone(initialState.regions); - state.inpaintMasks = deepClone(initialState.inpaintMasks); - - state.selectedEntityIdentifier = deepClone(initialState.selectedEntityIdentifier); + const { width, height } = state.bbox.rect; + const scaledSize = getScaledBoundingBoxDimensions({ width, height }, state.bbox.optimalDimension); + const newState = deepClone(initialState); + newState.bbox.scaledSize = scaledSize; + return newState; }, }, extraReducers(builder) { @@ -451,3 +1147,10 @@ export const canvasPersistConfig: PersistConfig = { migrate, persistDenylist: [], }; + +const syncScaledSize = (state: CanvasState) => { + if (state.bbox.scaleMethod === 'auto') { + const { width, height } = state.bbox.rect; + state.bbox.scaledSize = getScaledBoundingBoxDimensions({ width, height }, state.bbox.optimalDimension); + } +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersReducers.ts deleted file mode 100644 index 3563f7e546..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersReducers.ts +++ /dev/null @@ -1,162 +0,0 @@ -import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; -import { deepClone } from 'common/util/deepClone'; -import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { selectEntity } from 'features/controlLayers/store/selectors'; -import { zModelIdentifierField } from 'features/nodes/types/common'; -import { merge, omit } from 'lodash-es'; -import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; - -import type { - CanvasControlLayerState, - CanvasRasterLayerState, - CanvasState, - ControlModeV2, - ControlNetConfig, - EntityIdentifierPayload, - T2IAdapterConfig, -} from './types'; -import { getEntityIdentifier, initialControlNet } from './types'; - -export const controlLayersReducers = { - controlLayerAdded: { - reducer: ( - state, - action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> - ) => { - const { id, overrides, isSelected } = action.payload; - const entity: CanvasControlLayerState = { - id, - name: null, - type: 'control_layer', - isEnabled: true, - withTransparencyEffect: true, - objects: [], - opacity: 1, - position: { x: 0, y: 0 }, - controlAdapter: deepClone(initialControlNet), - }; - merge(entity, overrides); - state.controlLayers.entities.push(entity); - if (isSelected) { - state.selectedEntityIdentifier = getEntityIdentifier(entity); - } - }, - prepare: (payload: { overrides?: Partial; isSelected?: boolean }) => ({ - payload: { ...payload, id: getPrefixedId('control_layer') }, - }), - }, - controlLayerRecalled: (state, action: PayloadAction<{ data: CanvasControlLayerState }>) => { - const { data } = action.payload; - state.controlLayers.entities.push(data); - state.selectedEntityIdentifier = { type: 'control_layer', id: data.id }; - }, - controlLayerConvertedToRasterLayer: { - reducer: (state, action: PayloadAction>) => { - const { entityIdentifier, newId } = action.payload; - const layer = selectEntity(state, entityIdentifier); - if (!layer) { - return; - } - - // Convert the raster layer to control layer - const rasterLayerState: CanvasRasterLayerState = { - ...omit(deepClone(layer), ['type', 'controlAdapter', 'withTransparencyEffect']), - id: newId, - type: 'raster_layer', - }; - - // Remove the control layer - state.controlLayers.entities = state.controlLayers.entities.filter((layer) => layer.id !== entityIdentifier.id); - - // Add the new raster layer - state.rasterLayers.entities.push(rasterLayerState); - - state.selectedEntityIdentifier = { type: rasterLayerState.type, id: rasterLayerState.id }; - }, - prepare: (payload: EntityIdentifierPayload) => ({ - payload: { ...payload, newId: getPrefixedId('raster_layer') }, - }), - }, - controlLayerModelChanged: ( - state, - action: PayloadAction< - EntityIdentifierPayload< - { - modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null; - }, - 'control_layer' - > - > - ) => { - const { entityIdentifier, modelConfig } = action.payload; - const layer = selectEntity(state, entityIdentifier); - if (!layer || !layer.controlAdapter) { - return; - } - if (!modelConfig) { - layer.controlAdapter.model = null; - return; - } - layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig); - - // We may need to convert the CA to match the model - if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') { - // Converting from T2I Adapter to ControlNet - add `controlMode` - const controlNetConfig: ControlNetConfig = { - ...layer.controlAdapter, - type: 'controlnet', - controlMode: 'balanced', - }; - layer.controlAdapter = controlNetConfig; - } else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') { - // Converting from ControlNet to T2I Adapter - remove `controlMode` - const { controlMode: _, ...rest } = layer.controlAdapter; - const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' }; - layer.controlAdapter = t2iAdapterConfig; - } - }, - controlLayerControlModeChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, controlMode } = action.payload; - const layer = selectEntity(state, entityIdentifier); - if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') { - return; - } - layer.controlAdapter.controlMode = controlMode; - }, - controlLayerWeightChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, weight } = action.payload; - const layer = selectEntity(state, entityIdentifier); - if (!layer || !layer.controlAdapter) { - return; - } - layer.controlAdapter.weight = weight; - }, - controlLayerBeginEndStepPctChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, beginEndStepPct } = action.payload; - const layer = selectEntity(state, entityIdentifier); - if (!layer || !layer.controlAdapter) { - return; - } - layer.controlAdapter.beginEndStepPct = beginEndStepPct; - }, - controlLayerWithTransparencyEffectToggled: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier } = action.payload; - const layer = selectEntity(state, entityIdentifier); - if (!layer) { - return; - } - layer.withTransparencyEffect = !layer.withTransparencyEffect; - }, -} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/inpaintMaskReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/inpaintMaskReducers.ts deleted file mode 100644 index 6004ccfc36..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/inpaintMaskReducers.ts +++ /dev/null @@ -1,71 +0,0 @@ -import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; -import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { selectEntity } from 'features/controlLayers/store/selectors'; -import type { - CanvasInpaintMaskState, - CanvasState, - EntityIdentifierPayload, - FillStyle, - RgbColor, -} from 'features/controlLayers/store/types'; -import { getEntityIdentifier } from 'features/controlLayers/store/types'; -import { merge } from 'lodash-es'; - -export const inpaintMaskReducers = { - inpaintMaskAdded: { - reducer: ( - state, - action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> - ) => { - const { id, overrides, isSelected } = action.payload; - const entity: CanvasInpaintMaskState = { - id, - name: null, - type: 'inpaint_mask', - isEnabled: true, - objects: [], - opacity: 1, - position: { x: 0, y: 0 }, - fill: { - style: 'diagonal', - color: { r: 255, g: 122, b: 0 }, // some orange color - }, - }; - merge(entity, overrides); - state.inpaintMasks.entities.push(entity); - if (isSelected) { - state.selectedEntityIdentifier = getEntityIdentifier(entity); - } - }, - prepare: (payload?: { overrides?: Partial; isSelected?: boolean }) => ({ - payload: { ...payload, id: getPrefixedId('inpaint_mask') }, - }), - }, - inpaintMaskRecalled: (state, action: PayloadAction<{ data: CanvasInpaintMaskState }>) => { - const { data } = action.payload; - state.inpaintMasks.entities = [data]; - state.selectedEntityIdentifier = { type: 'inpaint_mask', id: data.id }; - }, - inpaintMaskFillColorChanged: ( - state, - action: PayloadAction> - ) => { - const { color, entityIdentifier } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.fill.color = color; - }, - inpaintMaskFillStyleChanged: ( - state, - action: PayloadAction> - ) => { - const { style, entityIdentifier } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.fill.style = style; - }, -} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/ipAdaptersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/ipAdaptersReducers.ts deleted file mode 100644 index fcadb3480e..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/ipAdaptersReducers.ts +++ /dev/null @@ -1,107 +0,0 @@ -import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; -import { deepClone } from 'common/util/deepClone'; -import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { selectEntity } from 'features/controlLayers/store/selectors'; -import { zModelIdentifierField } from 'features/nodes/types/common'; -import { merge } from 'lodash-es'; -import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types'; - -import type { - CanvasIPAdapterState, - CanvasState, - CLIPVisionModelV2, - EntityIdentifierPayload, - IPMethodV2, -} from './types'; -import { getEntityIdentifier, imageDTOToImageWithDims, initialIPAdapter } from './types'; - -export const ipAdaptersReducers = { - ipaAdded: { - reducer: ( - state, - action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> - ) => { - const { id, overrides, isSelected } = action.payload; - const entity: CanvasIPAdapterState = { - id, - type: 'ip_adapter', - name: null, - isEnabled: true, - ipAdapter: deepClone(initialIPAdapter), - }; - merge(entity, overrides); - state.ipAdapters.entities.push(entity); - if (isSelected) { - state.selectedEntityIdentifier = getEntityIdentifier(entity); - } - }, - prepare: (payload?: { overrides?: Partial; isSelected?: boolean }) => ({ - payload: { ...payload, id: getPrefixedId('ip_adapter') }, - }), - }, - ipaRecalled: (state, action: PayloadAction<{ data: CanvasIPAdapterState }>) => { - const { data } = action.payload; - state.ipAdapters.entities.push(data); - state.selectedEntityIdentifier = { type: 'ip_adapter', id: data.id }; - }, - ipaImageChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, imageDTO } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; - }, - ipaMethodChanged: (state, action: PayloadAction>) => { - const { entityIdentifier, method } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.ipAdapter.method = method; - }, - ipaModelChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, modelConfig } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null; - }, - ipaCLIPVisionModelChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, clipVisionModel } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.ipAdapter.clipVisionModel = clipVisionModel; - }, - ipaWeightChanged: (state, action: PayloadAction>) => { - const { entityIdentifier, weight } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.ipAdapter.weight = weight; - }, - ipaBeginEndStepPctChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, beginEndStepPct } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.ipAdapter.beginEndStepPct = beginEndStepPct; - }, -} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/rasterLayersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/rasterLayersReducers.ts deleted file mode 100644 index 25ea31f9cc..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/rasterLayersReducers.ts +++ /dev/null @@ -1,70 +0,0 @@ -import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; -import { deepClone } from 'common/util/deepClone'; -import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { selectEntity } from 'features/controlLayers/store/selectors'; -import { merge } from 'lodash-es'; - -import type { CanvasControlLayerState, CanvasRasterLayerState, CanvasState, EntityIdentifierPayload } from './types'; -import { getEntityIdentifier, initialControlNet } from './types'; - -export const rasterLayersReducers = { - rasterLayerAdded: { - reducer: ( - state, - action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> - ) => { - const { id, overrides, isSelected } = action.payload; - const entity: CanvasRasterLayerState = { - id, - name: null, - type: 'raster_layer', - isEnabled: true, - objects: [], - opacity: 1, - position: { x: 0, y: 0 }, - }; - merge(entity, overrides); - state.rasterLayers.entities.push(entity); - if (isSelected) { - state.selectedEntityIdentifier = getEntityIdentifier(entity); - } - }, - prepare: (payload: { overrides?: Partial; isSelected?: boolean }) => ({ - payload: { ...payload, id: getPrefixedId('raster_layer') }, - }), - }, - rasterLayerRecalled: (state, action: PayloadAction<{ data: CanvasRasterLayerState }>) => { - const { data } = action.payload; - state.rasterLayers.entities.push(data); - state.selectedEntityIdentifier = getEntityIdentifier(data); - }, - rasterLayerConvertedToControlLayer: { - reducer: (state, action: PayloadAction>) => { - const { entityIdentifier, newId } = action.payload; - const layer = selectEntity(state, entityIdentifier); - if (!layer) { - return; - } - - // Convert the raster layer to control layer - const controlLayerState: CanvasControlLayerState = { - ...deepClone(layer), - id: newId, - type: 'control_layer', - controlAdapter: deepClone(initialControlNet), - withTransparencyEffect: true, - }; - - // Remove the raster layer - state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id); - - // Add the converted control layer - state.controlLayers.entities.push(controlLayerState); - - state.selectedEntityIdentifier = { type: controlLayerState.type, id: controlLayerState.id }; - }, - prepare: (payload: EntityIdentifierPayload) => ({ - payload: { ...payload, newId: getPrefixedId('control_layer') }, - }), - }, -} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts deleted file mode 100644 index 87687fb395..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts +++ /dev/null @@ -1,252 +0,0 @@ -import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; -import { deepClone } from 'common/util/deepClone'; -import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { selectEntity, selectRegionalGuidanceIPAdapter } from 'features/controlLayers/store/selectors'; -import type { - CanvasState, - CLIPVisionModelV2, - EntityIdentifierPayload, - FillStyle, - IPMethodV2, - RegionalGuidanceIPAdapterConfig, - RgbColor, -} from 'features/controlLayers/store/types'; -import { getEntityIdentifier, imageDTOToImageWithDims, initialIPAdapter } from 'features/controlLayers/store/types'; -import { zModelIdentifierField } from 'features/nodes/types/common'; -import { isEqual, merge } from 'lodash-es'; -import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types'; -import { assert } from 'tsafe'; - -import type { CanvasRegionalGuidanceState } from './types'; - -const DEFAULT_MASK_COLORS: RgbColor[] = [ - { r: 121, g: 157, b: 219 }, // rgb(121, 157, 219) - { r: 131, g: 214, b: 131 }, // rgb(131, 214, 131) - { r: 250, g: 225, b: 80 }, // rgb(250, 225, 80) - { r: 220, g: 144, b: 101 }, // rgb(220, 144, 101) - { r: 224, g: 117, b: 117 }, // rgb(224, 117, 117) - { r: 213, g: 139, b: 202 }, // rgb(213, 139, 202) - { r: 161, g: 120, b: 214 }, // rgb(161, 120, 214) -]; - -const getRGMaskFill = (state: CanvasState): RgbColor => { - const lastFill = state.regions.entities.slice(-1)[0]?.fill.color; - let i = DEFAULT_MASK_COLORS.findIndex((c) => isEqual(c, lastFill)); - if (i === -1) { - i = 0; - } - i = (i + 1) % DEFAULT_MASK_COLORS.length; - const fill = DEFAULT_MASK_COLORS[i]; - assert(fill, 'This should never happen'); - return fill; -}; - -export const regionsReducers = { - rgAdded: { - reducer: ( - state, - action: PayloadAction<{ id: string; overrides?: Partial; isSelected?: boolean }> - ) => { - const { id, overrides, isSelected } = action.payload; - const entity: CanvasRegionalGuidanceState = { - id, - name: null, - type: 'regional_guidance', - isEnabled: true, - objects: [], - fill: { - style: 'solid', - color: getRGMaskFill(state), - }, - opacity: 0.5, - position: { x: 0, y: 0 }, - autoNegative: true, - positivePrompt: '', - negativePrompt: null, - ipAdapters: [], - }; - merge(entity, overrides); - state.regions.entities.push(entity); - if (isSelected) { - state.selectedEntityIdentifier = getEntityIdentifier(entity); - } - }, - prepare: (payload?: { overrides?: Partial; isSelected?: boolean }) => ({ - payload: { ...payload, id: getPrefixedId('regional_guidance') }, - }), - }, - rgRecalled: (state, action: PayloadAction<{ data: CanvasRegionalGuidanceState }>) => { - const { data } = action.payload; - state.regions.entities.push(data); - state.selectedEntityIdentifier = { type: 'regional_guidance', id: data.id }; - }, - rgPositivePromptChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, prompt } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.positivePrompt = prompt; - }, - rgNegativePromptChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, prompt } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.negativePrompt = prompt; - }, - rgFillColorChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, color } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.fill.color = color; - }, - rgFillStyleChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, style } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.fill.style = style; - }, - - rgAutoNegativeToggled: (state, action: PayloadAction>) => { - const { entityIdentifier } = action.payload; - const rg = selectEntity(state, entityIdentifier); - if (!rg) { - return; - } - rg.autoNegative = !rg.autoNegative; - }, - rgIPAdapterAdded: { - reducer: ( - state, - action: PayloadAction< - EntityIdentifierPayload< - { ipAdapterId: string; overrides?: Partial }, - 'regional_guidance' - > - > - ) => { - const { entityIdentifier, overrides, ipAdapterId } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - const ipAdapter = { ...deepClone(initialIPAdapter), id: ipAdapterId }; - merge(ipAdapter, overrides); - entity.ipAdapters.push(ipAdapter); - }, - prepare: ( - payload: EntityIdentifierPayload<{ overrides?: Partial }, 'regional_guidance'> - ) => ({ - payload: { ...payload, ipAdapterId: getPrefixedId('regional_guidance_ip_adapter') }, - }), - }, - rgIPAdapterDeleted: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, ipAdapterId } = action.payload; - const entity = selectEntity(state, entityIdentifier); - if (!entity) { - return; - } - entity.ipAdapters = entity.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId); - }, - rgIPAdapterImageChanged: ( - state, - action: PayloadAction< - EntityIdentifierPayload<{ ipAdapterId: string; imageDTO: ImageDTO | null }, 'regional_guidance'> - > - ) => { - const { entityIdentifier, ipAdapterId, imageDTO } = action.payload; - const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); - if (!ipAdapter) { - return; - } - ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; - }, - rgIPAdapterWeightChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, ipAdapterId, weight } = action.payload; - const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); - if (!ipAdapter) { - return; - } - ipAdapter.weight = weight; - }, - rgIPAdapterBeginEndStepPctChanged: ( - state, - action: PayloadAction< - EntityIdentifierPayload<{ ipAdapterId: string; beginEndStepPct: [number, number] }, 'regional_guidance'> - > - ) => { - const { entityIdentifier, ipAdapterId, beginEndStepPct } = action.payload; - const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); - if (!ipAdapter) { - return; - } - ipAdapter.beginEndStepPct = beginEndStepPct; - }, - rgIPAdapterMethodChanged: ( - state, - action: PayloadAction> - ) => { - const { entityIdentifier, ipAdapterId, method } = action.payload; - const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); - if (!ipAdapter) { - return; - } - ipAdapter.method = method; - }, - rgIPAdapterModelChanged: ( - state, - action: PayloadAction< - EntityIdentifierPayload< - { - ipAdapterId: string; - modelConfig: IPAdapterModelConfig | null; - }, - 'regional_guidance' - > - > - ) => { - const { entityIdentifier, ipAdapterId, modelConfig } = action.payload; - const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); - if (!ipAdapter) { - return; - } - ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null; - }, - rgIPAdapterCLIPVisionModelChanged: ( - state, - action: PayloadAction< - EntityIdentifierPayload<{ ipAdapterId: string; clipVisionModel: CLIPVisionModelV2 }, 'regional_guidance'> - > - ) => { - const { entityIdentifier, ipAdapterId, clipVisionModel } = action.payload; - const ipAdapter = selectRegionalGuidanceIPAdapter(state, entityIdentifier, ipAdapterId); - if (!ipAdapter) { - return; - } - ipAdapter.clipVisionModel = clipVisionModel; - }, -} satisfies SliceCaseReducers;