feat(ui): use lruMemoize for all entity adapter selectors

This commit is contained in:
psychedelicious 2024-01-05 12:16:34 +11:00
parent 83fbd4bdf2
commit 6924b04d7c
13 changed files with 100 additions and 57 deletions

View File

@ -1,4 +1,9 @@
import { createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
import {
createDraftSafeSelectorCreator,
createSelectorCreator,
lruMemoize,
} from '@reduxjs/toolkit';
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
import { isEqual } from 'lodash-es';
/**
@ -19,3 +24,12 @@ export const createLruSelector = createSelectorCreator({
memoize: lruMemoize,
argsMemoize: lruMemoize,
});
export const createLruDraftSafeSelector = createDraftSafeSelectorCreator({
memoize: lruMemoize,
argsMemoize: lruMemoize,
});
export const getSelectorsOptions: GetSelectorsOptions = {
createSelector: createLruDraftSafeSelector,
};

View File

@ -3,7 +3,7 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageCache } from 'services/api/types';
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
import { startAppListening } from '..';
@ -33,7 +33,7 @@ export const addFirstListImagesListener = () => {
if (data.ids.length > 0) {
// Select the first image
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0];
const firstImage = imagesSelectors.selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null));
}
},

View File

@ -17,7 +17,7 @@ import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp, forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { imagesSelectors } from 'services/api/util';
import { startAppListening } from '..';
@ -54,7 +54,7 @@ export const addRequestedSingleImageDeletionListener = () => {
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data
? imagesAdapter.getSelectors().selectAll(data)
? imagesSelectors.selectAll(data)
: [];
const deletedImageIndex = cachedImageDTOs.findIndex(
@ -187,7 +187,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const newSelectedImageDTO = data
? imagesAdapter.getSelectors().selectAll(data)[0]
? imagesSelectors.selectAll(data)[0]
: undefined;
if (newSelectedImageDTO) {

View File

@ -17,9 +17,9 @@ import {
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import {
mainModelsAdapter,
mainModelsAdapterSelectors,
modelsApi,
vaeModelsAdapter,
vaeModelsAdapterSelectors,
} from 'services/api/endpoints/models';
import type { TypeGuardFor } from 'services/api/types';
@ -43,7 +43,7 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().generation.model;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
@ -94,7 +94,7 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
@ -145,8 +145,7 @@ export const addModelsLoadedListener = () => {
return;
}
const firstModel = vaeModelsAdapter
.getSelectors()
const firstModel = vaeModelsAdapterSelectors
.selectAll(action.payload)[0];
if (!firstModel) {

View File

@ -1,7 +1,10 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
import {
mainModelsAdapterSelectors,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..';
@ -37,8 +40,7 @@ export const addTabChangedListener = () => {
}
// need to filter out all the invalid canvas models (currently refiner & any)
const validCanvasModels = mainModelsAdapter
.getSelectors()
const validCanvasModels = mainModelsAdapterSelectors
.selectAll(models)
.filter((model) =>
['sd-1', 'sd-2', 'sdxl'].includes(model.base_model)

View File

@ -1,9 +1,9 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import { useMemo } from 'react';
import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
t2iAdapterModelsAdapter,
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
@ -14,7 +14,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
const controlNetModels = useMemo(
() =>
controlNetModelsData
? controlNetModelsAdapter.getSelectors().selectAll(controlNetModelsData)
? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData)
: [],
[controlNetModelsData]
);
@ -23,7 +23,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
const t2iAdapterModels = useMemo(
() =>
t2iAdapterModelsData
? t2iAdapterModelsAdapter.getSelectors().selectAll(t2iAdapterModelsData)
? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData)
: [],
[t2iAdapterModelsData]
);
@ -31,7 +31,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
const ipAdapterModels = useMemo(
() =>
ipAdapterModelsData
? ipAdapterModelsAdapter.getSelectors().selectAll(ipAdapterModelsData)
? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData)
: [],
[ipAdapterModelsData]
);

View File

@ -1,5 +1,6 @@
import type { PayloadAction, Update } from '@reduxjs/toolkit';
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import type {
ParameterControlNetModel,
@ -36,6 +37,10 @@ import {
export const caAdapter = createEntityAdapter<ControlAdapterConfig, string>({
selectId: (ca) => ca.id,
});
export const caAdapterSelectors = caAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const {
selectById: selectControlAdapterById,
@ -43,7 +48,7 @@ export const {
selectEntities: selectControlAdapterEntities,
selectIds: selectControlAdapterIds,
selectTotal: selectControlAdapterTotal,
} = caAdapter.getSelectors();
} = caAdapterSelectors;
export const initialControlAdapterState: ControlAdaptersState =
caAdapter.getInitialState<{

View File

@ -16,7 +16,7 @@ import {
useLazyListImagesQuery,
} from 'services/api/endpoints/images';
import type { ListImagesArgs } from 'services/api/types';
import { imagesAdapter } from 'services/api/util';
import { imagesSelectors } from 'services/api/util';
export type UseNextPrevImageState = {
virtuosoRef: RefObject<VirtuosoGridHandle> | undefined;
@ -63,9 +63,7 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
limit: IMAGE_LIMIT,
};
const selectors = imagesAdapter.getSelectors();
const images = selectors.selectAll(data);
const images = imagesSelectors.selectAll(data);
const currentImageIndex = images.findIndex(
(i) => i.image_name === lastSelectedImage.image_name
@ -77,10 +75,10 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
const prevImageId = images[prevImageIndex]?.image_name;
const nextImage = nextImageId
? selectors.selectById(data, nextImageId)
? imagesSelectors.selectById(data, nextImageId)
: undefined;
const prevImage = prevImageId
? selectors.selectById(data, prevImageId)
? imagesSelectors.selectById(data, prevImageId)
: undefined;
const imagesLength = images.length;

View File

@ -87,10 +87,10 @@ import { isNil } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
loraModelsAdapter,
t2iAdapterModelsAdapter,
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
loraModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
@ -488,9 +488,10 @@ export const useRecallParameters = () => {
const { base_model, model_name } = loraMetadataItem.lora;
const matchingLoRA = loraModels
? loraModelsAdapter
.getSelectors()
.selectById(loraModels, `${base_model}/lora/${model_name}`)
? loraModelsAdapterSelectors.selectById(
loraModels,
`${base_model}/lora/${model_name}`
)
: undefined;
if (!matchingLoRA) {
@ -553,12 +554,10 @@ export const useRecallParameters = () => {
} = controlnetMetadataItem;
const matchingControlNetModel = controlNetModels
? controlNetModelsAdapter
.getSelectors()
.selectById(
controlNetModels,
`${control_model.base_model}/controlnet/${control_model.model_name}`
)
? controlNetModelsAdapterSelectors.selectById(
controlNetModels,
`${control_model.base_model}/controlnet/${control_model.model_name}`
)
: undefined;
if (!matchingControlNetModel) {
@ -649,12 +648,10 @@ export const useRecallParameters = () => {
} = t2iAdapterMetadataItem;
const matchingT2IAdapterModel = t2iAdapterModels
? t2iAdapterModelsAdapter
.getSelectors()
.selectById(
t2iAdapterModels,
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
)
? t2iAdapterModelsAdapterSelectors.selectById(
t2iAdapterModels,
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
)
: undefined;
if (!matchingT2IAdapterModel) {
@ -738,12 +735,10 @@ export const useRecallParameters = () => {
} = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapterModels
? ipAdapterModelsAdapter
.getSelectors()
.selectById(
ipAdapterModels,
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
)
? ipAdapterModelsAdapterSelectors.selectById(
ipAdapterModels,
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
)
: undefined;
if (!matchingIPAdapterModel) {

View File

@ -14,7 +14,7 @@ import { useTranslation } from 'react-i18next';
import type { Components, ItemContent } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import {
queueItemsAdapter,
queueItemsAdapterSelectors,
useListQueueItemsQuery,
} from 'services/api/endpoints/queue';
import type { SessionQueueItemDTO } from 'services/api/types';
@ -77,7 +77,7 @@ const QueueList = () => {
if (!listQueueItemsData) {
return [];
}
return queueItemsAdapter.getSelectors().selectAll(listQueueItemsData);
return queueItemsAdapterSelectors.selectAll(listQueueItemsData);
}, [listQueueItemsData]);
const handleLoadMore = useCallback(() => {

View File

@ -1,5 +1,6 @@
import type { EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { cloneDeep } from 'lodash-es';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
@ -134,28 +135,48 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const controlNetModelsAdapter =
createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const controlNetModelsAdapterSelectors =
controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter =
createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const ipAdapterModelsAdapterSelectors =
ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter =
createEntityAdapter<T2IAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const t2iAdapterModelsAdapterSelectors =
t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const textualInversionModelsAdapterSelectors =
textualInversionModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const getModelId = ({
base_model,

View File

@ -4,6 +4,7 @@ import type {
UnknownAction,
} from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { $queueId } from 'app/store/nanostores/queueId';
import { listParamsReset } from 'features/queue/store/queueSlice';
import queryString from 'query-string';
@ -59,6 +60,10 @@ export const queueItemsAdapter = createEntityAdapter<
return 0;
},
});
export const queueItemsAdapterSelectors = queueItemsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const queueApi = api.injectEndpoints({
endpoints: (build) => ({
@ -308,7 +313,7 @@ export const queueApi = api.injectEndpoints({
merge: (cache, response) => {
queueItemsAdapter.addMany(
cache,
queueItemsAdapter.getSelectors().selectAll(response)
queueItemsAdapterSelectors.selectAll(response)
);
cache.has_more = response.has_more;
},

View File

@ -1,4 +1,5 @@
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { dateComparator } from 'common/util/dateComparator';
import {
ASSETS_CATEGORIES,
@ -82,7 +83,10 @@ export const imagesAdapter = createEntityAdapter<ImageDTO, string>({
});
// Create selectors for the adapter.
export const imagesSelectors = imagesAdapter.getSelectors();
export const imagesSelectors = imagesAdapter.getSelectors(
undefined,
getSelectorsOptions
);
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>