From d069f21d97594199443f2272c12c013c1f616315 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:27:13 +1100 Subject: [PATCH] feat(ui): optimize model query caching When we retrieve a list of models, upsert that data into the `getModelConfig` and `getModelConfigByAttrs` query caches. With this change, calls to those two queries are almost always going to be free, because their caches will already have all models in them. The exception is queries for models that no longer exist. --- .../web/src/services/api/endpoints/models.ts | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 3ccccf62e1..2d1d021bc6 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,4 +1,4 @@ -import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; +import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import queryString from 'query-string'; @@ -111,6 +111,12 @@ export const vaeModelsAdapter = createEntityAdapter({ }); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); +export const anyModelConfigAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +export const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions); + const buildProvidesTags = (tagType: (typeof tagTypes)[number]) => (result: EntityState | undefined) => { @@ -141,8 +147,6 @@ const buildTransformResponse = */ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); -// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query - export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ getMainModels: build.query, BaseModelType[]>({ @@ -157,6 +161,11 @@ export const modelsApi = api.injectEndpoints({ }, providesTags: buildProvidesTags('MainModel'), transformResponse: buildTransformResponse(mainModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getModelMetadata: build.query({ query: (key) => { @@ -236,6 +245,7 @@ export const modelsApi = api.injectEndpoints({ return tags; }, + serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`, }), syncModels: build.mutation({ query: () => { @@ -250,31 +260,61 @@ export const modelsApi = api.injectEndpoints({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), providesTags: buildProvidesTags('LoRAModel'), transformResponse: buildTransformResponse(loraModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getControlNetModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), providesTags: buildProvidesTags('ControlNetModel'), transformResponse: buildTransformResponse(controlNetModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getIPAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), providesTags: buildProvidesTags('IPAdapterModel'), transformResponse: buildTransformResponse(ipAdapterModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getT2IAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), providesTags: buildProvidesTags('T2IAdapterModel'), transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getVaeModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), providesTags: buildProvidesTags('VaeModel'), transformResponse: buildTransformResponse(vaeModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getTextualInversionModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), providesTags: buildProvidesTags('TextualInversionModel'), transformResponse: buildTransformResponse(textualInversionModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), scanModels: build.query({ query: (arg) => { @@ -336,3 +376,15 @@ export const { useDeleteModelImportMutation, usePruneModelImportsMutation, } = modelsApi; + +const upsertModelConfigs = ( + modelConfigs: EntityState, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + dispatch: ThunkDispatch +) => { + anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => { + dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig)); + const { base, name, type } = modelConfig; + dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig)); + }); +};