feat(ui): split out loras state from canvas rendering state

This commit is contained in:
psychedelicious 2024-08-26 22:29:28 +10:00
parent 100832c66d
commit 52202e45de
16 changed files with 101 additions and 85 deletions

View File

@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { loraDeleted } from 'features/controlLayers/store/canvasV2Slice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
@ -31,7 +31,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
let modelsCleared = 0;
// handle incompatible loras
state.canvasV2.loras.forEach((lora) => {
state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBaseModel) {
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;

View File

@ -7,9 +7,9 @@ import {
bboxWidthChanged,
controlLayerModelChanged,
ipaModelChanged,
loraDeleted,
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
@ -161,7 +161,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loraModels = models.filter(isLoRAModelConfig);
state.canvasV2.loras.forEach((lora) => {
state.loras.loras.forEach((lora) => {
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;

View File

@ -9,6 +9,7 @@ import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasV2PersistConfig, canvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice';
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
@ -65,6 +66,7 @@ const allReducers = {
[toolSlice.name]: toolSlice.reducer,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@ -110,6 +112,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[toolPersistConfig.name]: toolPersistConfig,
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasSessionPersistConfig.name]: canvasSessionPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@ -8,7 +8,6 @@ import { bboxReducers } from 'features/controlLayers/store/bboxReducers';
import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers';
import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers';
import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers';
import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
import { modelChanged } from 'features/controlLayers/store/paramsSlice';
import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers';
import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
@ -52,7 +51,6 @@ const initialState: CanvasV2State = {
isHidden: false,
entities: [],
},
loras: [],
ipAdapters: { entities: [] },
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 },
@ -77,8 +75,6 @@ export const canvasV2Slice = createSlice({
...regionsReducers,
...inpaintMaskReducers,
...bboxReducers,
// move out
...lorasReducers,
entitySelected: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload;
state.selectedEntityIdentifier = entityIdentifier;
@ -437,13 +433,6 @@ export const {
rgIPAdapterMethodChanged,
rgIPAdapterModelChanged,
rgIPAdapterCLIPVisionModelChanged,
// LoRAs
loraAdded,
loraRecalled,
loraDeleted,
loraWeightChanged,
loraIsEnabledChanged,
loraAllDeleted,
// Inpaint mask
inpaintMaskAdded,
// inpaintMaskRecalled,

View File

@ -1,50 +0,0 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import type { CanvasV2State, LoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
const selectLoRA = (state: CanvasV2State, id: string) => state.loras.find((lora) => lora.id === id);
export const lorasReducers = {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
const { model, id } = action.payload;
const parsedModel = zModelIdentifierField.parse(model);
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
},
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
},
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
const { lora } = action.payload;
state.loras.push(lora);
},
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.loras = state.loras.filter((lora) => lora.id !== id);
},
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.weight = weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
const { id, isEnabled } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
loraAllDeleted: (state) => {
state.loras = [];
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -0,0 +1,80 @@
import { createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { LoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
type LoRAsState = {
loras: LoRA[];
};
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
const initialState: LoRAsState = {
loras: [],
};
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
export const lorasSlice = createSlice({
name: 'loras',
initialState,
reducers: {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
const { model, id } = action.payload;
const parsedModel = zModelIdentifierField.parse(model);
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
},
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
},
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
const { lora } = action.payload;
state.loras.push(lora);
},
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.loras = state.loras.filter((lora) => lora.id !== id);
},
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.weight = weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
const { id, isEnabled } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
loraAllDeleted: (state) => {
state.loras = [];
},
},
});
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
lorasSlice.actions;
export const selectLoRAsSlice = (state: RootState) => state.loras;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const lorasPersistConfig: PersistConfig<LoRAsState> = {
name: lorasSlice.name,
initialState,
migrate,
persistDenylist: [],
};

View File

@ -714,7 +714,6 @@ export type CanvasV2State = {
ipAdapters: {
entities: CanvasIPAdapterState[];
};
loras: LoRA[];
bbox: {
rect: {
x: number;

View File

@ -11,7 +11,7 @@ import {
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/canvasV2Slice';
import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/lorasSlice';
import type { LoRA } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';

View File

@ -1,11 +1,11 @@
import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { LoRACard } from 'features/lora/components/LoRACard';
import { memo } from 'react';
const selectLoRAsArray = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
const selectLoRAsArray = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras);
export const LoRAList = memo(() => {
const lorasArray = useAppSelector(selectLoRAsArray);

View File

@ -4,14 +4,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded } from 'features/controlLayers/store/canvasV2Slice';
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
const selectLoRAs = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
const selectLoRAs = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras);
const LoRASelect = () => {
const dispatch = useAppDispatch();

View File

@ -1,5 +1,5 @@
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers';
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasSlice';
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,

View File

@ -1,10 +1,6 @@
import { getStore } from 'app/store/nanostores/store';
import {
bboxHeightChanged,
bboxWidthChanged,
loraAllDeleted,
loraRecalled,
} from 'features/controlLayers/store/canvasV2Slice';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasV2Slice';
import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
import {
negativePrompt2Changed,
negativePromptChanged,

View File

@ -14,7 +14,7 @@ export const addLoRAs = (
posCond: Invocation<'compel'>,
negCond: Invocation<'compel'>
): void => {
const enabledLoRAs = state.canvasV2.loras.filter(
const enabledLoRAs = state.loras.loras.filter(
(l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2')
);
const loraCount = enabledLoRAs.length;

View File

@ -13,7 +13,7 @@ export const addSDXLLoRAs = (
posCond: Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'sdxl_compel_prompt'>
): void => {
const enabledLoRAs = state.canvasV2.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl');
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl');
const loraCount = enabledLoRAs.length;
if (loraCount === 0) {

View File

@ -18,7 +18,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const { t } = useTranslation();
const mainModel = useAppSelector((s) => s.params.model);
const addedLoRAs = useAppSelector((s) => s.canvasV2.loras);
const addedLoRAs = useAppSelector((s) => s.loras.loras);
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
mainModel?.key ?? skipToken
);

View File

@ -3,7 +3,7 @@ import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@inv
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
@ -29,8 +29,8 @@ export const GenerationSettingsAccordion = memo(() => {
const activeTabName = useAppSelector(selectActiveTab);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
const enabledLoRAsCount = canvasV2.loras.filter((l) => l.isEnabled).length;
createMemoizedSelector(selectLoRAsSlice, (loras) => {
const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
return { loraTabBadges, accordionBadges };