mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): select default control/ip adapter models in control layers
This commit is contained in:
parent
d14b315bc6
commit
c1666a8b5a
@ -1,6 +1,9 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||||
import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import type { ControlNetConfig, IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||||
|
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||||
import {
|
import {
|
||||||
controlAdapterLayerAdded,
|
controlAdapterLayerAdded,
|
||||||
ipAdapterLayerAdded,
|
ipAdapterLayerAdded,
|
||||||
@ -10,6 +13,8 @@ import {
|
|||||||
maskLayerIPAdapterDeleted,
|
maskLayerIPAdapterDeleted,
|
||||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
import type { Layer } from 'features/regionalPrompts/store/types';
|
import type { Layer } from 'features/regionalPrompts/store/types';
|
||||||
|
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||||
|
import { isControlNetModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
@ -24,19 +29,52 @@ export const guidanceLayerIPAdapterDeleted = createAction<{ layerId: string; ipA
|
|||||||
export const addRegionalControlToControlAdapterBridge = (startAppListening: AppStartListening) => {
|
export const addRegionalControlToControlAdapterBridge = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: guidanceLayerAdded,
|
actionCreator: guidanceLayerAdded,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const type = action.payload;
|
const type = action.payload;
|
||||||
const layerId = uuidv4();
|
const layerId = uuidv4();
|
||||||
|
if (type === 'masked_guidance_layer') {
|
||||||
|
dispatch(maskedGuidanceLayerAdded({ layerId }));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
const baseModel = state.generation.model?.base;
|
||||||
|
const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data;
|
||||||
|
|
||||||
if (type === 'ip_adapter_layer') {
|
if (type === 'ip_adapter_layer') {
|
||||||
const ipAdapterId = uuidv4();
|
const ipAdapterId = uuidv4();
|
||||||
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
|
const overrides: Partial<IPAdapterConfig> = {
|
||||||
|
id: ipAdapterId,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Find and select the first matching model
|
||||||
|
if (modelConfigs) {
|
||||||
|
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig);
|
||||||
|
overrides.model = models.find((m) => m.base === baseModel) ?? null;
|
||||||
|
}
|
||||||
|
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
|
||||||
dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId }));
|
dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId }));
|
||||||
} else if (type === 'control_adapter_layer') {
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'control_adapter_layer') {
|
||||||
const controlNetId = uuidv4();
|
const controlNetId = uuidv4();
|
||||||
dispatch(controlAdapterAdded({ type: 'controlnet', overrides: { id: controlNetId } }));
|
const overrides: Partial<ControlNetConfig> = {
|
||||||
|
id: controlNetId,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Find and select the first matching model
|
||||||
|
if (modelConfigs) {
|
||||||
|
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isControlNetModelConfig);
|
||||||
|
const model = models.find((m) => m.base === baseModel) ?? null;
|
||||||
|
overrides.model = model;
|
||||||
|
const defaultPreprocessor = model?.default_settings?.preprocessor;
|
||||||
|
overrides.processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||||
|
overrides.processorNode = CONTROLNET_PROCESSORS[overrides.processorType].buildDefaults(baseModel);
|
||||||
|
}
|
||||||
|
dispatch(controlAdapterAdded({ type: 'controlnet', overrides }));
|
||||||
dispatch(controlAdapterLayerAdded({ layerId, controlNetId }));
|
dispatch(controlAdapterLayerAdded({ layerId, controlNetId }));
|
||||||
} else if (type === 'masked_guidance_layer') {
|
return;
|
||||||
dispatch(maskedGuidanceLayerAdded({ layerId }));
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -74,10 +112,23 @@ export const addRegionalControlToControlAdapterBridge = (startAppListening: AppS
|
|||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: guidanceLayerIPAdapterAdded,
|
actionCreator: guidanceLayerIPAdapterAdded,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const layerId = action.payload;
|
const layerId = action.payload;
|
||||||
const ipAdapterId = uuidv4();
|
const ipAdapterId = uuidv4();
|
||||||
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
|
const overrides: Partial<IPAdapterConfig> = {
|
||||||
|
id: ipAdapterId,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Find and select the first matching model
|
||||||
|
const state = getState();
|
||||||
|
const baseModel = state.generation.model?.base;
|
||||||
|
const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data;
|
||||||
|
if (modelConfigs) {
|
||||||
|
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig);
|
||||||
|
overrides.model = models.find((m) => m.base === baseModel) ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
|
||||||
dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId }));
|
dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId }));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user