From a681fa4b0394f18c8395fdc00f9af66d1e32eac3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 7 Oct 2023 22:21:38 +1100 Subject: [PATCH] fix(ui): invalidate query cache for all models on sync models Also realised the tags were set up incorrectly, fixed that to get type safety with tags. --- .../web/src/services/api/endpoints/models.ts | 81 +++++++------------ .../frontend/web/src/services/api/index.ts | 18 ++++- 2 files changed, 43 insertions(+), 56 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 65cb151818..e095bce8ca 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -199,7 +199,10 @@ export const modelsApi = api.injectEndpoints({ return `models/?${query}`; }, providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'OnnxModel', id: LIST_TAG }]; + const tags: ApiTagDescription[] = [ + { type: 'OnnxModel', id: LIST_TAG }, + 'Model', + ]; if (result) { tags.push( @@ -236,7 +239,10 @@ export const modelsApi = api.injectEndpoints({ return `models/?${query}`; }, providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }]; + const tags: ApiTagDescription[] = [ + { type: 'MainModel', id: LIST_TAG }, + 'Model', + ]; if (result) { tags.push( @@ -270,11 +276,7 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [ - { type: 'MainModel', id: LIST_TAG }, - { type: 'SDXLRefinerModel', id: LIST_TAG }, - { type: 'OnnxModel', id: LIST_TAG }, - ], + invalidatesTags: ['Model'], }), importMainModels: build.mutation< ImportMainModelResponse, @@ -287,11 +289,7 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [ - { type: 'MainModel', id: LIST_TAG }, - { type: 'SDXLRefinerModel', id: LIST_TAG }, - { type: 'OnnxModel', id: LIST_TAG }, - ], + invalidatesTags: ['Model'], }), addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({ query: ({ body }) => { @@ -301,11 +299,7 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [ - { type: 'MainModel', id: LIST_TAG }, - { type: 'SDXLRefinerModel', id: LIST_TAG }, - { type: 'OnnxModel', id: LIST_TAG }, - ], + invalidatesTags: ['Model'], }), deleteMainModels: build.mutation< DeleteMainModelResponse, @@ -317,11 +311,7 @@ export const modelsApi = api.injectEndpoints({ method: 'DELETE', }; }, - invalidatesTags: [ - { type: 'MainModel', id: LIST_TAG }, - { type: 'SDXLRefinerModel', id: LIST_TAG }, - { type: 'OnnxModel', id: LIST_TAG }, - ], + invalidatesTags: ['Model'], }), convertMainModels: build.mutation< ConvertMainModelResponse, @@ -334,11 +324,7 @@ export const modelsApi = api.injectEndpoints({ params: { convert_dest_directory }, }; }, - invalidatesTags: [ - { type: 'MainModel', id: LIST_TAG }, - { type: 'SDXLRefinerModel', id: LIST_TAG }, - { type: 'OnnxModel', id: LIST_TAG }, - ], + invalidatesTags: ['Model'], }), mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({ query: ({ base_model, body }) => { @@ -348,11 +334,7 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [ - { type: 'MainModel', id: LIST_TAG }, - { type: 'SDXLRefinerModel', id: LIST_TAG }, - { type: 'OnnxModel', id: LIST_TAG }, - ], + invalidatesTags: ['Model'], }), syncModels: build.mutation<SyncModelsResponse, void>({ query: () => { @@ -361,16 +343,15 @@ export const modelsApi = api.injectEndpoints({ method: 'POST', }; }, - invalidatesTags: [ - { type: 'MainModel', id: LIST_TAG }, - { type: 'SDXLRefinerModel', id: LIST_TAG }, - { type: 'OnnxModel', id: LIST_TAG }, - ], + invalidatesTags: ['Model'], }), getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }]; + const tags: ApiTagDescription[] = [ + { type: 'LoRAModel', id: LIST_TAG }, + 'Model', + ]; if (result) { tags.push( @@ -426,6 +407,7 @@ export const modelsApi = api.injectEndpoints({ providesTags: (result) => { const tags: ApiTagDescription[] = [ { type: 'ControlNetModel', id: LIST_TAG }, + 'Model', ]; if (result) { @@ -457,6 +439,7 @@ export const modelsApi = api.injectEndpoints({ providesTags: (result) => { const tags: ApiTagDescription[] = [ { type: 'IPAdapterModel', id: LIST_TAG }, + 'Model', ]; if (result) { @@ -488,6 +471,7 @@ export const modelsApi = api.injectEndpoints({ providesTags: (result) => { const tags: ApiTagDescription[] = [ { type: 'T2IAdapterModel', id: LIST_TAG }, + 'Model', ]; if (result) { @@ -514,7 +498,10 @@ export const modelsApi = api.injectEndpoints({ getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({ query: () => ({ url: 'models/', params: { model_type: 'vae' } }), providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }]; + const tags: ApiTagDescription[] = [ + { type: 'VaeModel', id: LIST_TAG }, + 'Model', + ]; if (result) { tags.push( @@ -545,6 +532,7 @@ export const modelsApi = api.injectEndpoints({ providesTags: (result) => { const tags: ApiTagDescription[] = [ { type: 'TextualInversionModel', id: LIST_TAG }, + 'Model', ]; if (result) { @@ -577,21 +565,6 @@ export const modelsApi = api.injectEndpoints({ url: `/models/search?${folderQueryStr}`, }; }, - providesTags: (result) => { - const tags: ApiTagDescription[] = [ - { type: 'ScannedModels', id: LIST_TAG }, - ]; - - if (result) { - tags.push( - ...result.map((id) => ({ - type: 'ScannedModels' as const, - id, - })) - ); - } - return tags; - }, }), getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({ query: () => { diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index b39b11af29..7a10377323 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -9,6 +9,8 @@ import { import { $authToken, $baseUrl, $projectId } from 'services/api/client'; export const tagTypes = [ + 'AppVersion', + 'AppConfig', 'Board', 'BoardImagesTotal', 'BoardAssetsTotal', @@ -17,15 +19,27 @@ export const tagTypes = [ 'ImageList', 'ImageMetadata', 'ImageMetadataFromFile', - 'Model', + 'IntermediatesCount', 'SessionQueueItem', 'SessionQueueItemDTO', 'SessionQueueItemDTOList', 'SessionQueueStatus', 'SessionProcessorStatus', + 'CurrentSessionQueueItem', + 'NextSessionQueueItem', 'BatchStatus', 'InvocationCacheStatus', -]; + 'Model', + 'T2IAdapterModel', + 'MainModel', + 'OnnxModel', + 'VaeModel', + 'IPAdapterModel', + 'TextualInversionModel', + 'ControlNetModel', + 'LoRAModel', + 'SDXLRefinerModel', +] as const; export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>; export const LIST_TAG = 'LIST';