mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use lruMemoize
for all entity adapter selectors
This commit is contained in:
parent
83fbd4bdf2
commit
6924b04d7c
@ -1,4 +1,9 @@
|
||||
import { createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
|
||||
import {
|
||||
createDraftSafeSelectorCreator,
|
||||
createSelectorCreator,
|
||||
lruMemoize,
|
||||
} from '@reduxjs/toolkit';
|
||||
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
/**
|
||||
@ -19,3 +24,12 @@ export const createLruSelector = createSelectorCreator({
|
||||
memoize: lruMemoize,
|
||||
argsMemoize: lruMemoize,
|
||||
});
|
||||
|
||||
export const createLruDraftSafeSelector = createDraftSafeSelectorCreator({
|
||||
memoize: lruMemoize,
|
||||
argsMemoize: lruMemoize,
|
||||
});
|
||||
|
||||
export const getSelectorsOptions: GetSelectorsOptions = {
|
||||
createSelector: createLruDraftSafeSelector,
|
||||
};
|
||||
|
@ -3,7 +3,7 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageCache } from 'services/api/types';
|
||||
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
|
||||
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
|
||||
|
||||
import { startAppListening } from '..';
|
||||
|
||||
@ -33,7 +33,7 @@ export const addFirstListImagesListener = () => {
|
||||
|
||||
if (data.ids.length > 0) {
|
||||
// Select the first image
|
||||
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0];
|
||||
const firstImage = imagesSelectors.selectAll(data)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
}
|
||||
},
|
||||
|
@ -17,7 +17,7 @@ import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
import { startAppListening } from '..';
|
||||
|
||||
@ -54,7 +54,7 @@ export const addRequestedSingleImageDeletionListener = () => {
|
||||
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const cachedImageDTOs = data
|
||||
? imagesAdapter.getSelectors().selectAll(data)
|
||||
? imagesSelectors.selectAll(data)
|
||||
: [];
|
||||
|
||||
const deletedImageIndex = cachedImageDTOs.findIndex(
|
||||
@ -187,7 +187,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
||||
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const newSelectedImageDTO = data
|
||||
? imagesAdapter.getSelectors().selectAll(data)[0]
|
||||
? imagesSelectors.selectAll(data)[0]
|
||||
: undefined;
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
|
@ -17,9 +17,9 @@ import {
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import {
|
||||
mainModelsAdapter,
|
||||
mainModelsAdapterSelectors,
|
||||
modelsApi,
|
||||
vaeModelsAdapter,
|
||||
vaeModelsAdapterSelectors,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { TypeGuardFor } from 'services/api/types';
|
||||
|
||||
@ -43,7 +43,7 @@ export const addModelsLoadedListener = () => {
|
||||
);
|
||||
|
||||
const currentModel = getState().generation.model;
|
||||
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
@ -94,7 +94,7 @@ export const addModelsLoadedListener = () => {
|
||||
);
|
||||
|
||||
const currentModel = getState().sdxl.refinerModel;
|
||||
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
@ -145,8 +145,7 @@ export const addModelsLoadedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = vaeModelsAdapter
|
||||
.getSelectors()
|
||||
const firstModel = vaeModelsAdapterSelectors
|
||||
.selectAll(action.payload)[0];
|
||||
|
||||
if (!firstModel) {
|
||||
|
@ -1,7 +1,10 @@
|
||||
import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
|
||||
import {
|
||||
mainModelsAdapterSelectors,
|
||||
modelsApi,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
import { startAppListening } from '..';
|
||||
|
||||
@ -37,8 +40,7 @@ export const addTabChangedListener = () => {
|
||||
}
|
||||
|
||||
// need to filter out all the invalid canvas models (currently refiner & any)
|
||||
const validCanvasModels = mainModelsAdapter
|
||||
.getSelectors()
|
||||
const validCanvasModels = mainModelsAdapterSelectors
|
||||
.selectAll(models)
|
||||
.filter((model) =>
|
||||
['sd-1', 'sd-2', 'sdxl'].includes(model.base_model)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
controlNetModelsAdapter,
|
||||
ipAdapterModelsAdapter,
|
||||
t2iAdapterModelsAdapter,
|
||||
controlNetModelsAdapterSelectors,
|
||||
ipAdapterModelsAdapterSelectors,
|
||||
t2iAdapterModelsAdapterSelectors,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
@ -14,7 +14,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
||||
const controlNetModels = useMemo(
|
||||
() =>
|
||||
controlNetModelsData
|
||||
? controlNetModelsAdapter.getSelectors().selectAll(controlNetModelsData)
|
||||
? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData)
|
||||
: [],
|
||||
[controlNetModelsData]
|
||||
);
|
||||
@ -23,7 +23,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
||||
const t2iAdapterModels = useMemo(
|
||||
() =>
|
||||
t2iAdapterModelsData
|
||||
? t2iAdapterModelsAdapter.getSelectors().selectAll(t2iAdapterModelsData)
|
||||
? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData)
|
||||
: [],
|
||||
[t2iAdapterModelsData]
|
||||
);
|
||||
@ -31,7 +31,7 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
||||
const ipAdapterModels = useMemo(
|
||||
() =>
|
||||
ipAdapterModelsData
|
||||
? ipAdapterModelsAdapter.getSelectors().selectAll(ipAdapterModelsData)
|
||||
? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData)
|
||||
: [],
|
||||
[ipAdapterModelsData]
|
||||
);
|
||||
|
@ -1,5 +1,6 @@
|
||||
import type { PayloadAction, Update } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import type {
|
||||
ParameterControlNetModel,
|
||||
@ -36,6 +37,10 @@ import {
|
||||
export const caAdapter = createEntityAdapter<ControlAdapterConfig, string>({
|
||||
selectId: (ca) => ca.id,
|
||||
});
|
||||
export const caAdapterSelectors = caAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
|
||||
export const {
|
||||
selectById: selectControlAdapterById,
|
||||
@ -43,7 +48,7 @@ export const {
|
||||
selectEntities: selectControlAdapterEntities,
|
||||
selectIds: selectControlAdapterIds,
|
||||
selectTotal: selectControlAdapterTotal,
|
||||
} = caAdapter.getSelectors();
|
||||
} = caAdapterSelectors;
|
||||
|
||||
export const initialControlAdapterState: ControlAdaptersState =
|
||||
caAdapter.getInitialState<{
|
||||
|
@ -16,7 +16,7 @@ import {
|
||||
useLazyListImagesQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import type { ListImagesArgs } from 'services/api/types';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export type UseNextPrevImageState = {
|
||||
virtuosoRef: RefObject<VirtuosoGridHandle> | undefined;
|
||||
@ -63,9 +63,7 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
|
||||
limit: IMAGE_LIMIT,
|
||||
};
|
||||
|
||||
const selectors = imagesAdapter.getSelectors();
|
||||
|
||||
const images = selectors.selectAll(data);
|
||||
const images = imagesSelectors.selectAll(data);
|
||||
|
||||
const currentImageIndex = images.findIndex(
|
||||
(i) => i.image_name === lastSelectedImage.image_name
|
||||
@ -77,10 +75,10 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
|
||||
const prevImageId = images[prevImageIndex]?.image_name;
|
||||
|
||||
const nextImage = nextImageId
|
||||
? selectors.selectById(data, nextImageId)
|
||||
? imagesSelectors.selectById(data, nextImageId)
|
||||
: undefined;
|
||||
const prevImage = prevImageId
|
||||
? selectors.selectById(data, prevImageId)
|
||||
? imagesSelectors.selectById(data, prevImageId)
|
||||
: undefined;
|
||||
|
||||
const imagesLength = images.length;
|
||||
|
@ -87,10 +87,10 @@ import { isNil } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
controlNetModelsAdapter,
|
||||
ipAdapterModelsAdapter,
|
||||
loraModelsAdapter,
|
||||
t2iAdapterModelsAdapter,
|
||||
controlNetModelsAdapterSelectors,
|
||||
ipAdapterModelsAdapterSelectors,
|
||||
loraModelsAdapterSelectors,
|
||||
t2iAdapterModelsAdapterSelectors,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
@ -488,9 +488,10 @@ export const useRecallParameters = () => {
|
||||
const { base_model, model_name } = loraMetadataItem.lora;
|
||||
|
||||
const matchingLoRA = loraModels
|
||||
? loraModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(loraModels, `${base_model}/lora/${model_name}`)
|
||||
? loraModelsAdapterSelectors.selectById(
|
||||
loraModels,
|
||||
`${base_model}/lora/${model_name}`
|
||||
)
|
||||
: undefined;
|
||||
|
||||
if (!matchingLoRA) {
|
||||
@ -553,12 +554,10 @@ export const useRecallParameters = () => {
|
||||
} = controlnetMetadataItem;
|
||||
|
||||
const matchingControlNetModel = controlNetModels
|
||||
? controlNetModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(
|
||||
controlNetModels,
|
||||
`${control_model.base_model}/controlnet/${control_model.model_name}`
|
||||
)
|
||||
? controlNetModelsAdapterSelectors.selectById(
|
||||
controlNetModels,
|
||||
`${control_model.base_model}/controlnet/${control_model.model_name}`
|
||||
)
|
||||
: undefined;
|
||||
|
||||
if (!matchingControlNetModel) {
|
||||
@ -649,12 +648,10 @@ export const useRecallParameters = () => {
|
||||
} = t2iAdapterMetadataItem;
|
||||
|
||||
const matchingT2IAdapterModel = t2iAdapterModels
|
||||
? t2iAdapterModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(
|
||||
t2iAdapterModels,
|
||||
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
|
||||
)
|
||||
? t2iAdapterModelsAdapterSelectors.selectById(
|
||||
t2iAdapterModels,
|
||||
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
|
||||
)
|
||||
: undefined;
|
||||
|
||||
if (!matchingT2IAdapterModel) {
|
||||
@ -738,12 +735,10 @@ export const useRecallParameters = () => {
|
||||
} = ipAdapterMetadataItem;
|
||||
|
||||
const matchingIPAdapterModel = ipAdapterModels
|
||||
? ipAdapterModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(
|
||||
ipAdapterModels,
|
||||
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
|
||||
)
|
||||
? ipAdapterModelsAdapterSelectors.selectById(
|
||||
ipAdapterModels,
|
||||
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
|
||||
)
|
||||
: undefined;
|
||||
|
||||
if (!matchingIPAdapterModel) {
|
||||
|
@ -14,7 +14,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import type { Components, ItemContent } from 'react-virtuoso';
|
||||
import { Virtuoso } from 'react-virtuoso';
|
||||
import {
|
||||
queueItemsAdapter,
|
||||
queueItemsAdapterSelectors,
|
||||
useListQueueItemsQuery,
|
||||
} from 'services/api/endpoints/queue';
|
||||
import type { SessionQueueItemDTO } from 'services/api/types';
|
||||
@ -77,7 +77,7 @@ const QueueList = () => {
|
||||
if (!listQueueItemsData) {
|
||||
return [];
|
||||
}
|
||||
return queueItemsAdapter.getSelectors().selectAll(listQueueItemsData);
|
||||
return queueItemsAdapterSelectors.selectAll(listQueueItemsData);
|
||||
}, [listQueueItemsData]);
|
||||
|
||||
const handleLoadMore = useCallback(() => {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import queryString from 'query-string';
|
||||
import type { operations, paths } from 'services/api/schema';
|
||||
@ -134,28 +135,48 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
export const controlNetModelsAdapter =
|
||||
createEntityAdapter<ControlNetModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const controlNetModelsAdapterSelectors =
|
||||
controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const ipAdapterModelsAdapter =
|
||||
createEntityAdapter<IPAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const ipAdapterModelsAdapterSelectors =
|
||||
ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const t2iAdapterModelsAdapter =
|
||||
createEntityAdapter<T2IAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const t2iAdapterModelsAdapterSelectors =
|
||||
t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const textualInversionModelsAdapter =
|
||||
createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const textualInversionModelsAdapterSelectors =
|
||||
textualInversionModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
|
||||
export const getModelId = ({
|
||||
base_model,
|
||||
|
@ -4,6 +4,7 @@ import type {
|
||||
UnknownAction,
|
||||
} from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import { listParamsReset } from 'features/queue/store/queueSlice';
|
||||
import queryString from 'query-string';
|
||||
@ -59,6 +60,10 @@ export const queueItemsAdapter = createEntityAdapter<
|
||||
return 0;
|
||||
},
|
||||
});
|
||||
export const queueItemsAdapterSelectors = queueItemsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
|
||||
export const queueApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@ -308,7 +313,7 @@ export const queueApi = api.injectEndpoints({
|
||||
merge: (cache, response) => {
|
||||
queueItemsAdapter.addMany(
|
||||
cache,
|
||||
queueItemsAdapter.getSelectors().selectAll(response)
|
||||
queueItemsAdapterSelectors.selectAll(response)
|
||||
);
|
||||
cache.has_more = response.has_more;
|
||||
},
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
import {
|
||||
ASSETS_CATEGORIES,
|
||||
@ -82,7 +83,10 @@ export const imagesAdapter = createEntityAdapter<ImageDTO, string>({
|
||||
});
|
||||
|
||||
// Create selectors for the adapter.
|
||||
export const imagesSelectors = imagesAdapter.getSelectors();
|
||||
export const imagesSelectors = imagesAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
|
||||
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
|
||||
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>
|
||||
|
Loading…
Reference in New Issue
Block a user