refactor(ui): add adapterType to ControlAdapterData

This commit is contained in:
psychedelicious 2024-06-16 09:10:56 +10:00
parent 5159fcbc33
commit e55192ae2a
3 changed files with 50 additions and 22 deletions

View File

@ -122,7 +122,7 @@ export const CASettings = memo(({ id }: Props) => {
</Flex> </Flex>
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Flex flexDir="column" gap={3} w="full" h="full"> <Flex flexDir="column" gap={3} w="full" h="full">
{controlAdapter.controlMode && ( {controlAdapter.adapterType === 'controlnet' && (
<CAControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} /> <CAControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} />
)} )}
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} /> <Weight weight={controlAdapter.weight} onChange={onChangeWeight} />

View File

@ -9,11 +9,14 @@ import { v4 as uuidv4 } from 'uuid';
import type { import type {
CanvasV2State, CanvasV2State,
ControlAdapterConfig,
ControlAdapterData, ControlAdapterData,
ControlModeV2, ControlModeV2,
ControlNetConfig,
ControlNetData,
Filter, Filter,
ProcessorConfig, ProcessorConfig,
T2IAdapterConfig,
T2IAdapterData,
} from './types'; } from './types';
import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types'; import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types';
@ -26,7 +29,7 @@ export const selectCAOrThrow = (state: CanvasV2State, id: string) => {
export const controlAdaptersReducers = { export const controlAdaptersReducers = {
caAdded: { caAdded: {
reducer: (state, action: PayloadAction<{ id: string; config: ControlAdapterConfig }>) => { reducer: (state, action: PayloadAction<{ id: string; config: ControlNetConfig | T2IAdapterConfig }>) => {
const { id, config } = action.payload; const { id, config } = action.payload;
state.controlAdapters.push({ state.controlAdapters.push({
id, id,
@ -42,7 +45,7 @@ export const controlAdaptersReducers = {
...config, ...config,
}); });
}, },
prepare: (config: ControlAdapterConfig) => ({ prepare: (config: ControlNetConfig | T2IAdapterConfig) => ({
payload: { id: uuidv4(), config }, payload: { id: uuidv4(), config },
}), }),
}, },
@ -169,13 +172,6 @@ export const controlAdaptersReducers = {
} }
ca.model = zModelIdentifierField.parse(modelConfig); ca.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (!ca.controlMode && ca.model.type === 'controlnet') {
ca.controlMode = 'balanced';
} else if (ca.controlMode && ca.model.type === 't2i_adapter') {
ca.controlMode = null;
}
const candidateProcessorConfig = buildControlAdapterProcessorV2(modelConfig); const candidateProcessorConfig = buildControlAdapterProcessorV2(modelConfig);
if (candidateProcessorConfig?.type !== ca.processorConfig?.type) { if (candidateProcessorConfig?.type !== ca.processorConfig?.type) {
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth // The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
@ -183,11 +179,21 @@ export const controlAdaptersReducers = {
ca.processedImage = null; ca.processedImage = null;
ca.processorConfig = candidateProcessorConfig; ca.processorConfig = candidateProcessorConfig;
} }
// We may need to convert the CA to match the model
if (ca.adapterType === 't2i_adapter' && ca.model.type === 'controlnet') {
const convertedCA: ControlNetData = { ...ca, adapterType: 'controlnet', controlMode: 'balanced' };
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
} else if (ca.adapterType === 'controlnet' && ca.model.type === 't2i_adapter') {
const { controlMode: _, ...rest } = ca;
const convertedCA: T2IAdapterData = { ...rest, adapterType: 't2i_adapter' };
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
}
}, },
caControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => { caControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload; const { id, controlMode } = action.payload;
const ca = selectCA(state, id); const ca = selectCA(state, id);
if (!ca) { if (!ca || ca.adapterType !== 'controlnet') {
return; return;
} }
ca.controlMode = controlMode; ca.controlMode = controlMode;

View File

@ -681,7 +681,7 @@ export type InpaintMaskData = z.infer<typeof zInpaintMaskData>;
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']); const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
export type Filter = z.infer<typeof zFilter>; export type Filter = z.infer<typeof zFilter>;
const zControlAdapterData = z.object({ const zControlAdapterDataBase = z.object({
id: zId, id: zId,
type: z.literal('control_adapter'), type: z.literal('control_adapter'),
isEnabled: z.boolean(), isEnabled: z.boolean(),
@ -698,15 +698,37 @@ const zControlAdapterData = z.object({
processorPendingBatchId: z.string().nullable().default(null), processorPendingBatchId: z.string().nullable().default(null),
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
model: zModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
controlMode: zControlModeV2.nullable(),
}); });
const zControlNetData = zControlAdapterDataBase.extend({
adapterType: z.literal('controlnet'),
controlMode: zControlModeV2,
});
export type ControlNetData = z.infer<typeof zControlNetData>;
const zT2IAdapterData = zControlAdapterDataBase.extend({
adapterType: z.literal('t2i_adapter'),
});
export type T2IAdapterData = z.infer<typeof zT2IAdapterData>;
const zControlAdapterData = z.discriminatedUnion('adapterType', [zControlNetData, zT2IAdapterData]);
export type ControlAdapterData = z.infer<typeof zControlAdapterData>; export type ControlAdapterData = z.infer<typeof zControlAdapterData>;
export type ControlAdapterConfig = Pick< export type ControlNetConfig = Pick<
ControlAdapterData, ControlNetData,
'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model' | 'controlMode' | 'adapterType'
| 'weight'
| 'image'
| 'processedImage'
| 'processorConfig'
| 'beginEndStepPct'
| 'model'
| 'controlMode'
>;
export type T2IAdapterConfig = Pick<
T2IAdapterData,
'adapterType' | 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model'
>; >;
export const initialControlNetV2: ControlAdapterConfig = { export const initialControlNetV2: ControlNetConfig = {
adapterType: 'controlnet',
model: null, model: null,
weight: 1, weight: 1,
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
@ -716,11 +738,11 @@ export const initialControlNetV2: ControlAdapterConfig = {
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
}; };
export const initialT2IAdapterV2: ControlAdapterConfig = { export const initialT2IAdapterV2: T2IAdapterConfig = {
adapterType: 't2i_adapter',
model: null, model: null,
weight: 1, weight: 1,
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
controlMode: null,
image: null, image: null,
processedImage: null, processedImage: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
@ -735,11 +757,11 @@ export const initialIPAdapterV2: IPAdapterConfig = {
weight: 1, weight: 1,
}; };
export const buildControlNet = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => { export const buildControlNet = (id: string, overrides?: Partial<ControlNetConfig>): ControlNetConfig => {
return merge(deepClone(initialControlNetV2), { id, ...overrides }); return merge(deepClone(initialControlNetV2), { id, ...overrides });
}; };
export const buildT2IAdapter = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => { export const buildT2IAdapter = (id: string, overrides?: Partial<T2IAdapterConfig>): T2IAdapterConfig => {
return merge(deepClone(initialT2IAdapterV2), { id, ...overrides }); return merge(deepClone(initialT2IAdapterV2), { id, ...overrides });
}; };