feat(ui): use default settings for control adapters for processor

This commit is contained in:
psychedelicious 2024-03-08 19:45:46 +11:00 committed by Brandon
parent dbd7c94e7c
commit 53b7f6be37
11 changed files with 138 additions and 68 deletions

View File

@ -1,6 +1,7 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig'; import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig';
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isNil } from 'lodash-es'; import { isNil } from 'lodash-es';
@ -14,12 +15,13 @@ type Props = {
const ControlAdapterShouldAutoConfig = ({ id }: Props) => { const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id); const isEnabled = useControlAdapterIsEnabled(id);
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id); const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const handleShouldAutoConfigChanged = useCallback(() => { const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlAdapterAutoConfigToggled({ id })); dispatch(controlAdapterAutoConfigToggled({ id, modelConfig }));
}, [id, dispatch]); }, [id, dispatch, modelConfig]);
if (isNil(shouldAutoConfig)) { if (isNil(shouldAutoConfig)) {
return null; return null;

View File

@ -6,7 +6,6 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery'; import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
@ -17,21 +16,21 @@ type ParamControlAdapterModelProps = {
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id); const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id); const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id); const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const _onChange = useCallback( const _onChange = useCallback(
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { (modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!model) { if (!modelConfig) {
return; return;
} }
dispatch( dispatch(
controlAdapterModelChanged({ controlAdapterModelChanged({
id, id,
model: getModelKeyAndBase(model), modelConfig,
}) })
); );
}, },
@ -39,8 +38,8 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
); );
const selectedModel = useMemo( const selectedModel = useMemo(
() => (model && controlAdapterType ? { ...model, model_type: controlAdapterType } : null), () => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, model] [controlAdapterType, modelConfig]
); );
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({

View File

@ -1,7 +1,9 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterType } from 'features/controlAdapters/store/types'; import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { useControlAdapterModels } from './useControlAdapterModels'; import { useControlAdapterModels } from './useControlAdapterModels';
@ -11,7 +13,7 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
const models = useControlAdapterModels(type); const models = useControlAdapterModels(type);
const firstModel = useMemo(() => { const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
// prefer to use a model that matches the base model // prefer to use a model that matches the base model
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0]; const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
@ -28,6 +30,26 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
if (isDisabled) { if (isDisabled) {
return; return;
} }
if (
(type === 'controlnet' || type === 't2i_adapter') &&
(firstModel?.type === 'controlnet' || firstModel?.type === 't2i_adapter')
) {
const defaultPreprocessor = firstModel.default_settings?.preprocessor;
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
dispatch(
controlAdapterAdded({
type,
overrides: {
model: firstModel,
processorType,
processorNode,
},
})
);
return;
}
dispatch( dispatch(
controlAdapterAdded({ controlAdapterAdded({
type, type,

View File

@ -1,3 +1,4 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { import {
@ -5,18 +6,22 @@ import {
selectControlAdaptersSlice, selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
import { isControlAdapterModelConfig } from 'services/api/types';
export const useControlAdapterModel = (id: string) => { export const useControlAdapterModel = (id: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector( createMemoizedSelector(
selectControlAdaptersSlice, selectControlAdaptersSlice,
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model (controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model?.key
), ),
[id] [id]
); );
const model = useAppSelector(selector); const key = useAppSelector(selector);
return model; const result = useGetModelConfigWithTypeGuard(key ?? skipToken, isControlAdapterModelConfig);
return result;
}; };

View File

@ -3,20 +3,14 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter'; import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import type { import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
ParameterControlNetModel,
ParameterIPAdapterModel,
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, merge, uniq } from 'lodash-es'; import { cloneDeep, merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions'; import { socketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { controlAdapterImageProcessed } from './actions'; import { controlAdapterImageProcessed } from './actions';
import { import { CONTROLNET_PROCESSORS } from './constants';
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import type { import type {
ControlAdapterConfig, ControlAdapterConfig,
ControlAdapterProcessorType, ControlAdapterProcessorType,
@ -194,15 +188,17 @@ export const controlAdaptersSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
id: string; id: string;
model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel; modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
}> }>
) => { ) => {
const { id, model } = action.payload; const { id, modelConfig } = action.payload;
const cn = selectControlAdapterById(state, id); const cn = selectControlAdapterById(state, id);
if (!cn) { if (!cn) {
return; return;
} }
const model = { key: modelConfig.key, base: modelConfig.base };
if (!isControlNetOrT2IAdapter(cn)) { if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } }); caAdapter.updateOne(state, { id, changes: { model } });
return; return;
@ -215,24 +211,14 @@ export const controlAdaptersSlice = createSlice({
update.changes.processedControlImage = null; update.changes.processedControlImage = null;
let processorType: ControlAdapterProcessorType | undefined = undefined; if (modelConfig.type === 'ip_adapter') {
// should never happen...
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { return;
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (model.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
} }
if (processorType) { const processor = buildControlAdapterProcessor(modelConfig);
update.changes.processorType = processorType; update.changes.processorType = processor.processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType] update.changes.processorNode = processor.processorNode;
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
}
caAdapter.updateOne(state, update); caAdapter.updateOne(state, update);
}, },
@ -324,39 +310,23 @@ export const controlAdaptersSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
id: string; id: string;
modelConfig?: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
}> }>
) => { ) => {
const { id } = action.payload; const { id, modelConfig } = action.payload;
const cn = selectControlAdapterById(state, id); const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) { if (!cn || !isControlNetOrT2IAdapter(cn) || modelConfig?.type === 'ip_adapter') {
return; return;
} }
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = { const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
id, id,
changes: { shouldAutoConfig: !cn.shouldAutoConfig }, changes: { shouldAutoConfig: !cn.shouldAutoConfig },
}; };
if (update.changes.shouldAutoConfig) { if (update.changes.shouldAutoConfig && modelConfig) {
// manage the processor for the user const processor = buildControlAdapterProcessor(modelConfig);
let processorType: ControlAdapterProcessorType | undefined = undefined; update.changes.processorType = processor.processorType;
update.changes.processorNode = processor.processorNode;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (cn.model?.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
}
} }
caAdapter.updateOne(state, update); caAdapter.updateOne(state, update);

View File

@ -0,0 +1,13 @@
import type {
ControlAdapterProcessorType,
zControlAdapterProcessorType,
} from 'features/controlAdapters/store/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type { z } from 'zod';
describe('Control Adapter Types', () => {
test('ControlAdapterProcessorType', () =>
assert<Equals<ControlAdapterProcessorType, z.infer<typeof zControlAdapterProcessorType>>>());
});

View File

@ -47,6 +47,25 @@ export type ControlAdapterProcessorNode =
* Any ControlNet processor type * Any ControlNet processor type
*/ */
export type ControlAdapterProcessorType = NonNullable<ControlAdapterProcessorNode['type'] | 'none'>; export type ControlAdapterProcessorType = NonNullable<ControlAdapterProcessorNode['type'] | 'none'>;
export const zControlAdapterProcessorType = z.enum([
'canny_image_processor',
'color_map_image_processor',
'content_shuffle_image_processor',
'depth_anything_image_processor',
'hed_image_processor',
'lineart_anime_image_processor',
'lineart_image_processor',
'mediapipe_face_processor',
'midas_depth_image_processor',
'mlsd_image_processor',
'normalbae_image_processor',
'dw_openpose_image_processor',
'pidi_image_processor',
'zoe_depth_image_processor',
'none',
]);
export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterProcessorType =>
zControlAdapterProcessorType.safeParse(v).success;
/** /**
* The Canny processor node, with parameters flagged as required * The Canny processor node, with parameters flagged as required

View File

@ -0,0 +1,11 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const buildControlAdapterProcessor = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
return { processorType, processorNode };
};

View File

@ -1,4 +1,5 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import { import {
initialControlNet, initialControlNet,
initialIPAdapter, initialIPAdapter,
@ -253,8 +254,9 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
.catch(null) .catch(null)
.parse(getProperty(metadataItem, 'resize_mode')); .parse(getProperty(metadataItem, 'resize_mode'));
const processorType = 'none'; const defaultPreprocessor = controlNetModel.default_settings?.preprocessor;
const processorNode = CONTROLNET_PROCESSORS.none.default; const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
const controlNet: ControlNetConfigMetadata = { const controlNet: ControlNetConfigMetadata = {
type: 'controlnet', type: 'controlnet',
@ -305,8 +307,9 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
.catch(null) .catch(null)
.parse(getProperty(metadataItem, 'resize_mode')); .parse(getProperty(metadataItem, 'resize_mode'));
const processorType = 'none'; const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor;
const processorNode = CONTROLNET_PROCESSORS.none.default; const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
const t2iAdapter: T2IAdapterConfigMetadata = { const t2iAdapter: T2IAdapterConfigMetadata = {
type: 't2i_adapter', type: 't2i_adapter',

View File

@ -0,0 +1,20 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const useGetModelConfigWithTypeGuard = <T extends AnyModelConfig>(
key: string | typeof skipToken,
typeGuard: (config: AnyModelConfig) => config is T
) => {
const result = useGetModelConfigQuery(key ?? skipToken, {
selectFromResult: (result) => {
const modelConfig = result.data;
return {
...result,
modelConfig: modelConfig && typeGuard(modelConfig) ? modelConfig : undefined,
};
},
});
return result;
};

View File

@ -83,6 +83,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
return config.type === 't2i_adapter'; return config.type === 't2i_adapter';
}; };
export const isControlAdapterModelConfig = (
config: AnyModelConfig
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
return isControlNetModelConfig(config) || isT2IAdapterModelConfig(config) || isIPAdapterModelConfig(config);
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner'; return config.type === 'main' && config.base !== 'sdxl-refiner';
}; };