feat(ui): optimize model query caching

When we retrieve a list of models, upsert that data into the `getModelConfig` and `getModelConfigByAttrs` query caches.

With this change, calls to those two queries are almost always going to be free, because their caches will already have all models in them. The exception is queries for models that no longer exist.
This commit is contained in:
psychedelicious 2024-02-26 13:27:13 +11:00
parent 0b54bfb7c5
commit 7176c5d9d6

View File

@ -1,4 +1,4 @@
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import queryString from 'query-string'; import queryString from 'query-string';
@ -111,6 +111,12 @@ export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
}); });
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
const buildProvidesTags = const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) => <TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => { (result: EntityState<TEntity, string> | undefined) => {
@ -141,8 +147,6 @@ const buildTransformResponse =
*/ */
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({ getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
@ -157,6 +161,11 @@ export const modelsApi = api.injectEndpoints({
}, },
providesTags: buildProvidesTags<MainModelConfig>('MainModel'), providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter), transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getModelMetadata: build.query<GetModelMetadataResponse, string>({ getModelMetadata: build.query<GetModelMetadataResponse, string>({
query: (key) => { query: (key) => {
@ -236,6 +245,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`,
}), }),
syncModels: build.mutation<void, void>({ syncModels: build.mutation<void, void>({
query: () => { query: () => {
@ -250,31 +260,61 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'), providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter), transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({ getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'), providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter), transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({ getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'), providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter), transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({ getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'), providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter), transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({ getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'), providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter), transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({ getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'), providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter), transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({ scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => { query: (arg) => {
@ -336,3 +376,15 @@ export const {
useDeleteModelImportMutation, useDeleteModelImportMutation,
usePruneModelImportsMutation, usePruneModelImportsMutation,
} = modelsApi; } = modelsApi;
const upsertModelConfigs = (
modelConfigs: EntityState<AnyModelConfig, string>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
};