diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 3ccccf62e1..2d1d021bc6 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,4 +1,4 @@ -import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; +import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import queryString from 'query-string'; @@ -111,6 +111,12 @@ export const vaeModelsAdapter = createEntityAdapter({ }); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); +export const anyModelConfigAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +export const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions); + const buildProvidesTags = (tagType: (typeof tagTypes)[number]) => (result: EntityState | undefined) => { @@ -141,8 +147,6 @@ const buildTransformResponse = */ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); -// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query - export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ getMainModels: build.query, BaseModelType[]>({ @@ -157,6 +161,11 @@ export const modelsApi = api.injectEndpoints({ }, providesTags: buildProvidesTags('MainModel'), transformResponse: buildTransformResponse(mainModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getModelMetadata: build.query({ query: (key) => { @@ -236,6 +245,7 @@ export const modelsApi = api.injectEndpoints({ return tags; }, + serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`, }), syncModels: build.mutation({ query: () => { @@ -250,31 +260,61 @@ export const modelsApi = api.injectEndpoints({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), providesTags: buildProvidesTags('LoRAModel'), transformResponse: buildTransformResponse(loraModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getControlNetModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), providesTags: buildProvidesTags('ControlNetModel'), transformResponse: buildTransformResponse(controlNetModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getIPAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), providesTags: buildProvidesTags('IPAdapterModel'), transformResponse: buildTransformResponse(ipAdapterModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getT2IAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), providesTags: buildProvidesTags('T2IAdapterModel'), transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getVaeModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), providesTags: buildProvidesTags('VaeModel'), transformResponse: buildTransformResponse(vaeModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), getTextualInversionModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), providesTags: buildProvidesTags('TextualInversionModel'), transformResponse: buildTransformResponse(textualInversionModelsAdapter), + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + queryFulfilled.then(({ data }) => { + upsertModelConfigs(data, dispatch); + }); + }, }), scanModels: build.query({ query: (arg) => { @@ -336,3 +376,15 @@ export const { useDeleteModelImportMutation, usePruneModelImportsMutation, } = modelsApi; + +const upsertModelConfigs = ( + modelConfigs: EntityState, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + dispatch: ThunkDispatch +) => { + anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => { + dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig)); + const { base, name, type } = modelConfig; + dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig)); + }); +};