From 7176c5d9d6a5a4d33013fbdbcbdb2703fea5ec14 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:27:13 +1100 Subject: [PATCH] feat(ui): optimize model query caching When we retrieve a list of models, upsert that data into the `getModelConfig` and `getModelConfigByAttrs` query caches. With this change, calls to those two queries are almost always going to be free, because their caches will already have all models in them. The exception is queries for models that no longer exist. --- .../web/src/services/api/endpoints/models.ts | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 3ccccf62e1..2d1d021bc6 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,4 +1,4 @@ -import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; +import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import queryString from 'query-string'; @@ -111,6 +111,12 @@ export const vaeModelsAdapter = createEntityAdapter({ }); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); +export const anyModelConfigAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +export const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions); + const buildProvidesTags = (tagType: (typeof tagTypes)[number]) => (result: EntityState | undefined) => { @@ -141,8 +147,6 @@ const buildTransformResponse = */ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); -// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query - export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ getMainModels: build.query, BaseModelType[]>({ @@ -157,6 +161,11 @@ export const modelsApi = api.injectEndpoints({ }, providesTags: buildProvidesTags('MainModel'), transformResponse: buildTransformResponse(mainModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getModelMetadata: build.query({ query: (key) => { @@ -236,6 +245,7 @@ export const modelsApi = api.injectEndpoints({ return tags; }, + serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`, }), syncModels: build.mutation({ query: () => { @@ -250,31 +260,61 @@ export const modelsApi = api.injectEndpoints({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), providesTags: buildProvidesTags('LoRAModel'), transformResponse: buildTransformResponse(loraModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getControlNetModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), providesTags: buildProvidesTags('ControlNetModel'), transformResponse: buildTransformResponse(controlNetModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getIPAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), providesTags: buildProvidesTags('IPAdapterModel'), transformResponse: buildTransformResponse(ipAdapterModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getT2IAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), providesTags: buildProvidesTags('T2IAdapterModel'), transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getVaeModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), providesTags: buildProvidesTags('VaeModel'), transformResponse: buildTransformResponse(vaeModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getTextualInversionModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), providesTags: buildProvidesTags('TextualInversionModel'), transformResponse: buildTransformResponse(textualInversionModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), scanModels: build.query({ query: (arg) => { @@ -336,3 +376,15 @@ export const { useDeleteModelImportMutation, usePruneModelImportsMutation, } = modelsApi; + +const upsertModelConfigs = ( + modelConfigs: EntityState, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + dispatch: ThunkDispatch +) => { + anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => { + dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig)); + const { base, name, type } = modelConfig; + dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig)); + }); +};