refactor(ui): move loras to canvas slice

This commit is contained in:
psychedelicious 2024-06-17 12:23:32 +10:00
parent 083bcbc77d
commit 4b848798e7
14 changed files with 114 additions and 140 deletions

View File

@ -1,12 +1,10 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { caIsEnabledToggled, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { caIsEnabledToggled, loraDeleted, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice';
import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
@ -32,9 +30,9 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
let modelsCleared = 0;
// handle incompatible loras
forEach(state.lora.loras, (lora, id) => {
state.canvasV2.loras.forEach((lora) => {
if (lora.model.base !== newBaseModel) {
dispatch(loraRemoved(id));
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;
}
});
@ -68,7 +66,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}
dispatch(modelChanged(newModel, state.canvasV2.params.model));
dispatch(modelChanged({ model: newModel, previousModel: state.canvasV2.params.model }));
},
});
};

View File

@ -6,18 +6,17 @@ import {
caModelChanged,
heightChanged,
ipaModelChanged,
loraDeleted,
modelChanged,
refinerModelChanged,
rgIPAdapterModelChanged,
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { forEach } from 'lodash-es';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
@ -161,15 +160,13 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
};
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loras = state.lora.loras;
const loraModels = models.filter(isLoRAModelConfig);
forEach(loras, (lora, id) => {
state.canvasV2.loras.forEach((lora) => {
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
dispatch(loraDeleted({ id: lora.id }));
});
};

View File

@ -10,7 +10,6 @@ import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
@ -40,7 +39,6 @@ import { listenerMiddleware } from './middleware/listenerMiddleware';
const allReducers = {
[api.reducerPath]: api.reducer,
[gallerySlice.name]: gallerySlice.reducer,
// [generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
@ -48,9 +46,7 @@ const allReducers = {
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[loraSlice.name]: loraSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
// [sdxlSlice.name]: sdxlSlice.reducer,
[queueSlice.name]: queueSlice.reducer,
[workflowSlice.name]: workflowSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
@ -88,14 +84,11 @@ export type PersistConfig<T = any> = {
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
// [generationPersistConfig.name]: generationPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[workflowPersistConfig.name]: workflowPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
// [sdxlPersistConfig.name]: sdxlPersistConfig,
[loraPersistConfig.name]: loraPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[canvasV2PersistConfig.name]: canvasV2PersistConfig,

View File

@ -8,6 +8,7 @@ import { compositingReducers } from 'features/controlLayers/store/compositingRed
import { controlAdaptersReducers } from 'features/controlLayers/store/controlAdaptersReducers';
import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers';
import { layersReducers } from 'features/controlLayers/store/layersReducers';
import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
import { paramsReducers } from 'features/controlLayers/store/paramsReducers';
import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
import { settingsReducers } from 'features/controlLayers/store/settingsReducers';
@ -27,6 +28,7 @@ const initialState: CanvasV2State = {
controlAdapters: [],
ipAdapters: [],
regions: [],
loras: [],
tool: {
selected: 'bbox',
selectedBuffer: null,
@ -113,6 +115,7 @@ export const canvasV2Slice = createSlice({
...ipAdaptersReducers,
...controlAdaptersReducers,
...regionsReducers,
...lorasReducers,
...paramsReducers,
...compositingReducers,
...settingsReducers,
@ -287,6 +290,13 @@ export const {
setRefinerNegativeAestheticScore,
setRefinerStart,
modelChanged,
// LoRAs
loraAdded,
loraRecalled,
loraDeleted,
loraWeightChanged,
loraIsEnabledChanged,
loraAllDeleted,
} = canvasV2Slice.actions;
export const selectCanvasV2Slice = (state: RootState) => state.canvasV2;

View File

@ -0,0 +1,56 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import type { CanvasV2State, LoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
export const selectLoRA = (state: CanvasV2State, id: string) => state.loras.find((lora) => lora.id === id);
export const selectLoRAOrThrow = (state: CanvasV2State, id: string) => {
const lora = selectLoRA(state, id);
assert(lora, `LoRA with id ${id} not found`);
return lora;
};
export const lorasReducers = {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
const { model, id } = action.payload;
const parsedModel = zModelIdentifierField.parse(model);
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
},
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
},
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
const { lora } = action.payload;
state.loras.push(lora);
},
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.loras = state.loras.filter((lora) => lora.id !== id);
},
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.weight = weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
const { id, isEnabled } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
loraAllDeleted: (state) => {
state.loras = [];
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -6,6 +6,7 @@ import type {
ParameterCFGRescaleMultiplier,
ParameterCFGScale,
ParameterHeight,
ParameterLoRAModel,
ParameterMaskBlurMethod,
ParameterModel,
ParameterNegativePrompt,
@ -799,6 +800,13 @@ export type Dimensions = {
height: number;
};
export type LoRA = {
id: string;
isEnabled: boolean;
model: ParameterLoRAModel;
weight: number;
};
export type CanvasV2State = {
_version: 3;
selectedEntityIdentifier: CanvasEntityIdentifier | null;
@ -806,6 +814,7 @@ export type CanvasV2State = {
controlAdapters: ControlAdapterData[];
ipAdapters: IPAdapterData[];
regions: RegionalGuidanceData[];
loras: LoRA[];
tool: {
selected: Tool;
selectedBuffer: Tool | null;

View File

@ -11,8 +11,8 @@ import {
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { LoRA } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
@ -21,6 +21,8 @@ type LoRACardProps = {
lora: LoRA;
};
const marks = [-1, 0, 1, 2];
export const LoRACard = memo((props: LoRACardProps) => {
const { lora } = props;
const dispatch = useAppDispatch();
@ -28,18 +30,18 @@ export const LoRACard = memo((props: LoRACardProps) => {
const handleChange = useCallback(
(v: number) => {
dispatch(loraWeightChanged({ key: lora.model.key, weight: v }));
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
},
[dispatch, lora.model.key]
[dispatch, lora.id]
);
const handleSetLoraToggle = useCallback(() => {
dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.model.key, lora.isEnabled]);
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.id, lora.isEnabled]);
const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.model.key));
}, [dispatch, lora.model.key]);
dispatch(loraDeleted({ id: lora.id }));
}, [dispatch, lora.id]);
return (
<Card variant="lora">
@ -90,5 +92,3 @@ export const LoRACard = memo((props: LoRACardProps) => {
});
LoRACard.displayName = 'LoRACard';
const marks = [-1, 0, 1, 2];

View File

@ -1,12 +1,11 @@
import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { LoRACard } from 'features/lora/components/LoRACard';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
import { map } from 'lodash-es';
import { memo } from 'react';
const selectLoRAsArray = createMemoizedSelector(selectLoraSlice, (lora) => map(lora.loras));
const selectLoRAsArray = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
export const LoRAList = memo(() => {
const lorasArray = useAppSelector(selectLoRAsArray);
@ -18,7 +17,7 @@ export const LoRAList = memo(() => {
return (
<Flex flexWrap="wrap" gap={2}>
{lorasArray.map((lora) => (
<LoRACard key={lora.model.key} lora={lora} />
<LoRACard key={lora.id} lora={lora} />
))}
</Flex>
);

View File

@ -4,34 +4,34 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { loraAdded, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
const selectLoRAs = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
const LoRASelect = () => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLoRAModels();
const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectAddedLoRAs);
const addedLoRAs = useAppSelector(selectLoRAs);
const currentBaseModel = useAppSelector((s) => s.canvasV2.params.model?.base);
const getIsDisabled = (lora: LoRAModelConfig): boolean => {
const isCompatible = currentBaseModel === lora.base;
const isAdded = Boolean(addedLoRAs[lora.key]);
const getIsDisabled = (model: LoRAModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const isAdded = Boolean(addedLoRAs.find((lora) => lora.model.key === model.key));
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible || isAdded;
};
const _onChange = useCallback(
(lora: LoRAModelConfig | null) => {
if (!lora) {
(model: LoRAModelConfig | null) => {
if (!model) {
return;
}
dispatch(loraAdded(lora));
dispatch(loraAdded({ model }));
},
[dispatch]
);

View File

@ -1,87 +0,0 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = {
model: ParameterLoRAModel;
weight: number;
isEnabled?: boolean;
};
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
type LoraState = {
_version: 2;
loras: Record<string, LoRA>;
};
const initialLoraState: LoraState = {
_version: 2,
loras: {},
};
export const loraSlice = createSlice({
name: 'lora',
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
const model = zModelIdentifierField.parse(action.payload);
state.loras[model.key] = { ...defaultLoRAConfig, model };
},
loraRecalled: (state, action: PayloadAction<LoRA>) => {
state.loras[action.payload.model.key] = action.payload;
},
loraRemoved: (state, action: PayloadAction<string>) => {
const key = action.payload;
delete state.loras[key];
},
loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => {
const { key, weight } = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}
lora.weight = weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => {
const { key, isEnabled } = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
lorasReset: () => deepClone(initialLoraState),
},
});
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } =
loraSlice.actions;
export const selectLoraSlice = (state: RootState) => state.lora;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateLoRAState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
// Model type has changed, so we need to reset the state - too risky to migrate
state = deepClone(initialLoraState);
}
return state;
};
export const loraPersistConfig: PersistConfig<LoraState> = {
name: loraSlice.name,
initialState: initialLoraState,
migrate: migrateLoRAState,
persistDenylist: [],
};

View File

@ -1,4 +1,4 @@
import type { LoRA } from 'features/lora/store/loraSlice';
import type { LoRA } from 'features/controlLayers/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';

View File

@ -2,7 +2,6 @@ import type { RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { LORA_LOADER } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
export const addLoRAs = (
@ -15,8 +14,10 @@ export const addLoRAs = (
posCond: Invocation<'compel'>,
negCond: Invocation<'compel'>
): void => {
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
const enabledLoRAs = state.canvasV2.loras.filter(
(l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2')
);
const loraCount = enabledLoRAs.length;
if (loraCount === 0) {
return;

View File

@ -2,7 +2,6 @@ import type { RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { LORA_LOADER } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
export const addSDXLLoRas = (
@ -14,8 +13,8 @@ export const addSDXLLoRas = (
posCond: Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'sdxl_compel_prompt'>
): void => {
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
const enabledLoRAs = state.canvasV2.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl');
const loraCount = enabledLoRAs.length;
if (loraCount === 0) {
return;

View File

@ -3,9 +3,9 @@ import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@inv
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
@ -15,7 +15,6 @@ import { UseDefaultSettingsButton } from 'features/parameters/components/MainMod
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { filter } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
@ -30,8 +29,8 @@ export const GenerationSettingsAccordion = memo(() => {
const activeTabName = useAppSelector(activeTabNameSelector);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectLoraSlice, (lora) => {
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
const enabledLoRAsCount = canvasV2.loras.filter((l) => l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
return { loraTabBadges, accordionBadges };