fix(ui): select default control/ip adapter models in control layers

This commit is contained in:
psychedelicious 2024-04-30 10:37:45 +10:00 committed by Kent Keirsey
parent d14b315bc6
commit c1666a8b5a

View File

@ -1,6 +1,9 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterAdded, controlAdapterRemoved } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlNetConfig, IPAdapterConfig } from 'features/controlAdapters/store/types';
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import {
controlAdapterLayerAdded,
ipAdapterLayerAdded,
@ -10,6 +13,8 @@ import {
maskLayerIPAdapterDeleted,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import type { Layer } from 'features/regionalPrompts/store/types';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isControlNetModelConfig, isIPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
@ -24,19 +29,52 @@ export const guidanceLayerIPAdapterDeleted = createAction<{ layerId: string; ipA
export const addRegionalControlToControlAdapterBridge = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: guidanceLayerAdded,
effect: (action, { dispatch }) => {
effect: (action, { dispatch, getState }) => {
const type = action.payload;
const layerId = uuidv4();
if (type === 'masked_guidance_layer') {
dispatch(maskedGuidanceLayerAdded({ layerId }));
return;
}
const state = getState();
const baseModel = state.generation.model?.base;
const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data;
if (type === 'ip_adapter_layer') {
const ipAdapterId = uuidv4();
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
const overrides: Partial<IPAdapterConfig> = {
id: ipAdapterId,
};
// Find and select the first matching model
if (modelConfigs) {
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig);
overrides.model = models.find((m) => m.base === baseModel) ?? null;
}
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
dispatch(ipAdapterLayerAdded({ layerId, ipAdapterId }));
} else if (type === 'control_adapter_layer') {
return;
}
if (type === 'control_adapter_layer') {
const controlNetId = uuidv4();
dispatch(controlAdapterAdded({ type: 'controlnet', overrides: { id: controlNetId } }));
const overrides: Partial<ControlNetConfig> = {
id: controlNetId,
};
// Find and select the first matching model
if (modelConfigs) {
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isControlNetModelConfig);
const model = models.find((m) => m.base === baseModel) ?? null;
overrides.model = model;
const defaultPreprocessor = model?.default_settings?.preprocessor;
overrides.processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
overrides.processorNode = CONTROLNET_PROCESSORS[overrides.processorType].buildDefaults(baseModel);
}
dispatch(controlAdapterAdded({ type: 'controlnet', overrides }));
dispatch(controlAdapterLayerAdded({ layerId, controlNetId }));
} else if (type === 'masked_guidance_layer') {
dispatch(maskedGuidanceLayerAdded({ layerId }));
return;
}
},
});
@ -74,10 +112,23 @@ export const addRegionalControlToControlAdapterBridge = (startAppListening: AppS
startAppListening({
actionCreator: guidanceLayerIPAdapterAdded,
effect: (action, { dispatch }) => {
effect: (action, { dispatch, getState }) => {
const layerId = action.payload;
const ipAdapterId = uuidv4();
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides: { id: ipAdapterId } }));
const overrides: Partial<IPAdapterConfig> = {
id: ipAdapterId,
};
// Find and select the first matching model
const state = getState();
const baseModel = state.generation.model?.base;
const modelConfigs = modelsApi.endpoints.getModelConfigs.select(undefined)(state).data;
if (modelConfigs) {
const models = modelConfigsAdapterSelectors.selectAll(modelConfigs).filter(isIPAdapterModelConfig);
overrides.model = models.find((m) => m.base === baseModel) ?? null;
}
dispatch(controlAdapterAdded({ type: 'ip_adapter', overrides }));
dispatch(maskLayerIPAdapterAdded({ layerId, ipAdapterId }));
},
});