fix(ui): reset loras/control adapters when using recall all or remix

This commit is contained in:
psychedelicious 2024-03-14 12:14:24 +11:00 committed by Kent Keirsey
parent bb8e6bbee6
commit 06ff105a1f
3 changed files with 45 additions and 3 deletions

View File

@ -338,6 +338,21 @@ export const controlAdaptersSlice = createSlice({
pendingControlImagesCleared: (state) => { pendingControlImagesCleared: (state) => {
state.pendingControlImages = []; state.pendingControlImages = [];
}, },
ipAdaptersReset: (state) => {
selectAllIPAdapters(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
controlNetsReset: (state) => {
selectAllControlNets(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
t2iAdaptersReset: (state) => {
selectAllT2IAdapters(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(controlAdapterImageProcessed, (state, action) => { builder.addCase(controlAdapterImageProcessed, (state, action) => {
@ -376,6 +391,9 @@ export const {
controlAdapterAutoConfigToggled, controlAdapterAutoConfigToggled,
pendingControlImagesCleared, pendingControlImagesCleared,
controlAdapterModelCleared, controlAdapterModelCleared,
ipAdaptersReset,
controlNetsReset,
t2iAdaptersReset,
} = controlAdaptersSlice.actions; } = controlAdaptersSlice.actions;
export const isAnyControlAdapterAdded = isAnyOf(controlAdapterAdded, controlAdapterRecalled); export const isAnyControlAdapterAdded = isAnyOf(controlAdapterAdded, controlAdapterRecalled);

View File

@ -3,6 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { cloneDeep } from 'lodash-es';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = { export type LoRA = {
@ -57,10 +58,12 @@ export const loraSlice = createSlice({
} }
lora.isEnabled = isEnabled; lora.isEnabled = isEnabled;
}, },
lorasReset: () => cloneDeep(initialLoraState),
}, },
}); });
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled } = loraSlice.actions; export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } =
loraSlice.actions;
export const selectLoraSlice = (state: RootState) => state.lora; export const selectLoraSlice = (state: RootState) => state.lora;

View File

@ -1,8 +1,13 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice'; import {
controlAdapterRecalled,
controlNetsReset,
ipAdaptersReset,
t2iAdaptersReset,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import type { LoRA } from 'features/lora/store/loraSlice'; import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled } from 'features/lora/store/loraSlice'; import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
import type { import type {
ControlNetConfigMetadata, ControlNetConfigMetadata,
IPAdapterConfigMetadata, IPAdapterConfigMetadata,
@ -166,7 +171,11 @@ const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
}; };
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => { const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
if (!loras.length) {
return;
}
const { dispatch } = getStore(); const { dispatch } = getStore();
dispatch(lorasReset());
loras.forEach((lora) => { loras.forEach((lora) => {
dispatch(loraRecalled(lora)); dispatch(loraRecalled(lora));
}); });
@ -177,7 +186,11 @@ const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlN
}; };
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => { const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
if (!controlNets.length) {
return;
}
const { dispatch } = getStore(); const { dispatch } = getStore();
dispatch(controlNetsReset());
controlNets.forEach((controlNet) => { controlNets.forEach((controlNet) => {
dispatch(controlAdapterRecalled(controlNet)); dispatch(controlAdapterRecalled(controlNet));
}); });
@ -188,7 +201,11 @@ const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapt
}; };
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => { const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
if (!t2iAdapters.length) {
return;
}
const { dispatch } = getStore(); const { dispatch } = getStore();
dispatch(t2iAdaptersReset());
t2iAdapters.forEach((t2iAdapter) => { t2iAdapters.forEach((t2iAdapter) => {
dispatch(controlAdapterRecalled(t2iAdapter)); dispatch(controlAdapterRecalled(t2iAdapter));
}); });
@ -199,7 +216,11 @@ const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter)
}; };
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => { const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
if (!ipAdapters.length) {
return;
}
const { dispatch } = getStore(); const { dispatch } = getStore();
dispatch(ipAdaptersReset());
ipAdapters.forEach((ipAdapter) => { ipAdapters.forEach((ipAdapter) => {
dispatch(controlAdapterRecalled(ipAdapter)); dispatch(controlAdapterRecalled(ipAdapter));
}); });