feat(ui): use default settings for control adapters for processor

This commit is contained in:
psychedelicious 2024-03-08 19:45:46 +11:00 committed by Brandon
parent dbd7c94e7c
commit 53b7f6be37
11 changed files with 138 additions and 68 deletions

View File

@ -1,6 +1,7 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig';
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isNil } from 'lodash-es';
@ -14,12 +15,13 @@ type Props = {
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlAdapterAutoConfigToggled({ id }));
}, [id, dispatch]);
dispatch(controlAdapterAutoConfigToggled({ id, modelConfig }));
}, [id, dispatch, modelConfig]);
if (isNil(shouldAutoConfig)) {
return null;

View File

@ -6,7 +6,6 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { memo, useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
@ -17,21 +16,21 @@ type ParamControlAdapterModelProps = {
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const _onChange = useCallback(
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!model) {
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
return;
}
dispatch(
controlAdapterModelChanged({
id,
model: getModelKeyAndBase(model),
modelConfig,
})
);
},
@ -39,8 +38,8 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
);
const selectedModel = useMemo(
() => (model && controlAdapterType ? { ...model, model_type: controlAdapterType } : null),
[controlAdapterType, model]
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, modelConfig]
);
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({

View File

@ -1,7 +1,9 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import { useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { useControlAdapterModels } from './useControlAdapterModels';
@ -11,7 +13,7 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
const models = useControlAdapterModels(type);
const firstModel = useMemo(() => {
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
// prefer to use a model that matches the base model
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
@ -28,6 +30,26 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
if (isDisabled) {
return;
}
if (
(type === 'controlnet' || type === 't2i_adapter') &&
(firstModel?.type === 'controlnet' || firstModel?.type === 't2i_adapter')
) {
const defaultPreprocessor = firstModel.default_settings?.preprocessor;
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
dispatch(
controlAdapterAdded({
type,
overrides: {
model: firstModel,
processorType,
processorNode,
},
})
);
return;
}
dispatch(
controlAdapterAdded({
type,

View File

@ -1,3 +1,4 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
@ -5,18 +6,22 @@ import {
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
import { isControlAdapterModelConfig } from 'services/api/types';
export const useControlAdapterModel = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(
selectControlAdaptersSlice,
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model?.key
),
[id]
);
const model = useAppSelector(selector);
const key = useAppSelector(selector);
return model;
const result = useGetModelConfigWithTypeGuard(key ?? skipToken, isControlAdapterModelConfig);
return result;
};

View File

@ -3,20 +3,14 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import type {
ParameterControlNetModel,
ParameterIPAdapterModel,
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { cloneDeep, merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { controlAdapterImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import { CONTROLNET_PROCESSORS } from './constants';
import type {
ControlAdapterConfig,
ControlAdapterProcessorType,
@ -194,15 +188,17 @@ export const controlAdaptersSlice = createSlice({
state,
action: PayloadAction<{
id: string;
model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
}>
) => {
const { id, model } = action.payload;
const { id, modelConfig } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
const model = { key: modelConfig.key, base: modelConfig.base };
if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } });
return;
@ -215,24 +211,14 @@ export const controlAdaptersSlice = createSlice({
update.changes.processedControlImage = null;
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (model.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
if (modelConfig.type === 'ip_adapter') {
// should never happen...
return;
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
}
const processor = buildControlAdapterProcessor(modelConfig);
update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
caAdapter.updateOne(state, update);
},
@ -324,39 +310,23 @@ export const controlAdaptersSlice = createSlice({
state,
action: PayloadAction<{
id: string;
modelConfig?: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
}>
) => {
const { id } = action.payload;
const { id, modelConfig } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
if (!cn || !isControlNetOrT2IAdapter(cn) || modelConfig?.type === 'ip_adapter') {
return;
}
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
id,
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
};
if (update.changes.shouldAutoConfig) {
// manage the processor for the user
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (cn.model?.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
}
if (update.changes.shouldAutoConfig && modelConfig) {
const processor = buildControlAdapterProcessor(modelConfig);
update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
}
caAdapter.updateOne(state, update);

View File

@ -0,0 +1,13 @@
import type {
ControlAdapterProcessorType,
zControlAdapterProcessorType,
} from 'features/controlAdapters/store/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type { z } from 'zod';
describe('Control Adapter Types', () => {
test('ControlAdapterProcessorType', () =>
assert<Equals<ControlAdapterProcessorType, z.infer<typeof zControlAdapterProcessorType>>>());
});

View File

@ -47,6 +47,25 @@ export type ControlAdapterProcessorNode =
* Any ControlNet processor type
*/
export type ControlAdapterProcessorType = NonNullable<ControlAdapterProcessorNode['type'] | 'none'>;
export const zControlAdapterProcessorType = z.enum([
'canny_image_processor',
'color_map_image_processor',
'content_shuffle_image_processor',
'depth_anything_image_processor',
'hed_image_processor',
'lineart_anime_image_processor',
'lineart_image_processor',
'mediapipe_face_processor',
'midas_depth_image_processor',
'mlsd_image_processor',
'normalbae_image_processor',
'dw_openpose_image_processor',
'pidi_image_processor',
'zoe_depth_image_processor',
'none',
]);
export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterProcessorType =>
zControlAdapterProcessorType.safeParse(v).success;
/**
* The Canny processor node, with parameters flagged as required

View File

@ -0,0 +1,11 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const buildControlAdapterProcessor = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
return { processorType, processorNode };
};

View File

@ -1,4 +1,5 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
@ -253,8 +254,9 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
.catch(null)
.parse(getProperty(metadataItem, 'resize_mode'));
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const defaultPreprocessor = controlNetModel.default_settings?.preprocessor;
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
const controlNet: ControlNetConfigMetadata = {
type: 'controlnet',
@ -305,8 +307,9 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
.catch(null)
.parse(getProperty(metadataItem, 'resize_mode'));
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor;
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
const t2iAdapter: T2IAdapterConfigMetadata = {
type: 't2i_adapter',

View File

@ -0,0 +1,20 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const useGetModelConfigWithTypeGuard = <T extends AnyModelConfig>(
key: string | typeof skipToken,
typeGuard: (config: AnyModelConfig) => config is T
) => {
const result = useGetModelConfigQuery(key ?? skipToken, {
selectFromResult: (result) => {
const modelConfig = result.data;
return {
...result,
modelConfig: modelConfig && typeGuard(modelConfig) ? modelConfig : undefined,
};
},
});
return result;
};

View File

@ -83,6 +83,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
return config.type === 't2i_adapter';
};
export const isControlAdapterModelConfig = (
config: AnyModelConfig
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
return isControlNetModelConfig(config) || isT2IAdapterModelConfig(config) || isIPAdapterModelConfig(config);
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner';
};