mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix image processors to work with new cnet/t2i model format
This commit is contained in:
parent
ad70cdfe87
commit
4bbe6f3548
@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addControlAdapterAutoProcessorUpdateListener } from './listeners/controlAdapterAutoProcessorUpdateListener';
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
@ -125,6 +126,7 @@ addBulkDownloadListeners(startAppListening);
|
||||
// ControlNet
|
||||
addControlNetImageProcessedListener(startAppListening);
|
||||
addControlNetAutoProcessListener(startAppListening);
|
||||
addControlAdapterAutoProcessorUpdateListener(startAppListening);
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||
|
@ -0,0 +1,77 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/actions';
|
||||
import { CONTROLNET_MODEL_DEFAULT_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import {
|
||||
controlAdapterAutoConfigToggled,
|
||||
controlAdapterProcessortTypeChanged,
|
||||
selectControlAdapterById,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
export const addControlAdapterAutoProcessorUpdateListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: isAnyOf(controlAdapterModelChanged, controlAdapterAutoConfigToggled),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
let id;
|
||||
let model;
|
||||
|
||||
const state = getState();
|
||||
const { controlAdapters: controlAdaptersState } = state;
|
||||
|
||||
if (controlAdapterModelChanged.match(action)) {
|
||||
model = action.payload.model;
|
||||
id = action.payload.id;
|
||||
|
||||
const cn = selectControlAdapterById(controlAdaptersState, id);
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (controlAdapterAutoConfigToggled.match(action)) {
|
||||
id = action.payload.id;
|
||||
|
||||
const cn = selectControlAdapterById(controlAdaptersState, id);
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if they turned off autoconfig, return
|
||||
if (!cn.shouldAutoConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
model = cn.model;
|
||||
}
|
||||
|
||||
if (!model || !id) {
|
||||
return;
|
||||
}
|
||||
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
const { data: modelConfig } = modelsApi.endpoints.getModelConfig.select(model.key)(state);
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (modelConfig?.name.includes(modelSubstring)) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(
|
||||
controlAdapterProcessortTypeChanged({ id, processorType: processorType || 'none', shouldAutoConfig: true })
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -2,11 +2,10 @@ import type { AnyListenerPredicate } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
|
||||
import { controlAdapterImageProcessed, controlAdapterModelChanged } from 'features/controlAdapters/store/actions';
|
||||
import {
|
||||
controlAdapterAutoConfigToggled,
|
||||
controlAdapterImageChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterProcessorParamsChanged,
|
||||
controlAdapterProcessortTypeChanged,
|
||||
selectControlAdapterById,
|
||||
|
@ -5,7 +5,7 @@ import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useCo
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/actions';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
@ -25,6 +25,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
console.log('on change');
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
@ -1,6 +1,12 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CONTROLNET_MODEL_DEFAULT_PROCESSORS, CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import type {
|
||||
ControlAdapterProcessorType,
|
||||
ControlAdapterType,
|
||||
RequiredControlAdapterProcessorNode,
|
||||
} from 'features/controlAdapters/store/types';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
||||
@ -22,6 +28,29 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
return models[0];
|
||||
}, [baseModel, models]);
|
||||
|
||||
const processor = useMemo(() => {
|
||||
let processorType;
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (firstModel?.name.includes(modelSubstring)) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!processorType) {
|
||||
processorType = 'none';
|
||||
}
|
||||
|
||||
const processorNode =
|
||||
processorType === 'none'
|
||||
? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode)
|
||||
: (cloneDeep(
|
||||
CONTROLNET_PROCESSORS[processorType as ControlAdapterProcessorType].default
|
||||
) as RequiredControlAdapterProcessorNode);
|
||||
|
||||
return { processorType, processorNode };
|
||||
}, [firstModel]);
|
||||
|
||||
const isDisabled = useMemo(() => !firstModel, [firstModel]);
|
||||
|
||||
const addControlAdapter = useCallback(() => {
|
||||
@ -31,10 +60,10 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
dispatch(
|
||||
controlAdapterAdded({
|
||||
type,
|
||||
overrides: { model: firstModel },
|
||||
overrides: { model: firstModel, ...processor },
|
||||
})
|
||||
);
|
||||
}, [dispatch, firstModel, isDisabled, type]);
|
||||
}, [dispatch, firstModel, isDisabled, type, processor]);
|
||||
|
||||
return [addControlAdapter, isDisabled] as const;
|
||||
};
|
||||
|
@ -1,5 +1,15 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type {
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const controlAdapterImageProcessed = createAction<{
|
||||
id: string;
|
||||
}>('controlAdapters/imageProcessed');
|
||||
|
||||
export const controlAdapterModelChanged = createAction<{
|
||||
id: string;
|
||||
model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel;
|
||||
}>('controlAdapters/controlAdapterModelChanged');
|
||||
|
@ -3,20 +3,12 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import type {
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { controlAdapterImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from './constants';
|
||||
import { CONTROLNET_PROCESSORS } from './constants';
|
||||
import type {
|
||||
ControlAdapterConfig,
|
||||
ControlAdapterProcessorType,
|
||||
@ -190,52 +182,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
changes: { model: null },
|
||||
});
|
||||
},
|
||||
controlAdapterModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel;
|
||||
}>
|
||||
) => {
|
||||
const { id, model } = action.payload;
|
||||
const cn = selectControlAdapterById(state, id);
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
caAdapter.updateOne(state, { id, changes: { model } });
|
||||
return;
|
||||
}
|
||||
|
||||
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
|
||||
id,
|
||||
changes: { model, shouldAutoConfig: true },
|
||||
};
|
||||
|
||||
update.changes.processedControlImage = null;
|
||||
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (model.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
controlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
|
||||
const { id, weight } = action.payload;
|
||||
caAdapter.updateOne(state, { id, changes: { weight } });
|
||||
@ -298,17 +244,19 @@ export const controlAdaptersSlice = createSlice({
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
processorType: ControlAdapterProcessorType;
|
||||
shouldAutoConfig?: boolean;
|
||||
}>
|
||||
) => {
|
||||
const { id, processorType } = action.payload;
|
||||
const { id, processorType, shouldAutoConfig = false } = action.payload;
|
||||
const cn = selectControlAdapterById(state, id);
|
||||
if (!cn || !isControlNetOrT2IAdapter(cn)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = cloneDeep(
|
||||
CONTROLNET_PROCESSORS[processorType].default
|
||||
) as RequiredControlAdapterProcessorNode;
|
||||
const processorNode =
|
||||
processorType === 'none'
|
||||
? (cloneDeep(CONTROLNET_PROCESSORS.none.default) as RequiredControlAdapterProcessorNode)
|
||||
: (cloneDeep(CONTROLNET_PROCESSORS[processorType].default) as RequiredControlAdapterProcessorNode);
|
||||
|
||||
caAdapter.updateOne(state, {
|
||||
id,
|
||||
@ -316,7 +264,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
processorType,
|
||||
processedControlImage: null,
|
||||
processorNode,
|
||||
shouldAutoConfig: false,
|
||||
shouldAutoConfig,
|
||||
},
|
||||
});
|
||||
},
|
||||
@ -337,28 +285,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
|
||||
};
|
||||
|
||||
if (update.changes.shouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (cn.model?.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
}
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
controlAdaptersReset: () => {
|
||||
@ -393,7 +319,6 @@ export const {
|
||||
controlAdapterImageChanged,
|
||||
controlAdapterProcessedImageChanged,
|
||||
controlAdapterIsEnabledChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterWeightChanged,
|
||||
controlAdapterBeginStepPctChanged,
|
||||
controlAdapterEndStepPctChanged,
|
||||
|
Loading…
Reference in New Issue
Block a user