From e55192ae2a006fb3c3dabb3d4dd2dd91de3c5015 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 16 Jun 2024 09:10:56 +1000 Subject: [PATCH] refactor(ui): add `adapterType` to ControlAdapterData --- .../components/ControlAdapter/CASettings.tsx | 2 +- .../store/controlAdaptersReducers.ts | 28 ++++++++----- .../src/features/controlLayers/store/types.ts | 42 ++++++++++++++----- 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlAdapter/CASettings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlAdapter/CASettings.tsx index 292e456584..70cde437e7 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlAdapter/CASettings.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlAdapter/CASettings.tsx @@ -122,7 +122,7 @@ export const CASettings = memo(({ id }: Props) => { - {controlAdapter.controlMode && ( + {controlAdapter.adapterType === 'controlnet' && ( )} diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlAdaptersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlAdaptersReducers.ts index 6cda64871f..84c12e5f3a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlAdaptersReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/controlAdaptersReducers.ts @@ -9,11 +9,14 @@ import { v4 as uuidv4 } from 'uuid'; import type { CanvasV2State, - ControlAdapterConfig, ControlAdapterData, ControlModeV2, + ControlNetConfig, + ControlNetData, Filter, ProcessorConfig, + T2IAdapterConfig, + T2IAdapterData, } from './types'; import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types'; @@ -26,7 +29,7 @@ export const selectCAOrThrow = (state: CanvasV2State, id: string) => { export const controlAdaptersReducers = { caAdded: { - reducer: (state, action: PayloadAction<{ id: string; config: ControlAdapterConfig }>) => { + reducer: (state, action: PayloadAction<{ id: string; config: ControlNetConfig | T2IAdapterConfig }>) => { const { id, config } = action.payload; state.controlAdapters.push({ id, @@ -42,7 +45,7 @@ export const controlAdaptersReducers = { ...config, }); }, - prepare: (config: ControlAdapterConfig) => ({ + prepare: (config: ControlNetConfig | T2IAdapterConfig) => ({ payload: { id: uuidv4(), config }, }), }, @@ -169,13 +172,6 @@ export const controlAdaptersReducers = { } ca.model = zModelIdentifierField.parse(modelConfig); - // We may need to convert the CA to match the model - if (!ca.controlMode && ca.model.type === 'controlnet') { - ca.controlMode = 'balanced'; - } else if (ca.controlMode && ca.model.type === 't2i_adapter') { - ca.controlMode = null; - } - const candidateProcessorConfig = buildControlAdapterProcessorV2(modelConfig); if (candidateProcessorConfig?.type !== ca.processorConfig?.type) { // The processor has changed. For example, the previous model was a Canny model and the new model is a Depth @@ -183,11 +179,21 @@ export const controlAdaptersReducers = { ca.processedImage = null; ca.processorConfig = candidateProcessorConfig; } + + // We may need to convert the CA to match the model + if (ca.adapterType === 't2i_adapter' && ca.model.type === 'controlnet') { + const convertedCA: ControlNetData = { ...ca, adapterType: 'controlnet', controlMode: 'balanced' }; + state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA); + } else if (ca.adapterType === 'controlnet' && ca.model.type === 't2i_adapter') { + const { controlMode: _, ...rest } = ca; + const convertedCA: T2IAdapterData = { ...rest, adapterType: 't2i_adapter' }; + state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA); + } }, caControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => { const { id, controlMode } = action.payload; const ca = selectCA(state, id); - if (!ca) { + if (!ca || ca.adapterType !== 'controlnet') { return; } ca.controlMode = controlMode; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 64e2ffea05..cd8b112e49 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -681,7 +681,7 @@ export type InpaintMaskData = z.infer; const zFilter = z.enum(['none', 'LightnessToAlphaFilter']); export type Filter = z.infer; -const zControlAdapterData = z.object({ +const zControlAdapterDataBase = z.object({ id: zId, type: z.literal('control_adapter'), isEnabled: z.boolean(), @@ -698,15 +698,37 @@ const zControlAdapterData = z.object({ processorPendingBatchId: z.string().nullable().default(null), beginEndStepPct: zBeginEndStepPct, model: zModelIdentifierField.nullable(), - controlMode: zControlModeV2.nullable(), }); +const zControlNetData = zControlAdapterDataBase.extend({ + adapterType: z.literal('controlnet'), + controlMode: zControlModeV2, +}); +export type ControlNetData = z.infer; +const zT2IAdapterData = zControlAdapterDataBase.extend({ + adapterType: z.literal('t2i_adapter'), +}); +export type T2IAdapterData = z.infer; + +const zControlAdapterData = z.discriminatedUnion('adapterType', [zControlNetData, zT2IAdapterData]); export type ControlAdapterData = z.infer; -export type ControlAdapterConfig = Pick< - ControlAdapterData, - 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model' | 'controlMode' +export type ControlNetConfig = Pick< + ControlNetData, + | 'adapterType' + | 'weight' + | 'image' + | 'processedImage' + | 'processorConfig' + | 'beginEndStepPct' + | 'model' + | 'controlMode' +>; +export type T2IAdapterConfig = Pick< + T2IAdapterData, + 'adapterType' | 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model' >; -export const initialControlNetV2: ControlAdapterConfig = { +export const initialControlNetV2: ControlNetConfig = { + adapterType: 'controlnet', model: null, weight: 1, beginEndStepPct: [0, 1], @@ -716,11 +738,11 @@ export const initialControlNetV2: ControlAdapterConfig = { processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), }; -export const initialT2IAdapterV2: ControlAdapterConfig = { +export const initialT2IAdapterV2: T2IAdapterConfig = { + adapterType: 't2i_adapter', model: null, weight: 1, beginEndStepPct: [0, 1], - controlMode: null, image: null, processedImage: null, processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), @@ -735,11 +757,11 @@ export const initialIPAdapterV2: IPAdapterConfig = { weight: 1, }; -export const buildControlNet = (id: string, overrides?: Partial): ControlAdapterConfig => { +export const buildControlNet = (id: string, overrides?: Partial): ControlNetConfig => { return merge(deepClone(initialControlNetV2), { id, ...overrides }); }; -export const buildT2IAdapter = (id: string, overrides?: Partial): ControlAdapterConfig => { +export const buildT2IAdapter = (id: string, overrides?: Partial): T2IAdapterConfig => { return merge(deepClone(initialT2IAdapterV2), { id, ...overrides }); };