mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): split out loras state from canvas rendering state
This commit is contained in:
parent
100832c66d
commit
52202e45de
@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { loraDeleted } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
@ -31,7 +31,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
let modelsCleared = 0;
|
||||
|
||||
// handle incompatible loras
|
||||
state.canvasV2.loras.forEach((lora) => {
|
||||
state.loras.loras.forEach((lora) => {
|
||||
if (lora.model.base !== newBaseModel) {
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
modelsCleared += 1;
|
||||
|
@ -7,9 +7,9 @@ import {
|
||||
bboxWidthChanged,
|
||||
controlLayerModelChanged,
|
||||
ipaModelChanged,
|
||||
loraDeleted,
|
||||
rgIPAdapterModelChanged,
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
|
||||
@ -161,7 +161,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const loraModels = models.filter(isLoRAModelConfig);
|
||||
state.canvasV2.loras.forEach((lora) => {
|
||||
state.loras.loras.forEach((lora) => {
|
||||
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
|
@ -9,6 +9,7 @@ import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasV2PersistConfig, canvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice';
|
||||
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
|
||||
@ -65,6 +66,7 @@ const allReducers = {
|
||||
[toolSlice.name]: toolSlice.reducer,
|
||||
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
|
||||
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
|
||||
[lorasSlice.name]: lorasSlice.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
@ -110,6 +112,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[toolPersistConfig.name]: toolPersistConfig,
|
||||
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
|
||||
[canvasSessionPersistConfig.name]: canvasSessionPersistConfig,
|
||||
[lorasPersistConfig.name]: lorasPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
@ -8,7 +8,6 @@ import { bboxReducers } from 'features/controlLayers/store/bboxReducers';
|
||||
import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers';
|
||||
import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers';
|
||||
import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers';
|
||||
import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
|
||||
import { modelChanged } from 'features/controlLayers/store/paramsSlice';
|
||||
import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers';
|
||||
import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
|
||||
@ -52,7 +51,6 @@ const initialState: CanvasV2State = {
|
||||
isHidden: false,
|
||||
entities: [],
|
||||
},
|
||||
loras: [],
|
||||
ipAdapters: { entities: [] },
|
||||
bbox: {
|
||||
rect: { x: 0, y: 0, width: 512, height: 512 },
|
||||
@ -77,8 +75,6 @@ export const canvasV2Slice = createSlice({
|
||||
...regionsReducers,
|
||||
...inpaintMaskReducers,
|
||||
...bboxReducers,
|
||||
// move out
|
||||
...lorasReducers,
|
||||
entitySelected: (state, action: PayloadAction<EntityIdentifierPayload>) => {
|
||||
const { entityIdentifier } = action.payload;
|
||||
state.selectedEntityIdentifier = entityIdentifier;
|
||||
@ -437,13 +433,6 @@ export const {
|
||||
rgIPAdapterMethodChanged,
|
||||
rgIPAdapterModelChanged,
|
||||
rgIPAdapterCLIPVisionModelChanged,
|
||||
// LoRAs
|
||||
loraAdded,
|
||||
loraRecalled,
|
||||
loraDeleted,
|
||||
loraWeightChanged,
|
||||
loraIsEnabledChanged,
|
||||
loraAllDeleted,
|
||||
// Inpaint mask
|
||||
inpaintMaskAdded,
|
||||
// inpaintMaskRecalled,
|
||||
|
@ -1,50 +0,0 @@
|
||||
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
|
||||
import type { CanvasV2State, LoRA } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||
weight: 0.75,
|
||||
isEnabled: true,
|
||||
};
|
||||
|
||||
const selectLoRA = (state: CanvasV2State, id: string) => state.loras.find((lora) => lora.id === id);
|
||||
|
||||
export const lorasReducers = {
|
||||
loraAdded: {
|
||||
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
|
||||
const { model, id } = action.payload;
|
||||
const parsedModel = zModelIdentifierField.parse(model);
|
||||
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
|
||||
},
|
||||
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
|
||||
const { lora } = action.payload;
|
||||
state.loras.push(lora);
|
||||
},
|
||||
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.loras = state.loras.filter((lora) => lora.id !== id);
|
||||
},
|
||||
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
|
||||
const { id, weight } = action.payload;
|
||||
const lora = selectLoRA(state, id);
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.weight = weight;
|
||||
},
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
|
||||
const { id, isEnabled } = action.payload;
|
||||
const lora = selectLoRA(state, id);
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
loraAllDeleted: (state) => {
|
||||
state.loras = [];
|
||||
},
|
||||
} satisfies SliceCaseReducers<CanvasV2State>;
|
@ -0,0 +1,80 @@
|
||||
import { createSlice, type PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
type LoRAsState = {
|
||||
loras: LoRA[];
|
||||
};
|
||||
|
||||
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||
weight: 0.75,
|
||||
isEnabled: true,
|
||||
};
|
||||
|
||||
const initialState: LoRAsState = {
|
||||
loras: [],
|
||||
};
|
||||
|
||||
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
|
||||
|
||||
export const lorasSlice = createSlice({
|
||||
name: 'loras',
|
||||
initialState,
|
||||
reducers: {
|
||||
loraAdded: {
|
||||
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
|
||||
const { model, id } = action.payload;
|
||||
const parsedModel = zModelIdentifierField.parse(model);
|
||||
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
|
||||
},
|
||||
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
|
||||
const { lora } = action.payload;
|
||||
state.loras.push(lora);
|
||||
},
|
||||
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.loras = state.loras.filter((lora) => lora.id !== id);
|
||||
},
|
||||
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
|
||||
const { id, weight } = action.payload;
|
||||
const lora = selectLoRA(state, id);
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.weight = weight;
|
||||
},
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
|
||||
const { id, isEnabled } = action.payload;
|
||||
const lora = selectLoRA(state, id);
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
loraAllDeleted: (state) => {
|
||||
state.loras = [];
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
|
||||
lorasSlice.actions;
|
||||
|
||||
export const selectLoRAsSlice = (state: RootState) => state.loras;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
return state;
|
||||
};
|
||||
|
||||
export const lorasPersistConfig: PersistConfig<LoRAsState> = {
|
||||
name: lorasSlice.name,
|
||||
initialState,
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
};
|
@ -714,7 +714,6 @@ export type CanvasV2State = {
|
||||
ipAdapters: {
|
||||
entities: CanvasIPAdapterState[];
|
||||
};
|
||||
loras: LoRA[];
|
||||
bbox: {
|
||||
rect: {
|
||||
x: number;
|
||||
|
@ -11,7 +11,7 @@ import {
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/lorasSlice';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
|
||||
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { LoRACard } from 'features/lora/components/LoRACard';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectLoRAsArray = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
|
||||
const selectLoRAsArray = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras);
|
||||
|
||||
export const LoRAList = memo(() => {
|
||||
const lorasArray = useAppSelector(selectLoRAsArray);
|
||||
|
@ -4,14 +4,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { loraAdded } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
|
||||
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectLoRAs = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
|
||||
const selectLoRAs = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras);
|
||||
|
||||
const LoRASelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers';
|
||||
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasSlice';
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasInpaintMaskState,
|
||||
|
@ -1,10 +1,6 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import {
|
||||
bboxHeightChanged,
|
||||
bboxWidthChanged,
|
||||
loraAllDeleted,
|
||||
loraRecalled,
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
|
||||
import {
|
||||
negativePrompt2Changed,
|
||||
negativePromptChanged,
|
||||
|
@ -14,7 +14,7 @@ export const addLoRAs = (
|
||||
posCond: Invocation<'compel'>,
|
||||
negCond: Invocation<'compel'>
|
||||
): void => {
|
||||
const enabledLoRAs = state.canvasV2.loras.filter(
|
||||
const enabledLoRAs = state.loras.loras.filter(
|
||||
(l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2')
|
||||
);
|
||||
const loraCount = enabledLoRAs.length;
|
||||
|
@ -13,7 +13,7 @@ export const addSDXLLoRAs = (
|
||||
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'sdxl_compel_prompt'>
|
||||
): void => {
|
||||
const enabledLoRAs = state.canvasV2.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl');
|
||||
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl');
|
||||
const loraCount = enabledLoRAs.length;
|
||||
|
||||
if (loraCount === 0) {
|
||||
|
@ -18,7 +18,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
const { t } = useTranslation();
|
||||
|
||||
const mainModel = useAppSelector((s) => s.params.model);
|
||||
const addedLoRAs = useAppSelector((s) => s.canvasV2.loras);
|
||||
const addedLoRAs = useAppSelector((s) => s.loras.loras);
|
||||
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
|
||||
mainModel?.key ?? skipToken
|
||||
);
|
||||
|
@ -3,7 +3,7 @@ import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@inv
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
|
||||
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
@ -29,8 +29,8 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
|
||||
const enabledLoRAsCount = canvasV2.loras.filter((l) => l.isEnabled).length;
|
||||
createMemoizedSelector(selectLoRAsSlice, (loras) => {
|
||||
const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
|
||||
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
||||
return { loraTabBadges, accordionBadges };
|
||||
|
Loading…
Reference in New Issue
Block a user