mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): select default control/ip adapter models in control layers
This commit is contained in:
parent
d14b315bc6
commit
c1666a8b5a
@ -1,6 +1,9 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlNetConfig, IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
controlAdapterLayerAdded,
|
||||
ipAdapterLayerAdded,
|
||||
@ -10,6 +13,8 @@ import {
|
||||
maskLayerIPAdapterDeleted,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import type { Layer } from 'features/regionalPrompts/store/types';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import { isControlNetModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
@ -24,19 +29,52 @@ export const guidanceLayerIPAdapterDeleted = createAction<{ layerId: string; ipA
|
||||
export const addRegionalControlToControlAdapterBridge = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: guidanceLayerAdded,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const type = action.payload;
|
||||
const layerId = uuidv4();
|
||||
if (type === 'masked_guidance_layer') {
|
||||
dispatch(maskedGuidanceLayerAdded({ layerId }));
|
||||
return;
|
||||
}
|
||||
|
||||
const state = getState();
|
||||
const baseModel = state.generation.model?.base;
|
||||
const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data;
|
||||
|
||||
if (type === 'ip_adapter_layer') {
|
||||
const ipAdapterId = uuidv4();
|
||||
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
|
||||
const overrides: Partial<IPAdapterConfig> = {
|
||||
id: ipAdapterId,
|
||||
};
|
||||
|
||||
// Find and select the first matching model
|
||||
if (modelConfigs) {
|
||||
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig);
|
||||
overrides.model = models.find((m) => m.base === baseModel) ?? null;
|
||||
}
|
||||
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
|
||||
dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId }));
|
||||
} else if (type === 'control_adapter_layer') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (type === 'control_adapter_layer') {
|
||||
const controlNetId = uuidv4();
|
||||
dispatch(controlAdapterAdded({ type: 'controlnet', overrides: { id: controlNetId } }));
|
||||
const overrides: Partial<ControlNetConfig> = {
|
||||
id: controlNetId,
|
||||
};
|
||||
|
||||
// Find and select the first matching model
|
||||
if (modelConfigs) {
|
||||
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isControlNetModelConfig);
|
||||
const model = models.find((m) => m.base === baseModel) ?? null;
|
||||
overrides.model = model;
|
||||
const defaultPreprocessor = model?.default_settings?.preprocessor;
|
||||
overrides.processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||
overrides.processorNode = CONTROLNET_PROCESSORS[overrides.processorType].buildDefaults(baseModel);
|
||||
}
|
||||
dispatch(controlAdapterAdded({ type: 'controlnet', overrides }));
|
||||
dispatch(controlAdapterLayerAdded({ layerId, controlNetId }));
|
||||
} else if (type === 'masked_guidance_layer') {
|
||||
dispatch(maskedGuidanceLayerAdded({ layerId }));
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
@ -74,10 +112,23 @@ export const addRegionalControlToControlAdapterBridge = (startAppListening: AppS
|
||||
|
||||
startAppListening({
|
||||
actionCreator: guidanceLayerIPAdapterAdded,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const layerId = action.payload;
|
||||
const ipAdapterId = uuidv4();
|
||||
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
|
||||
const overrides: Partial<IPAdapterConfig> = {
|
||||
id: ipAdapterId,
|
||||
};
|
||||
|
||||
// Find and select the first matching model
|
||||
const state = getState();
|
||||
const baseModel = state.generation.model?.base;
|
||||
const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data;
|
||||
if (modelConfigs) {
|
||||
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig);
|
||||
overrides.model = models.find((m) => m.base === baseModel) ?? null;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
|
||||
dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId }));
|
||||
},
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user