mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use default settings for control adapters for processor
This commit is contained in:
parent
dbd7c94e7c
commit
53b7f6be37
@ -1,6 +1,7 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig';
|
||||
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
@ -14,12 +15,13 @@ type Props = {
|
||||
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
|
||||
const { modelConfig } = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||
dispatch(controlAdapterAutoConfigToggled({ id }));
|
||||
}, [id, dispatch]);
|
||||
dispatch(controlAdapterAutoConfigToggled({ id, modelConfig }));
|
||||
}, [id, dispatch, modelConfig]);
|
||||
|
||||
if (isNil(shouldAutoConfig)) {
|
||||
return null;
|
||||
|
@ -6,7 +6,6 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
|
||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
@ -17,21 +16,21 @@ type ParamControlAdapterModelProps = {
|
||||
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const controlAdapterType = useControlAdapterType(id);
|
||||
const model = useControlAdapterModel(id);
|
||||
const { modelConfig } = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
if (!model) {
|
||||
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterModelChanged({
|
||||
id,
|
||||
model: getModelKeyAndBase(model),
|
||||
modelConfig,
|
||||
})
|
||||
);
|
||||
},
|
||||
@ -39,8 +38,8 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => (model && controlAdapterType ? { ...model, model_type: controlAdapterType } : null),
|
||||
[controlAdapterType, model]
|
||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
||||
[controlAdapterType, modelConfig]
|
||||
);
|
||||
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
|
@ -1,7 +1,9 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
||||
|
||||
@ -11,7 +13,7 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
|
||||
const models = useControlAdapterModels(type);
|
||||
|
||||
const firstModel = useMemo(() => {
|
||||
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
||||
// prefer to use a model that matches the base model
|
||||
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
|
||||
|
||||
@ -28,6 +30,26 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
(type === 'controlnet' || type === 't2i_adapter') &&
|
||||
(firstModel?.type === 'controlnet' || firstModel?.type === 't2i_adapter')
|
||||
) {
|
||||
const defaultPreprocessor = firstModel.default_settings?.preprocessor;
|
||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
dispatch(
|
||||
controlAdapterAdded({
|
||||
type,
|
||||
overrides: {
|
||||
model: firstModel,
|
||||
processorType,
|
||||
processorNode,
|
||||
},
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterAdded({
|
||||
type,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
@ -5,18 +6,22 @@ import {
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isControlAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlAdapterModel = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(
|
||||
selectControlAdaptersSlice,
|
||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model
|
||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model?.key
|
||||
),
|
||||
[id]
|
||||
);
|
||||
|
||||
const model = useAppSelector(selector);
|
||||
const key = useAppSelector(selector);
|
||||
|
||||
return model;
|
||||
const result = useGetModelConfigWithTypeGuard(key ?? skipToken, isControlAdapterModelConfig);
|
||||
|
||||
return result;
|
||||
};
|
||||
|
@ -3,20 +3,14 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import type {
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { controlAdapterImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from './constants';
|
||||
import { CONTROLNET_PROCESSORS } from './constants';
|
||||
import type {
|
||||
ControlAdapterConfig,
|
||||
ControlAdapterProcessorType,
|
||||
@ -194,15 +188,17 @@ export const controlAdaptersSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel;
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
|
||||
}>
|
||||
) => {
|
||||
const { id, model } = action.payload;
|
||||
const { id, modelConfig } = action.payload;
|
||||
const cn = selectControlAdapterById(state, id);
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
|
||||
const model = { key: modelConfig.key, base: modelConfig.base };
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
caAdapter.updateOne(state, { id, changes: { model } });
|
||||
return;
|
||||
@ -215,24 +211,14 @@ export const controlAdaptersSlice = createSlice({
|
||||
|
||||
update.changes.processedControlImage = null;
|
||||
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (model.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
if (modelConfig.type === 'ip_adapter') {
|
||||
// should never happen...
|
||||
return;
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
const processor = buildControlAdapterProcessor(modelConfig);
|
||||
update.changes.processorType = processor.processorType;
|
||||
update.changes.processorNode = processor.processorNode;
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
@ -324,39 +310,23 @@ export const controlAdaptersSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
modelConfig?: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
|
||||
}>
|
||||
) => {
|
||||
const { id } = action.payload;
|
||||
const { id, modelConfig } = action.payload;
|
||||
const cn = selectControlAdapterById(state, id);
|
||||
if (!cn || !isControlNetOrT2IAdapter(cn)) {
|
||||
if (!cn || !isControlNetOrT2IAdapter(cn) || modelConfig?.type === 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
|
||||
id,
|
||||
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
|
||||
};
|
||||
|
||||
if (update.changes.shouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (cn.model?.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
if (update.changes.shouldAutoConfig && modelConfig) {
|
||||
const processor = buildControlAdapterProcessor(modelConfig);
|
||||
update.changes.processorType = processor.processorType;
|
||||
update.changes.processorNode = processor.processorNode;
|
||||
}
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
|
@ -0,0 +1,13 @@
|
||||
import type {
|
||||
ControlAdapterProcessorType,
|
||||
zControlAdapterProcessorType,
|
||||
} from 'features/controlAdapters/store/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
import type { z } from 'zod';
|
||||
|
||||
describe('Control Adapter Types', () => {
|
||||
test('ControlAdapterProcessorType', () =>
|
||||
assert<Equals<ControlAdapterProcessorType, z.infer<typeof zControlAdapterProcessorType>>>());
|
||||
});
|
@ -47,6 +47,25 @@ export type ControlAdapterProcessorNode =
|
||||
* Any ControlNet processor type
|
||||
*/
|
||||
export type ControlAdapterProcessorType = NonNullable<ControlAdapterProcessorNode['type'] | 'none'>;
|
||||
export const zControlAdapterProcessorType = z.enum([
|
||||
'canny_image_processor',
|
||||
'color_map_image_processor',
|
||||
'content_shuffle_image_processor',
|
||||
'depth_anything_image_processor',
|
||||
'hed_image_processor',
|
||||
'lineart_anime_image_processor',
|
||||
'lineart_image_processor',
|
||||
'mediapipe_face_processor',
|
||||
'midas_depth_image_processor',
|
||||
'mlsd_image_processor',
|
||||
'normalbae_image_processor',
|
||||
'dw_openpose_image_processor',
|
||||
'pidi_image_processor',
|
||||
'zoe_depth_image_processor',
|
||||
'none',
|
||||
]);
|
||||
export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterProcessorType =>
|
||||
zControlAdapterProcessorType.safeParse(v).success;
|
||||
|
||||
/**
|
||||
* The Canny processor node, with parameters flagged as required
|
||||
|
@ -0,0 +1,11 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const buildControlAdapterProcessor = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
|
||||
const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
|
||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
|
||||
return { processorType, processorNode };
|
||||
};
|
@ -1,4 +1,5 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
@ -253,8 +254,9 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
const defaultPreprocessor = controlNetModel.default_settings?.preprocessor;
|
||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
|
||||
const controlNet: ControlNetConfigMetadata = {
|
||||
type: 'controlnet',
|
||||
@ -305,8 +307,9 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor;
|
||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfigMetadata = {
|
||||
type: 't2i_adapter',
|
||||
|
@ -0,0 +1,20 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
export const useGetModelConfigWithTypeGuard = <T extends AnyModelConfig>(
|
||||
key: string | typeof skipToken,
|
||||
typeGuard: (config: AnyModelConfig) => config is T
|
||||
) => {
|
||||
const result = useGetModelConfigQuery(key ?? skipToken, {
|
||||
selectFromResult: (result) => {
|
||||
const modelConfig = result.data;
|
||||
return {
|
||||
...result,
|
||||
modelConfig: modelConfig && typeGuard(modelConfig) ? modelConfig : undefined,
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
return result;
|
||||
};
|
@ -83,6 +83,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
return config.type === 't2i_adapter';
|
||||
};
|
||||
|
||||
export const isControlAdapterModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
|
||||
return isControlNetModelConfig(config) || isT2IAdapterModelConfig(config) || isIPAdapterModelConfig(config);
|
||||
};
|
||||
|
||||
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base !== 'sdxl-refiner';
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user