feat(ui): optimize model query caching

When we retrieve a list of models, upsert that data into the `getModelConfig` and `getModelConfigByAttrs` query caches.

With this change, calls to those two queries are almost always going to be free, because their caches will already have all models in them. The exception is queries for models that no longer exist.
This commit is contained in:
psychedelicious 2024-02-26 13:27:13 +11:00 committed by Kent Keirsey
parent 0f19176944
commit 3c103c89f3

View File

@ -1,4 +1,4 @@
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import queryString from 'query-string'; import queryString from 'query-string';
@ -111,6 +111,12 @@ export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
}); });
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
const buildProvidesTags = const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) => <TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => { (result: EntityState<TEntity, string> | undefined) => {
@ -141,8 +147,6 @@ const buildTransformResponse =
*/ */
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({ getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
@ -157,6 +161,11 @@ export const modelsApi = api.injectEndpoints({
}, },
providesTags: buildProvidesTags<MainModelConfig>('MainModel'), providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter), transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getModelMetadata: build.query<GetModelMetadataResponse, string>({ getModelMetadata: build.query<GetModelMetadataResponse, string>({
query: (key) => { query: (key) => {
@ -236,6 +245,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`,
}), }),
syncModels: build.mutation<void, void>({ syncModels: build.mutation<void, void>({
query: () => { query: () => {
@ -250,31 +260,61 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'), providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter), transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({ getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'), providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter), transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({ getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'), providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter), transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({ getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'), providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter), transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({ getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'), providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter), transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({ getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'), providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter), transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}), }),
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({ scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => { query: (arg) => {
@ -336,3 +376,15 @@ export const {
useDeleteModelImportMutation, useDeleteModelImportMutation,
usePruneModelImportsMutation, usePruneModelImportsMutation,
} = modelsApi; } = modelsApi;
const upsertModelConfigs = (
modelConfigs: EntityState<AnyModelConfig, string>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
};