mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(ui): add adapterType
to ControlAdapterData
This commit is contained in:
parent
5159fcbc33
commit
e55192ae2a
@ -122,7 +122,7 @@ export const CASettings = memo(({ id }: Props) => {
|
||||
</Flex>
|
||||
<Flex gap={3} w="full">
|
||||
<Flex flexDir="column" gap={3} w="full" h="full">
|
||||
{controlAdapter.controlMode && (
|
||||
{controlAdapter.adapterType === 'controlnet' && (
|
||||
<CAControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} />
|
||||
)}
|
||||
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
||||
|
@ -9,11 +9,14 @@ import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import type {
|
||||
CanvasV2State,
|
||||
ControlAdapterConfig,
|
||||
ControlAdapterData,
|
||||
ControlModeV2,
|
||||
ControlNetConfig,
|
||||
ControlNetData,
|
||||
Filter,
|
||||
ProcessorConfig,
|
||||
T2IAdapterConfig,
|
||||
T2IAdapterData,
|
||||
} from './types';
|
||||
import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types';
|
||||
|
||||
@ -26,7 +29,7 @@ export const selectCAOrThrow = (state: CanvasV2State, id: string) => {
|
||||
|
||||
export const controlAdaptersReducers = {
|
||||
caAdded: {
|
||||
reducer: (state, action: PayloadAction<{ id: string; config: ControlAdapterConfig }>) => {
|
||||
reducer: (state, action: PayloadAction<{ id: string; config: ControlNetConfig | T2IAdapterConfig }>) => {
|
||||
const { id, config } = action.payload;
|
||||
state.controlAdapters.push({
|
||||
id,
|
||||
@ -42,7 +45,7 @@ export const controlAdaptersReducers = {
|
||||
...config,
|
||||
});
|
||||
},
|
||||
prepare: (config: ControlAdapterConfig) => ({
|
||||
prepare: (config: ControlNetConfig | T2IAdapterConfig) => ({
|
||||
payload: { id: uuidv4(), config },
|
||||
}),
|
||||
},
|
||||
@ -169,13 +172,6 @@ export const controlAdaptersReducers = {
|
||||
}
|
||||
ca.model = zModelIdentifierField.parse(modelConfig);
|
||||
|
||||
// We may need to convert the CA to match the model
|
||||
if (!ca.controlMode && ca.model.type === 'controlnet') {
|
||||
ca.controlMode = 'balanced';
|
||||
} else if (ca.controlMode && ca.model.type === 't2i_adapter') {
|
||||
ca.controlMode = null;
|
||||
}
|
||||
|
||||
const candidateProcessorConfig = buildControlAdapterProcessorV2(modelConfig);
|
||||
if (candidateProcessorConfig?.type !== ca.processorConfig?.type) {
|
||||
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
|
||||
@ -183,11 +179,21 @@ export const controlAdaptersReducers = {
|
||||
ca.processedImage = null;
|
||||
ca.processorConfig = candidateProcessorConfig;
|
||||
}
|
||||
|
||||
// We may need to convert the CA to match the model
|
||||
if (ca.adapterType === 't2i_adapter' && ca.model.type === 'controlnet') {
|
||||
const convertedCA: ControlNetData = { ...ca, adapterType: 'controlnet', controlMode: 'balanced' };
|
||||
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
|
||||
} else if (ca.adapterType === 'controlnet' && ca.model.type === 't2i_adapter') {
|
||||
const { controlMode: _, ...rest } = ca;
|
||||
const convertedCA: T2IAdapterData = { ...rest, adapterType: 't2i_adapter' };
|
||||
state.controlAdapters.splice(state.controlAdapters.indexOf(ca), 1, convertedCA);
|
||||
}
|
||||
},
|
||||
caControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
|
||||
const { id, controlMode } = action.payload;
|
||||
const ca = selectCA(state, id);
|
||||
if (!ca) {
|
||||
if (!ca || ca.adapterType !== 'controlnet') {
|
||||
return;
|
||||
}
|
||||
ca.controlMode = controlMode;
|
||||
|
@ -681,7 +681,7 @@ export type InpaintMaskData = z.infer<typeof zInpaintMaskData>;
|
||||
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
|
||||
export type Filter = z.infer<typeof zFilter>;
|
||||
|
||||
const zControlAdapterData = z.object({
|
||||
const zControlAdapterDataBase = z.object({
|
||||
id: zId,
|
||||
type: z.literal('control_adapter'),
|
||||
isEnabled: z.boolean(),
|
||||
@ -698,15 +698,37 @@ const zControlAdapterData = z.object({
|
||||
processorPendingBatchId: z.string().nullable().default(null),
|
||||
beginEndStepPct: zBeginEndStepPct,
|
||||
model: zModelIdentifierField.nullable(),
|
||||
controlMode: zControlModeV2.nullable(),
|
||||
});
|
||||
const zControlNetData = zControlAdapterDataBase.extend({
|
||||
adapterType: z.literal('controlnet'),
|
||||
controlMode: zControlModeV2,
|
||||
});
|
||||
export type ControlNetData = z.infer<typeof zControlNetData>;
|
||||
const zT2IAdapterData = zControlAdapterDataBase.extend({
|
||||
adapterType: z.literal('t2i_adapter'),
|
||||
});
|
||||
export type T2IAdapterData = z.infer<typeof zT2IAdapterData>;
|
||||
|
||||
const zControlAdapterData = z.discriminatedUnion('adapterType', [zControlNetData, zT2IAdapterData]);
|
||||
export type ControlAdapterData = z.infer<typeof zControlAdapterData>;
|
||||
export type ControlAdapterConfig = Pick<
|
||||
ControlAdapterData,
|
||||
'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model' | 'controlMode'
|
||||
export type ControlNetConfig = Pick<
|
||||
ControlNetData,
|
||||
| 'adapterType'
|
||||
| 'weight'
|
||||
| 'image'
|
||||
| 'processedImage'
|
||||
| 'processorConfig'
|
||||
| 'beginEndStepPct'
|
||||
| 'model'
|
||||
| 'controlMode'
|
||||
>;
|
||||
export type T2IAdapterConfig = Pick<
|
||||
T2IAdapterData,
|
||||
'adapterType' | 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model'
|
||||
>;
|
||||
|
||||
export const initialControlNetV2: ControlAdapterConfig = {
|
||||
export const initialControlNetV2: ControlNetConfig = {
|
||||
adapterType: 'controlnet',
|
||||
model: null,
|
||||
weight: 1,
|
||||
beginEndStepPct: [0, 1],
|
||||
@ -716,11 +738,11 @@ export const initialControlNetV2: ControlAdapterConfig = {
|
||||
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
||||
};
|
||||
|
||||
export const initialT2IAdapterV2: ControlAdapterConfig = {
|
||||
export const initialT2IAdapterV2: T2IAdapterConfig = {
|
||||
adapterType: 't2i_adapter',
|
||||
model: null,
|
||||
weight: 1,
|
||||
beginEndStepPct: [0, 1],
|
||||
controlMode: null,
|
||||
image: null,
|
||||
processedImage: null,
|
||||
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
||||
@ -735,11 +757,11 @@ export const initialIPAdapterV2: IPAdapterConfig = {
|
||||
weight: 1,
|
||||
};
|
||||
|
||||
export const buildControlNet = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => {
|
||||
export const buildControlNet = (id: string, overrides?: Partial<ControlNetConfig>): ControlNetConfig => {
|
||||
return merge(deepClone(initialControlNetV2), { id, ...overrides });
|
||||
};
|
||||
|
||||
export const buildT2IAdapter = (id: string, overrides?: Partial<ControlAdapterConfig>): ControlAdapterConfig => {
|
||||
export const buildT2IAdapter = (id: string, overrides?: Partial<T2IAdapterConfig>): T2IAdapterConfig => {
|
||||
return merge(deepClone(initialT2IAdapterV2), { id, ...overrides });
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user