mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): optimize model query caching
When we retrieve a list of models, upsert that data into the `getModelConfig` and `getModelConfigByAttrs` query caches. With this change, calls to those two queries are almost always going to be free, because their caches will already have all models in them. The exception is queries for models that no longer exist.
This commit is contained in:
parent
7d23120c2e
commit
d069f21d97
@ -1,4 +1,4 @@
|
|||||||
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit';
|
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||||
import queryString from 'query-string';
|
import queryString from 'query-string';
|
||||||
@ -111,6 +111,12 @@ export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
|||||||
});
|
});
|
||||||
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
|
|
||||||
|
export const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
|
||||||
|
selectId: (entity) => entity.key,
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
export const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
|
|
||||||
const buildProvidesTags =
|
const buildProvidesTags =
|
||||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
||||||
(result: EntityState<TEntity, string> | undefined) => {
|
(result: EntityState<TEntity, string> | undefined) => {
|
||||||
@ -141,8 +147,6 @@ const buildTransformResponse =
|
|||||||
*/
|
*/
|
||||||
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||||
|
|
||||||
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
|
|
||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||||
@ -157,6 +161,11 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
||||||
query: (key) => {
|
query: (key) => {
|
||||||
@ -236,6 +245,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
|
serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`,
|
||||||
}),
|
}),
|
||||||
syncModels: build.mutation<void, void>({
|
syncModels: build.mutation<void, void>({
|
||||||
query: () => {
|
query: () => {
|
||||||
@ -250,31 +260,61 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||||
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
||||||
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
||||||
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
||||||
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
||||||
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
||||||
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
||||||
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
||||||
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
||||||
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
||||||
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
||||||
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
|
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
|
||||||
query: (arg) => {
|
query: (arg) => {
|
||||||
@ -336,3 +376,15 @@ export const {
|
|||||||
useDeleteModelImportMutation,
|
useDeleteModelImportMutation,
|
||||||
usePruneModelImportsMutation,
|
usePruneModelImportsMutation,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
|
||||||
|
const upsertModelConfigs = (
|
||||||
|
modelConfigs: EntityState<AnyModelConfig, string>,
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
dispatch: ThunkDispatch<any, any, UnknownAction>
|
||||||
|
) => {
|
||||||
|
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
||||||
|
const { base, name, type } = modelConfig;
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
||||||
|
});
|
||||||
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user