mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(ui): move loras to canvas slice
This commit is contained in:
parent
083bcbc77d
commit
4b848798e7
@ -1,12 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { caIsEnabledToggled, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { caIsEnabledToggled, loraDeleted, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
|
||||
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@ -32,9 +30,9 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
let modelsCleared = 0;
|
||||
|
||||
// handle incompatible loras
|
||||
forEach(state.lora.loras, (lora, id) => {
|
||||
state.canvasV2.loras.forEach((lora) => {
|
||||
if (lora.model.base !== newBaseModel) {
|
||||
dispatch(loraRemoved(id));
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
});
|
||||
@ -68,7 +66,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(modelChanged(newModel, state.canvasV2.params.model));
|
||||
dispatch(modelChanged({ model: newModel, previousModel: state.canvasV2.params.model }));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -6,18 +6,17 @@ import {
|
||||
caModelChanged,
|
||||
heightChanged,
|
||||
ipaModelChanged,
|
||||
loraDeleted,
|
||||
modelChanged,
|
||||
refinerModelChanged,
|
||||
rgIPAdapterModelChanged,
|
||||
vaeSelected,
|
||||
widthChanged,
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type { Logger } from 'roarr';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
@ -161,15 +160,13 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
};
|
||||
|
||||
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const loras = state.lora.loras;
|
||||
const loraModels = models.filter(isLoRAModelConfig);
|
||||
|
||||
forEach(loras, (lora, id) => {
|
||||
state.canvasV2.loras.forEach((lora) => {
|
||||
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(loraRemoved(id));
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
});
|
||||
};
|
||||
|
||||
|
@ -10,7 +10,6 @@ import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
|
||||
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
@ -40,7 +39,6 @@ import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
const allReducers = {
|
||||
[api.reducerPath]: api.reducer,
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
// [generationSlice.name]: generationSlice.reducer,
|
||||
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
@ -48,9 +46,7 @@ const allReducers = {
|
||||
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
|
||||
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
|
||||
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
||||
[loraSlice.name]: loraSlice.reducer,
|
||||
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
||||
// [sdxlSlice.name]: sdxlSlice.reducer,
|
||||
[queueSlice.name]: queueSlice.reducer,
|
||||
[workflowSlice.name]: workflowSlice.reducer,
|
||||
[hrfSlice.name]: hrfSlice.reducer,
|
||||
@ -88,14 +84,11 @@ export type PersistConfig<T = any> = {
|
||||
|
||||
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[galleryPersistConfig.name]: galleryPersistConfig,
|
||||
// [generationPersistConfig.name]: generationPersistConfig,
|
||||
[nodesPersistConfig.name]: nodesPersistConfig,
|
||||
[systemPersistConfig.name]: systemPersistConfig,
|
||||
[workflowPersistConfig.name]: workflowPersistConfig,
|
||||
[uiPersistConfig.name]: uiPersistConfig,
|
||||
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
|
||||
// [sdxlPersistConfig.name]: sdxlPersistConfig,
|
||||
[loraPersistConfig.name]: loraPersistConfig,
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[canvasV2PersistConfig.name]: canvasV2PersistConfig,
|
||||
|
@ -8,6 +8,7 @@ import { compositingReducers } from 'features/controlLayers/store/compositingRed
|
||||
import { controlAdaptersReducers } from 'features/controlLayers/store/controlAdaptersReducers';
|
||||
import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers';
|
||||
import { layersReducers } from 'features/controlLayers/store/layersReducers';
|
||||
import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
|
||||
import { paramsReducers } from 'features/controlLayers/store/paramsReducers';
|
||||
import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
|
||||
import { settingsReducers } from 'features/controlLayers/store/settingsReducers';
|
||||
@ -27,6 +28,7 @@ const initialState: CanvasV2State = {
|
||||
controlAdapters: [],
|
||||
ipAdapters: [],
|
||||
regions: [],
|
||||
loras: [],
|
||||
tool: {
|
||||
selected: 'bbox',
|
||||
selectedBuffer: null,
|
||||
@ -113,6 +115,7 @@ export const canvasV2Slice = createSlice({
|
||||
...ipAdaptersReducers,
|
||||
...controlAdaptersReducers,
|
||||
...regionsReducers,
|
||||
...lorasReducers,
|
||||
...paramsReducers,
|
||||
...compositingReducers,
|
||||
...settingsReducers,
|
||||
@ -287,6 +290,13 @@ export const {
|
||||
setRefinerNegativeAestheticScore,
|
||||
setRefinerStart,
|
||||
modelChanged,
|
||||
// LoRAs
|
||||
loraAdded,
|
||||
loraRecalled,
|
||||
loraDeleted,
|
||||
loraWeightChanged,
|
||||
loraIsEnabledChanged,
|
||||
loraAllDeleted,
|
||||
} = canvasV2Slice.actions;
|
||||
|
||||
export const selectCanvasV2Slice = (state: RootState) => state.canvasV2;
|
||||
|
@ -0,0 +1,56 @@
|
||||
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
|
||||
import type { CanvasV2State, LoRA } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||
weight: 0.75,
|
||||
isEnabled: true,
|
||||
};
|
||||
|
||||
export const selectLoRA = (state: CanvasV2State, id: string) => state.loras.find((lora) => lora.id === id);
|
||||
export const selectLoRAOrThrow = (state: CanvasV2State, id: string) => {
|
||||
const lora = selectLoRA(state, id);
|
||||
assert(lora, `LoRA with id ${id} not found`);
|
||||
return lora;
|
||||
};
|
||||
|
||||
export const lorasReducers = {
|
||||
loraAdded: {
|
||||
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
|
||||
const { model, id } = action.payload;
|
||||
const parsedModel = zModelIdentifierField.parse(model);
|
||||
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
|
||||
},
|
||||
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
|
||||
const { lora } = action.payload;
|
||||
state.loras.push(lora);
|
||||
},
|
||||
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.loras = state.loras.filter((lora) => lora.id !== id);
|
||||
},
|
||||
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
|
||||
const { id, weight } = action.payload;
|
||||
const lora = selectLoRA(state, id);
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.weight = weight;
|
||||
},
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
|
||||
const { id, isEnabled } = action.payload;
|
||||
const lora = selectLoRA(state, id);
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
loraAllDeleted: (state) => {
|
||||
state.loras = [];
|
||||
},
|
||||
} satisfies SliceCaseReducers<CanvasV2State>;
|
@ -6,6 +6,7 @@ import type {
|
||||
ParameterCFGRescaleMultiplier,
|
||||
ParameterCFGScale,
|
||||
ParameterHeight,
|
||||
ParameterLoRAModel,
|
||||
ParameterMaskBlurMethod,
|
||||
ParameterModel,
|
||||
ParameterNegativePrompt,
|
||||
@ -799,6 +800,13 @@ export type Dimensions = {
|
||||
height: number;
|
||||
};
|
||||
|
||||
export type LoRA = {
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: ParameterLoRAModel;
|
||||
weight: number;
|
||||
};
|
||||
|
||||
export type CanvasV2State = {
|
||||
_version: 3;
|
||||
selectedEntityIdentifier: CanvasEntityIdentifier | null;
|
||||
@ -806,6 +814,7 @@ export type CanvasV2State = {
|
||||
controlAdapters: ControlAdapterData[];
|
||||
ipAdapters: IPAdapterData[];
|
||||
regions: RegionalGuidanceData[];
|
||||
loras: LoRA[];
|
||||
tool: {
|
||||
selected: Tool;
|
||||
selectedBuffer: Tool | null;
|
||||
|
@ -11,8 +11,8 @@ import {
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
||||
import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
@ -21,6 +21,8 @@ type LoRACardProps = {
|
||||
lora: LoRA;
|
||||
};
|
||||
|
||||
const marks = [-1, 0, 1, 2];
|
||||
|
||||
export const LoRACard = memo((props: LoRACardProps) => {
|
||||
const { lora } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -28,18 +30,18 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(loraWeightChanged({ key: lora.model.key, weight: v }));
|
||||
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
|
||||
},
|
||||
[dispatch, lora.model.key]
|
||||
[dispatch, lora.id]
|
||||
);
|
||||
|
||||
const handleSetLoraToggle = useCallback(() => {
|
||||
dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled }));
|
||||
}, [dispatch, lora.model.key, lora.isEnabled]);
|
||||
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
|
||||
}, [dispatch, lora.id, lora.isEnabled]);
|
||||
|
||||
const handleRemoveLora = useCallback(() => {
|
||||
dispatch(loraRemoved(lora.model.key));
|
||||
}, [dispatch, lora.model.key]);
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
}, [dispatch, lora.id]);
|
||||
|
||||
return (
|
||||
<Card variant="lora">
|
||||
@ -90,5 +92,3 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
});
|
||||
|
||||
LoRACard.displayName = 'LoRACard';
|
||||
|
||||
const marks = [-1, 0, 1, 2];
|
||||
|
@ -1,12 +1,11 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { LoRACard } from 'features/lora/components/LoRACard';
|
||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectLoRAsArray = createMemoizedSelector(selectLoraSlice, (lora) => map(lora.loras));
|
||||
const selectLoRAsArray = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
|
||||
|
||||
export const LoRAList = memo(() => {
|
||||
const lorasArray = useAppSelector(selectLoRAsArray);
|
||||
@ -18,7 +17,7 @@ export const LoRAList = memo(() => {
|
||||
return (
|
||||
<Flex flexWrap="wrap" gap={2}>
|
||||
{lorasArray.map((lora) => (
|
||||
<LoRACard key={lora.model.key} lora={lora} />
|
||||
<LoRACard key={lora.id} lora={lora} />
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -4,34 +4,34 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { loraAdded, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||
const selectLoRAs = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras);
|
||||
|
||||
const LoRASelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const { t } = useTranslation();
|
||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
||||
const currentBaseModel = useAppSelector((s) => s.canvasV2.params.model?.base);
|
||||
|
||||
const getIsDisabled = (lora: LoRAModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === lora.base;
|
||||
const isAdded = Boolean(addedLoRAs[lora.key]);
|
||||
const getIsDisabled = (model: LoRAModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
const isAdded = Boolean(addedLoRAs.find((lora) => lora.model.key === model.key));
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible || isAdded;
|
||||
};
|
||||
|
||||
const _onChange = useCallback(
|
||||
(lora: LoRAModelConfig | null) => {
|
||||
if (!lora) {
|
||||
(model: LoRAModelConfig | null) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
dispatch(loraAdded(lora));
|
||||
dispatch(loraAdded({ model }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -1,87 +0,0 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = {
|
||||
model: ParameterLoRAModel;
|
||||
weight: number;
|
||||
isEnabled?: boolean;
|
||||
};
|
||||
|
||||
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||
weight: 0.75,
|
||||
isEnabled: true,
|
||||
};
|
||||
|
||||
type LoraState = {
|
||||
_version: 2;
|
||||
loras: Record<string, LoRA>;
|
||||
};
|
||||
|
||||
const initialLoraState: LoraState = {
|
||||
_version: 2,
|
||||
loras: {},
|
||||
};
|
||||
|
||||
export const loraSlice = createSlice({
|
||||
name: 'lora',
|
||||
initialState: initialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||
const model = zModelIdentifierField.parse(action.payload);
|
||||
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||
state.loras[action.payload.model.key] = action.payload;
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const key = action.payload;
|
||||
delete state.loras[key];
|
||||
},
|
||||
loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => {
|
||||
const { key, weight } = action.payload;
|
||||
const lora = state.loras[key];
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.weight = weight;
|
||||
},
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => {
|
||||
const { key, isEnabled } = action.payload;
|
||||
const lora = state.loras[key];
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
lorasReset: () => deepClone(initialLoraState),
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } =
|
||||
loraSlice.actions;
|
||||
|
||||
export const selectLoraSlice = (state: RootState) => state.lora;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateLoRAState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
// Model type has changed, so we need to reset the state - too risky to migrate
|
||||
state = deepClone(initialLoraState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const loraPersistConfig: PersistConfig<LoraState> = {
|
||||
name: loraSlice.name,
|
||||
initialState: initialLoraState,
|
||||
migrate: migrateLoRAState,
|
||||
persistDenylist: [],
|
||||
};
|
@ -1,4 +1,4 @@
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
|
@ -2,7 +2,6 @@ import type { RootState } from 'app/store/store';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { LORA_LOADER } from 'features/nodes/util/graph/constants';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
|
||||
export const addLoRAs = (
|
||||
@ -15,8 +14,10 @@ export const addLoRAs = (
|
||||
posCond: Invocation<'compel'>,
|
||||
negCond: Invocation<'compel'>
|
||||
): void => {
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
const enabledLoRAs = state.canvasV2.loras.filter(
|
||||
(l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2')
|
||||
);
|
||||
const loraCount = enabledLoRAs.length;
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
|
@ -2,7 +2,6 @@ import type { RootState } from 'app/store/store';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { LORA_LOADER } from 'features/nodes/util/graph/constants';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
|
||||
export const addSDXLLoRas = (
|
||||
@ -14,8 +13,8 @@ export const addSDXLLoRas = (
|
||||
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'sdxl_compel_prompt'>
|
||||
): void => {
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
const enabledLoRAs = state.canvasV2.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl');
|
||||
const loraCount = enabledLoRAs.length;
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
|
@ -3,9 +3,9 @@ import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@inv
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
@ -15,7 +15,6 @@ import { UseDefaultSettingsButton } from 'features/parameters/components/MainMod
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { filter } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
|
||||
@ -30,8 +29,8 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
|
||||
const enabledLoRAsCount = canvasV2.loras.filter((l) => l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
|
||||
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
||||
return { loraTabBadges, accordionBadges };
|
||||
|
Loading…
Reference in New Issue
Block a user