fix(ui): reset loras/control adapters when using recall all or remix

This commit is contained in:
psychedelicious 2024-03-14 12:14:24 +11:00 committed by Kent Keirsey
parent bb8e6bbee6
commit 06ff105a1f
3 changed files with 45 additions and 3 deletions

View File

@ -338,6 +338,21 @@ export const controlAdaptersSlice = createSlice({
pendingControlImagesCleared: (state) => {
state.pendingControlImages = [];
},
ipAdaptersReset: (state) => {
selectAllIPAdapters(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
controlNetsReset: (state) => {
selectAllControlNets(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
t2iAdaptersReset: (state) => {
selectAllT2IAdapters(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
},
extraReducers: (builder) => {
builder.addCase(controlAdapterImageProcessed, (state, action) => {
@ -376,6 +391,9 @@ export const {
controlAdapterAutoConfigToggled,
pendingControlImagesCleared,
controlAdapterModelCleared,
ipAdaptersReset,
controlNetsReset,
t2iAdaptersReset,
} = controlAdaptersSlice.actions;
export const isAnyControlAdapterAdded = isAnyOf(controlAdapterAdded, controlAdapterRecalled);

View File

@ -3,6 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { cloneDeep } from 'lodash-es';
import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = {
@ -57,10 +58,12 @@ export const loraSlice = createSlice({
}
lora.isEnabled = isEnabled;
},
lorasReset: () => cloneDeep(initialLoraState),
},
});
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled } = loraSlice.actions;
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } =
loraSlice.actions;
export const selectLoraSlice = (state: RootState) => state.lora;

View File

@ -1,8 +1,13 @@
import { getStore } from 'app/store/nanostores/store';
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
controlAdapterRecalled,
controlNetsReset,
ipAdaptersReset,
t2iAdaptersReset,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled } from 'features/lora/store/loraSlice';
import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
import type {
ControlNetConfigMetadata,
IPAdapterConfigMetadata,
@ -166,7 +171,11 @@ const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
};
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
if (!loras.length) {
return;
}
const { dispatch } = getStore();
dispatch(lorasReset());
loras.forEach((lora) => {
dispatch(loraRecalled(lora));
});
@ -177,7 +186,11 @@ const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlN
};
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
if (!controlNets.length) {
return;
}
const { dispatch } = getStore();
dispatch(controlNetsReset());
controlNets.forEach((controlNet) => {
dispatch(controlAdapterRecalled(controlNet));
});
@ -188,7 +201,11 @@ const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapt
};
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
if (!t2iAdapters.length) {
return;
}
const { dispatch } = getStore();
dispatch(t2iAdaptersReset());
t2iAdapters.forEach((t2iAdapter) => {
dispatch(controlAdapterRecalled(t2iAdapter));
});
@ -199,7 +216,11 @@ const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter)
};
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
if (!ipAdapters.length) {
return;
}
const { dispatch } = getStore();
dispatch(ipAdaptersReset());
ipAdapters.forEach((ipAdapter) => {
dispatch(controlAdapterRecalled(ipAdapter));
});