refactor(ui): add adapterType to ControlAdapterData

This commit is contained in:
psychedelicious 2024-06-16 09:10:56 +10:00
parent 5159fcbc33
commit e55192ae2a
3 changed files with 50 additions and 22 deletions

View File

@ -122,7 +122,7 @@ export const CASettings = memo(({ id }: Props) => {
</Flex>
<Flex gap={3} w="full">
<Flex flexDir="column" gap={3} w="full" h="full">
{controlAdapter.controlMode && (
{controlAdapter.adapterType === 'controlnet' && (
<CAControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} />
)}
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />

View File

@ -9,11 +9,14 @@ import { v4 as uuidv4 } from 'uuid';
import type {
CanvasV2State,
ControlAdapterConfig,
ControlAdapterData,
ControlModeV2,
ControlNetConfig,
ControlNetData,
Filter,
ProcessorConfig,
T2IAdapterConfig,
T2IAdapterData,
} from './types';
import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types';
@ -26,7 +29,7 @@ export const selectCAOrThrow = (state: CanvasV2State, id: string) => {
export const controlAdaptersReducers = {
caAdded: {
reducer: (state, action: PayloadAction<{ id: string; config: ControlAdapterConfig }>) => {
reducer: (state, action: PayloadAction<{ id: string; config: ControlNetConfig | T2IAdapterConfig }>) => {
const { id, config } = action.payload;
state.controlAdapters.push({
id,
@ -42,7 +45,7 @@ export const controlAdaptersReducers = {
...config,
});
},
prepare: (config: ControlAdapterConfig) => ({
prepare: (config: ControlNetConfig | T2IAdapterConfig) => ({
payload: { id: uuidv4(), config },
}),
},
@ -169,13 +172,6 @@ export const controlAdaptersReducers = {
}
ca.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (!ca.controlMode && ca.model.type === 'controlnet') {
ca.controlMode = 'balanced';
} else if (ca.controlMode && ca.model.type === 't2i_adapter') {
ca.controlMode = null;
}
const candidateProcessorConfig = buildControlAdapterProcessorV2(modelConfig);
if (candidateProcessorConfig?.type !== ca.processorConfig?.type) {
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
@ -183,11 +179,21 @@ export const controlAdaptersReducers = {
ca.processedImage = null;
ca.processorConfig = candidateProcessorConfig;
}
// We may need to convert the CA to match the model
if (ca.adapterType === 't2i_adapter' && ca.model.type === 'controlnet') {
const convertedCA: ControlNetData = { ...ca, adapterType: 'controlnet', controlMode: 'balanced' };
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
} else if (ca.adapterType === 'controlnet' && ca.model.type === 't2i_adapter') {
const { controlMode: _, ...rest } = ca;
const convertedCA: T2IAdapterData = { ...rest, adapterType: 't2i_adapter' };
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
}
},
caControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload;
const ca = selectCA(state, id);
if (!ca) {
if (!ca || ca.adapterType !== 'controlnet') {
return;
}
ca.controlMode = controlMode;

View File

@ -681,7 +681,7 @@ export type InpaintMaskData = z.infer<typeof zInpaintMaskData>;
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
export type Filter = z.infer<typeof zFilter>;
const zControlAdapterData = z.object({
const zControlAdapterDataBase = z.object({
id: zId,
type: z.literal('control_adapter'),
isEnabled: z.boolean(),
@ -698,15 +698,37 @@ const zControlAdapterData = z.object({
processorPendingBatchId: z.string().nullable().default(null),
beginEndStepPct: zBeginEndStepPct,
model: zModelIdentifierField.nullable(),
controlMode: zControlModeV2.nullable(),
});
const zControlNetData = zControlAdapterDataBase.extend({
adapterType: z.literal('controlnet'),
controlMode: zControlModeV2,
});
export type ControlNetData = z.infer<typeof zControlNetData>;
const zT2IAdapterData = zControlAdapterDataBase.extend({
adapterType: z.literal('t2i_adapter'),
});
export type T2IAdapterData = z.infer<typeof zT2IAdapterData>;
const zControlAdapterData = z.discriminatedUnion('adapterType', [zControlNetData, zT2IAdapterData]);
export type ControlAdapterData = z.infer<typeof zControlAdapterData>;
export type ControlAdapterConfig = Pick<
ControlAdapterData,
'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model' | 'controlMode'
export type ControlNetConfig = Pick<
ControlNetData,
| 'adapterType'
| 'weight'
| 'image'
| 'processedImage'
| 'processorConfig'
| 'beginEndStepPct'
| 'model'
| 'controlMode'
>;
export type T2IAdapterConfig = Pick<
T2IAdapterData,
'adapterType' | 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model'
>;
export const initialControlNetV2: ControlAdapterConfig = {
export const initialControlNetV2: ControlNetConfig = {
adapterType: 'controlnet',
model: null,
weight: 1,
beginEndStepPct: [0, 1],
@ -716,11 +738,11 @@ export const initialControlNetV2: ControlAdapterConfig = {
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
};
export const initialT2IAdapterV2: ControlAdapterConfig = {
export const initialT2IAdapterV2: T2IAdapterConfig = {
adapterType: 't2i_adapter',
model: null,
weight: 1,
beginEndStepPct: [0, 1],
controlMode: null,
image: null,
processedImage: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
@ -735,11 +757,11 @@ export const initialIPAdapterV2: IPAdapterConfig = {
weight: 1,
};
export const buildControlNet = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => {
export const buildControlNet = (id: string, overrides?: Partial<ControlNetConfig>): ControlNetConfig => {
return merge(deepClone(initialControlNetV2), { id, ...overrides });
};
export const buildT2IAdapter = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => {
export const buildT2IAdapter = (id: string, overrides?: Partial<T2IAdapterConfig>): T2IAdapterConfig => {
return merge(deepClone(initialT2IAdapterV2), { id, ...overrides });
};