mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(ui): add adapterType
to ControlAdapterData
This commit is contained in:
parent
5159fcbc33
commit
e55192ae2a
@ -122,7 +122,7 @@ export const CASettings = memo(({ id }: Props) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
<Flex gap={3} w="full">
|
<Flex gap={3} w="full">
|
||||||
<Flex flexDir="column" gap={3} w="full" h="full">
|
<Flex flexDir="column" gap={3} w="full" h="full">
|
||||||
{controlAdapter.controlMode && (
|
{controlAdapter.adapterType === 'controlnet' && (
|
||||||
<CAControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} />
|
<CAControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} />
|
||||||
)}
|
)}
|
||||||
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
||||||
|
@ -9,11 +9,14 @@ import { v4 as uuidv4 } from 'uuid';
|
|||||||
|
|
||||||
import type {
|
import type {
|
||||||
CanvasV2State,
|
CanvasV2State,
|
||||||
ControlAdapterConfig,
|
|
||||||
ControlAdapterData,
|
ControlAdapterData,
|
||||||
ControlModeV2,
|
ControlModeV2,
|
||||||
|
ControlNetConfig,
|
||||||
|
ControlNetData,
|
||||||
Filter,
|
Filter,
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
|
T2IAdapterConfig,
|
||||||
|
T2IAdapterData,
|
||||||
} from './types';
|
} from './types';
|
||||||
import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types';
|
import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types';
|
||||||
|
|
||||||
@ -26,7 +29,7 @@ export const selectCAOrThrow = (state: CanvasV2State, id: string) => {
|
|||||||
|
|
||||||
export const controlAdaptersReducers = {
|
export const controlAdaptersReducers = {
|
||||||
caAdded: {
|
caAdded: {
|
||||||
reducer: (state, action: PayloadAction<{ id: string; config: ControlAdapterConfig }>) => {
|
reducer: (state, action: PayloadAction<{ id: string; config: ControlNetConfig | T2IAdapterConfig }>) => {
|
||||||
const { id, config } = action.payload;
|
const { id, config } = action.payload;
|
||||||
state.controlAdapters.push({
|
state.controlAdapters.push({
|
||||||
id,
|
id,
|
||||||
@ -42,7 +45,7 @@ export const controlAdaptersReducers = {
|
|||||||
...config,
|
...config,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
prepare: (config: ControlAdapterConfig) => ({
|
prepare: (config: ControlNetConfig | T2IAdapterConfig) => ({
|
||||||
payload: { id: uuidv4(), config },
|
payload: { id: uuidv4(), config },
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
@ -169,13 +172,6 @@ export const controlAdaptersReducers = {
|
|||||||
}
|
}
|
||||||
ca.model = zModelIdentifierField.parse(modelConfig);
|
ca.model = zModelIdentifierField.parse(modelConfig);
|
||||||
|
|
||||||
// We may need to convert the CA to match the model
|
|
||||||
if (!ca.controlMode && ca.model.type === 'controlnet') {
|
|
||||||
ca.controlMode = 'balanced';
|
|
||||||
} else if (ca.controlMode && ca.model.type === 't2i_adapter') {
|
|
||||||
ca.controlMode = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const candidateProcessorConfig = buildControlAdapterProcessorV2(modelConfig);
|
const candidateProcessorConfig = buildControlAdapterProcessorV2(modelConfig);
|
||||||
if (candidateProcessorConfig?.type !== ca.processorConfig?.type) {
|
if (candidateProcessorConfig?.type !== ca.processorConfig?.type) {
|
||||||
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
|
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
|
||||||
@ -183,11 +179,21 @@ export const controlAdaptersReducers = {
|
|||||||
ca.processedImage = null;
|
ca.processedImage = null;
|
||||||
ca.processorConfig = candidateProcessorConfig;
|
ca.processorConfig = candidateProcessorConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We may need to convert the CA to match the model
|
||||||
|
if (ca.adapterType === 't2i_adapter' && ca.model.type === 'controlnet') {
|
||||||
|
const convertedCA: ControlNetData = { ...ca, adapterType: 'controlnet', controlMode: 'balanced' };
|
||||||
|
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
|
||||||
|
} else if (ca.adapterType === 'controlnet' && ca.model.type === 't2i_adapter') {
|
||||||
|
const { controlMode: _, ...rest } = ca;
|
||||||
|
const convertedCA: T2IAdapterData = { ...rest, adapterType: 't2i_adapter' };
|
||||||
|
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
caControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
|
caControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
|
||||||
const { id, controlMode } = action.payload;
|
const { id, controlMode } = action.payload;
|
||||||
const ca = selectCA(state, id);
|
const ca = selectCA(state, id);
|
||||||
if (!ca) {
|
if (!ca || ca.adapterType !== 'controlnet') {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ca.controlMode = controlMode;
|
ca.controlMode = controlMode;
|
||||||
|
@ -681,7 +681,7 @@ export type InpaintMaskData = z.infer<typeof zInpaintMaskData>;
|
|||||||
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
|
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
|
||||||
export type Filter = z.infer<typeof zFilter>;
|
export type Filter = z.infer<typeof zFilter>;
|
||||||
|
|
||||||
const zControlAdapterData = z.object({
|
const zControlAdapterDataBase = z.object({
|
||||||
id: zId,
|
id: zId,
|
||||||
type: z.literal('control_adapter'),
|
type: z.literal('control_adapter'),
|
||||||
isEnabled: z.boolean(),
|
isEnabled: z.boolean(),
|
||||||
@ -698,15 +698,37 @@ const zControlAdapterData = z.object({
|
|||||||
processorPendingBatchId: z.string().nullable().default(null),
|
processorPendingBatchId: z.string().nullable().default(null),
|
||||||
beginEndStepPct: zBeginEndStepPct,
|
beginEndStepPct: zBeginEndStepPct,
|
||||||
model: zModelIdentifierField.nullable(),
|
model: zModelIdentifierField.nullable(),
|
||||||
controlMode: zControlModeV2.nullable(),
|
|
||||||
});
|
});
|
||||||
|
const zControlNetData = zControlAdapterDataBase.extend({
|
||||||
|
adapterType: z.literal('controlnet'),
|
||||||
|
controlMode: zControlModeV2,
|
||||||
|
});
|
||||||
|
export type ControlNetData = z.infer<typeof zControlNetData>;
|
||||||
|
const zT2IAdapterData = zControlAdapterDataBase.extend({
|
||||||
|
adapterType: z.literal('t2i_adapter'),
|
||||||
|
});
|
||||||
|
export type T2IAdapterData = z.infer<typeof zT2IAdapterData>;
|
||||||
|
|
||||||
|
const zControlAdapterData = z.discriminatedUnion('adapterType', [zControlNetData, zT2IAdapterData]);
|
||||||
export type ControlAdapterData = z.infer<typeof zControlAdapterData>;
|
export type ControlAdapterData = z.infer<typeof zControlAdapterData>;
|
||||||
export type ControlAdapterConfig = Pick<
|
export type ControlNetConfig = Pick<
|
||||||
ControlAdapterData,
|
ControlNetData,
|
||||||
'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model' | 'controlMode'
|
| 'adapterType'
|
||||||
|
| 'weight'
|
||||||
|
| 'image'
|
||||||
|
| 'processedImage'
|
||||||
|
| 'processorConfig'
|
||||||
|
| 'beginEndStepPct'
|
||||||
|
| 'model'
|
||||||
|
| 'controlMode'
|
||||||
|
>;
|
||||||
|
export type T2IAdapterConfig = Pick<
|
||||||
|
T2IAdapterData,
|
||||||
|
'adapterType' | 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model'
|
||||||
>;
|
>;
|
||||||
|
|
||||||
export const initialControlNetV2: ControlAdapterConfig = {
|
export const initialControlNetV2: ControlNetConfig = {
|
||||||
|
adapterType: 'controlnet',
|
||||||
model: null,
|
model: null,
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginEndStepPct: [0, 1],
|
beginEndStepPct: [0, 1],
|
||||||
@ -716,11 +738,11 @@ export const initialControlNetV2: ControlAdapterConfig = {
|
|||||||
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
||||||
};
|
};
|
||||||
|
|
||||||
export const initialT2IAdapterV2: ControlAdapterConfig = {
|
export const initialT2IAdapterV2: T2IAdapterConfig = {
|
||||||
|
adapterType: 't2i_adapter',
|
||||||
model: null,
|
model: null,
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginEndStepPct: [0, 1],
|
beginEndStepPct: [0, 1],
|
||||||
controlMode: null,
|
|
||||||
image: null,
|
image: null,
|
||||||
processedImage: null,
|
processedImage: null,
|
||||||
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
||||||
@ -735,11 +757,11 @@ export const initialIPAdapterV2: IPAdapterConfig = {
|
|||||||
weight: 1,
|
weight: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const buildControlNet = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => {
|
export const buildControlNet = (id: string, overrides?: Partial<ControlNetConfig>): ControlNetConfig => {
|
||||||
return merge(deepClone(initialControlNetV2), { id, ...overrides });
|
return merge(deepClone(initialControlNetV2), { id, ...overrides });
|
||||||
};
|
};
|
||||||
|
|
||||||
export const buildT2IAdapter = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => {
|
export const buildT2IAdapter = (id: string, overrides?: Partial<T2IAdapterConfig>): T2IAdapterConfig => {
|
||||||
return merge(deepClone(initialT2IAdapterV2), { id, ...overrides });
|
return merge(deepClone(initialT2IAdapterV2), { id, ...overrides });
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user