feat(ui): use lruMemoize for all entity adapter selectors

This commit is contained in:
psychedelicious 2024-01-05 12:16:34 +11:00
parent 83fbd4bdf2
commit 6924b04d7c
13 changed files with 100 additions and 57 deletions

View File

@ -1,4 +1,9 @@
import { createSelectorCreator, lruMemoize } from '@reduxjs/toolkit'; import {
createDraftSafeSelectorCreator,
createSelectorCreator,
lruMemoize,
} from '@reduxjs/toolkit';
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
/** /**
@ -19,3 +24,12 @@ export const createLruSelector = createSelectorCreator({
memoize: lruMemoize, memoize: lruMemoize,
argsMemoize: lruMemoize, argsMemoize: lruMemoize,
}); });
export const createLruDraftSafeSelector = createDraftSafeSelectorCreator({
memoize: lruMemoize,
argsMemoize: lruMemoize,
});
export const getSelectorsOptions: GetSelectorsOptions = {
createSelector: createLruDraftSafeSelector,
};

View File

@ -3,7 +3,7 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import type { ImageCache } from 'services/api/types'; import type { ImageCache } from 'services/api/types';
import { getListImagesUrl, imagesAdapter } from 'services/api/util'; import { getListImagesUrl, imagesSelectors } from 'services/api/util';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -33,7 +33,7 @@ export const addFirstListImagesListener = () => {
if (data.ids.length > 0) { if (data.ids.length > 0) {
// Select the first image // Select the first image
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0]; const firstImage = imagesSelectors.selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null)); dispatch(imageSelected(firstImage ?? null));
} }
}, },

View File

@ -17,7 +17,7 @@ import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp, forEach } from 'lodash-es'; import { clamp, forEach } from 'lodash-es';
import { api } from 'services/api'; import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util'; import { imagesSelectors } from 'services/api/util';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -54,7 +54,7 @@ export const addRequestedSingleImageDeletionListener = () => {
imagesApi.endpoints.listImages.select(baseQueryArgs)(state); imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data const cachedImageDTOs = data
? imagesAdapter.getSelectors().selectAll(data) ? imagesSelectors.selectAll(data)
: []; : [];
const deletedImageIndex = cachedImageDTOs.findIndex( const deletedImageIndex = cachedImageDTOs.findIndex(
@ -187,7 +187,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
imagesApi.endpoints.listImages.select(baseQueryArgs)(state); imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const newSelectedImageDTO = data const newSelectedImageDTO = data
? imagesAdapter.getSelectors().selectAll(data)[0] ? imagesSelectors.selectAll(data)[0]
: undefined; : undefined;
if (newSelectedImageDTO) { if (newSelectedImageDTO) {

View File

@ -17,9 +17,9 @@ import {
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { import {
mainModelsAdapter, mainModelsAdapterSelectors,
modelsApi, modelsApi,
vaeModelsAdapter, vaeModelsAdapterSelectors,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import type { TypeGuardFor } from 'services/api/types'; import type { TypeGuardFor } from 'services/api/types';
@ -43,7 +43,7 @@ export const addModelsLoadedListener = () => {
); );
const currentModel = getState().generation.model; const currentModel = getState().generation.model;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload); const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) { if (models.length === 0) {
// No models loaded at all // No models loaded at all
@ -94,7 +94,7 @@ export const addModelsLoadedListener = () => {
); );
const currentModel = getState().sdxl.refinerModel; const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload); const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) { if (models.length === 0) {
// No models loaded at all // No models loaded at all
@ -145,8 +145,7 @@ export const addModelsLoadedListener = () => {
return; return;
} }
const firstModel = vaeModelsAdapter const firstModel = vaeModelsAdapterSelectors
.getSelectors()
.selectAll(action.payload)[0]; .selectAll(action.payload)[0];
if (!firstModel) { if (!firstModel) {

View File

@ -1,7 +1,10 @@
import { modelChanged } from 'features/parameters/store/generationSlice'; import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models'; import {
mainModelsAdapterSelectors,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -37,8 +40,7 @@ export const addTabChangedListener = () => {
} }
// need to filter out all the invalid canvas models (currently refiner & any) // need to filter out all the invalid canvas models (currently refiner & any)
const validCanvasModels = mainModelsAdapter const validCanvasModels = mainModelsAdapterSelectors
.getSelectors()
.selectAll(models) .selectAll(models)
.filter((model) => .filter((model) =>
['sd-1', 'sd-2', 'sdxl'].includes(model.base_model) ['sd-1', 'sd-2', 'sdxl'].includes(model.base_model)

View File

@ -1,9 +1,9 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types'; import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { import {
controlNetModelsAdapter, controlNetModelsAdapterSelectors,
ipAdapterModelsAdapter, ipAdapterModelsAdapterSelectors,
t2iAdapterModelsAdapter, t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery, useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery, useGetT2IAdapterModelsQuery,
@ -14,7 +14,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
const controlNetModels = useMemo( const controlNetModels = useMemo(
() => () =>
controlNetModelsData controlNetModelsData
? controlNetModelsAdapter.getSelectors().selectAll(controlNetModelsData) ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData)
: [], : [],
[controlNetModelsData] [controlNetModelsData]
); );
@ -23,7 +23,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
const t2iAdapterModels = useMemo( const t2iAdapterModels = useMemo(
() => () =>
t2iAdapterModelsData t2iAdapterModelsData
? t2iAdapterModelsAdapter.getSelectors().selectAll(t2iAdapterModelsData) ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData)
: [], : [],
[t2iAdapterModelsData] [t2iAdapterModelsData]
); );
@ -31,7 +31,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
const ipAdapterModels = useMemo( const ipAdapterModels = useMemo(
() => () =>
ipAdapterModelsData ipAdapterModelsData
? ipAdapterModelsAdapter.getSelectors().selectAll(ipAdapterModelsData) ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData)
: [], : [],
[ipAdapterModelsData] [ipAdapterModelsData]
); );

View File

@ -1,5 +1,6 @@
import type { PayloadAction, Update } from '@reduxjs/toolkit'; import type { PayloadAction, Update } from '@reduxjs/toolkit';
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter'; import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import type { import type {
ParameterControlNetModel, ParameterControlNetModel,
@ -36,6 +37,10 @@ import {
export const caAdapter = createEntityAdapter<ControlAdapterConfig, string>({ export const caAdapter = createEntityAdapter<ControlAdapterConfig, string>({
selectId: (ca) => ca.id, selectId: (ca) => ca.id,
}); });
export const caAdapterSelectors = caAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const { export const {
selectById: selectControlAdapterById, selectById: selectControlAdapterById,
@ -43,7 +48,7 @@ export const {
selectEntities: selectControlAdapterEntities, selectEntities: selectControlAdapterEntities,
selectIds: selectControlAdapterIds, selectIds: selectControlAdapterIds,
selectTotal: selectControlAdapterTotal, selectTotal: selectControlAdapterTotal,
} = caAdapter.getSelectors(); } = caAdapterSelectors;
export const initialControlAdapterState: ControlAdaptersState = export const initialControlAdapterState: ControlAdaptersState =
caAdapter.getInitialState<{ caAdapter.getInitialState<{

View File

@ -16,7 +16,7 @@ import {
useLazyListImagesQuery, useLazyListImagesQuery,
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import type { ListImagesArgs } from 'services/api/types'; import type { ListImagesArgs } from 'services/api/types';
import { imagesAdapter } from 'services/api/util'; import { imagesSelectors } from 'services/api/util';
export type UseNextPrevImageState = { export type UseNextPrevImageState = {
virtuosoRef: RefObject<VirtuosoGridHandle> | undefined; virtuosoRef: RefObject<VirtuosoGridHandle> | undefined;
@ -63,9 +63,7 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
limit: IMAGE_LIMIT, limit: IMAGE_LIMIT,
}; };
const selectors = imagesAdapter.getSelectors(); const images = imagesSelectors.selectAll(data);
const images = selectors.selectAll(data);
const currentImageIndex = images.findIndex( const currentImageIndex = images.findIndex(
(i) => i.image_name === lastSelectedImage.image_name (i) => i.image_name === lastSelectedImage.image_name
@ -77,10 +75,10 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
const prevImageId = images[prevImageIndex]?.image_name; const prevImageId = images[prevImageIndex]?.image_name;
const nextImage = nextImageId const nextImage = nextImageId
? selectors.selectById(data, nextImageId) ? imagesSelectors.selectById(data, nextImageId)
: undefined; : undefined;
const prevImage = prevImageId const prevImage = prevImageId
? selectors.selectById(data, prevImageId) ? imagesSelectors.selectById(data, prevImageId)
: undefined; : undefined;
const imagesLength = images.length; const imagesLength = images.length;

View File

@ -87,10 +87,10 @@ import { isNil } from 'lodash-es';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
controlNetModelsAdapter, controlNetModelsAdapterSelectors,
ipAdapterModelsAdapter, ipAdapterModelsAdapterSelectors,
loraModelsAdapter, loraModelsAdapterSelectors,
t2iAdapterModelsAdapter, t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery, useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
@ -488,9 +488,10 @@ export const useRecallParameters = () => {
const { base_model, model_name } = loraMetadataItem.lora; const { base_model, model_name } = loraMetadataItem.lora;
const matchingLoRA = loraModels const matchingLoRA = loraModels
? loraModelsAdapter ? loraModelsAdapterSelectors.selectById(
.getSelectors() loraModels,
.selectById(loraModels, `${base_model}/lora/${model_name}`) `${base_model}/lora/${model_name}`
)
: undefined; : undefined;
if (!matchingLoRA) { if (!matchingLoRA) {
@ -553,9 +554,7 @@ export const useRecallParameters = () => {
} = controlnetMetadataItem; } = controlnetMetadataItem;
const matchingControlNetModel = controlNetModels const matchingControlNetModel = controlNetModels
? controlNetModelsAdapter ? controlNetModelsAdapterSelectors.selectById(
.getSelectors()
.selectById(
controlNetModels, controlNetModels,
`${control_model.base_model}/controlnet/${control_model.model_name}` `${control_model.base_model}/controlnet/${control_model.model_name}`
) )
@ -649,9 +648,7 @@ export const useRecallParameters = () => {
} = t2iAdapterMetadataItem; } = t2iAdapterMetadataItem;
const matchingT2IAdapterModel = t2iAdapterModels const matchingT2IAdapterModel = t2iAdapterModels
? t2iAdapterModelsAdapter ? t2iAdapterModelsAdapterSelectors.selectById(
.getSelectors()
.selectById(
t2iAdapterModels, t2iAdapterModels,
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}` `${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
) )
@ -738,9 +735,7 @@ export const useRecallParameters = () => {
} = ipAdapterMetadataItem; } = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapterModels const matchingIPAdapterModel = ipAdapterModels
? ipAdapterModelsAdapter ? ipAdapterModelsAdapterSelectors.selectById(
.getSelectors()
.selectById(
ipAdapterModels, ipAdapterModels,
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}` `${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
) )

View File

@ -14,7 +14,7 @@ import { useTranslation } from 'react-i18next';
import type { Components, ItemContent } from 'react-virtuoso'; import type { Components, ItemContent } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso'; import { Virtuoso } from 'react-virtuoso';
import { import {
queueItemsAdapter, queueItemsAdapterSelectors,
useListQueueItemsQuery, useListQueueItemsQuery,
} from 'services/api/endpoints/queue'; } from 'services/api/endpoints/queue';
import type { SessionQueueItemDTO } from 'services/api/types'; import type { SessionQueueItemDTO } from 'services/api/types';
@ -77,7 +77,7 @@ const QueueList = () => {
if (!listQueueItemsData) { if (!listQueueItemsData) {
return []; return [];
} }
return queueItemsAdapter.getSelectors().selectAll(listQueueItemsData); return queueItemsAdapterSelectors.selectAll(listQueueItemsData);
}, [listQueueItemsData]); }, [listQueueItemsData]);
const handleLoadMore = useCallback(() => { const handleLoadMore = useCallback(() => {

View File

@ -1,5 +1,6 @@
import type { EntityState } from '@reduxjs/toolkit'; import type { EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { cloneDeep } from 'lodash-es'; import { cloneDeep } from 'lodash-es';
import queryString from 'query-string'; import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema'; import type { operations, paths } from 'services/api/schema';
@ -134,28 +135,48 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({ export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const controlNetModelsAdapter = export const controlNetModelsAdapter =
createEntityAdapter<ControlNetModelConfigEntity>({ createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const controlNetModelsAdapterSelectors =
controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter = export const ipAdapterModelsAdapter =
createEntityAdapter<IPAdapterModelConfigEntity>({ createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const ipAdapterModelsAdapterSelectors =
ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter = export const t2iAdapterModelsAdapter =
createEntityAdapter<T2IAdapterModelConfigEntity>({ createEntityAdapter<T2IAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const t2iAdapterModelsAdapterSelectors =
t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter = export const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({ createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const textualInversionModelsAdapterSelectors =
textualInversionModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({ export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const getModelId = ({ export const getModelId = ({
base_model, base_model,

View File

@ -4,6 +4,7 @@ import type {
UnknownAction, UnknownAction,
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { $queueId } from 'app/store/nanostores/queueId'; import { $queueId } from 'app/store/nanostores/queueId';
import { listParamsReset } from 'features/queue/store/queueSlice'; import { listParamsReset } from 'features/queue/store/queueSlice';
import queryString from 'query-string'; import queryString from 'query-string';
@ -59,6 +60,10 @@ export const queueItemsAdapter = createEntityAdapter<
return 0; return 0;
}, },
}); });
export const queueItemsAdapterSelectors = queueItemsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const queueApi = api.injectEndpoints({ export const queueApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
@ -308,7 +313,7 @@ export const queueApi = api.injectEndpoints({
merge: (cache, response) => { merge: (cache, response) => {
queueItemsAdapter.addMany( queueItemsAdapter.addMany(
cache, cache,
queueItemsAdapter.getSelectors().selectAll(response) queueItemsAdapterSelectors.selectAll(response)
); );
cache.has_more = response.has_more; cache.has_more = response.has_more;
}, },

View File

@ -1,4 +1,5 @@
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { dateComparator } from 'common/util/dateComparator'; import { dateComparator } from 'common/util/dateComparator';
import { import {
ASSETS_CATEGORIES, ASSETS_CATEGORIES,
@ -82,7 +83,10 @@ export const imagesAdapter = createEntityAdapter<ImageDTO, string>({
}); });
// Create selectors for the adapter. // Create selectors for the adapter.
export const imagesSelectors = imagesAdapter.getSelectors(); export const imagesSelectors = imagesAdapter.getSelectors(
undefined,
getSelectorsOptions
);
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key. // Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
export const getListImagesUrl = (queryArgs: ListImagesArgs) => export const getListImagesUrl = (queryArgs: ListImagesArgs) =>