From a681fa4b0394f18c8395fdc00f9af66d1e32eac3 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 7 Oct 2023 22:21:38 +1100
Subject: [PATCH] fix(ui): invalidate query cache for all models on sync models

Also realised the tags were set up incorrectly, fixed that to get type safety with tags.
---
 .../web/src/services/api/endpoints/models.ts  | 81 +++++++------------
 .../frontend/web/src/services/api/index.ts    | 18 ++++-
 2 files changed, 43 insertions(+), 56 deletions(-)

diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts
index 65cb151818..e095bce8ca 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/models.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts
@@ -199,7 +199,10 @@ export const modelsApi = api.injectEndpoints({
         return `models/?${query}`;
       },
       providesTags: (result) => {
-        const tags: ApiTagDescription[] = [{ type: 'OnnxModel', id: LIST_TAG }];
+        const tags: ApiTagDescription[] = [
+          { type: 'OnnxModel', id: LIST_TAG },
+          'Model',
+        ];
 
         if (result) {
           tags.push(
@@ -236,7 +239,10 @@ export const modelsApi = api.injectEndpoints({
         return `models/?${query}`;
       },
       providesTags: (result) => {
-        const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }];
+        const tags: ApiTagDescription[] = [
+          { type: 'MainModel', id: LIST_TAG },
+          'Model',
+        ];
 
         if (result) {
           tags.push(
@@ -270,11 +276,7 @@ export const modelsApi = api.injectEndpoints({
           body: body,
         };
       },
-      invalidatesTags: [
-        { type: 'MainModel', id: LIST_TAG },
-        { type: 'SDXLRefinerModel', id: LIST_TAG },
-        { type: 'OnnxModel', id: LIST_TAG },
-      ],
+      invalidatesTags: ['Model'],
     }),
     importMainModels: build.mutation<
       ImportMainModelResponse,
@@ -287,11 +289,7 @@ export const modelsApi = api.injectEndpoints({
           body: body,
         };
       },
-      invalidatesTags: [
-        { type: 'MainModel', id: LIST_TAG },
-        { type: 'SDXLRefinerModel', id: LIST_TAG },
-        { type: 'OnnxModel', id: LIST_TAG },
-      ],
+      invalidatesTags: ['Model'],
     }),
     addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
       query: ({ body }) => {
@@ -301,11 +299,7 @@ export const modelsApi = api.injectEndpoints({
           body: body,
         };
       },
-      invalidatesTags: [
-        { type: 'MainModel', id: LIST_TAG },
-        { type: 'SDXLRefinerModel', id: LIST_TAG },
-        { type: 'OnnxModel', id: LIST_TAG },
-      ],
+      invalidatesTags: ['Model'],
     }),
     deleteMainModels: build.mutation<
       DeleteMainModelResponse,
@@ -317,11 +311,7 @@ export const modelsApi = api.injectEndpoints({
           method: 'DELETE',
         };
       },
-      invalidatesTags: [
-        { type: 'MainModel', id: LIST_TAG },
-        { type: 'SDXLRefinerModel', id: LIST_TAG },
-        { type: 'OnnxModel', id: LIST_TAG },
-      ],
+      invalidatesTags: ['Model'],
     }),
     convertMainModels: build.mutation<
       ConvertMainModelResponse,
@@ -334,11 +324,7 @@ export const modelsApi = api.injectEndpoints({
           params: { convert_dest_directory },
         };
       },
-      invalidatesTags: [
-        { type: 'MainModel', id: LIST_TAG },
-        { type: 'SDXLRefinerModel', id: LIST_TAG },
-        { type: 'OnnxModel', id: LIST_TAG },
-      ],
+      invalidatesTags: ['Model'],
     }),
     mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
       query: ({ base_model, body }) => {
@@ -348,11 +334,7 @@ export const modelsApi = api.injectEndpoints({
           body: body,
         };
       },
-      invalidatesTags: [
-        { type: 'MainModel', id: LIST_TAG },
-        { type: 'SDXLRefinerModel', id: LIST_TAG },
-        { type: 'OnnxModel', id: LIST_TAG },
-      ],
+      invalidatesTags: ['Model'],
     }),
     syncModels: build.mutation<SyncModelsResponse, void>({
       query: () => {
@@ -361,16 +343,15 @@ export const modelsApi = api.injectEndpoints({
           method: 'POST',
         };
       },
-      invalidatesTags: [
-        { type: 'MainModel', id: LIST_TAG },
-        { type: 'SDXLRefinerModel', id: LIST_TAG },
-        { type: 'OnnxModel', id: LIST_TAG },
-      ],
+      invalidatesTags: ['Model'],
     }),
     getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
       query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
       providesTags: (result) => {
-        const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }];
+        const tags: ApiTagDescription[] = [
+          { type: 'LoRAModel', id: LIST_TAG },
+          'Model',
+        ];
 
         if (result) {
           tags.push(
@@ -426,6 +407,7 @@ export const modelsApi = api.injectEndpoints({
       providesTags: (result) => {
         const tags: ApiTagDescription[] = [
           { type: 'ControlNetModel', id: LIST_TAG },
+          'Model',
         ];
 
         if (result) {
@@ -457,6 +439,7 @@ export const modelsApi = api.injectEndpoints({
       providesTags: (result) => {
         const tags: ApiTagDescription[] = [
           { type: 'IPAdapterModel', id: LIST_TAG },
+          'Model',
         ];
 
         if (result) {
@@ -488,6 +471,7 @@ export const modelsApi = api.injectEndpoints({
       providesTags: (result) => {
         const tags: ApiTagDescription[] = [
           { type: 'T2IAdapterModel', id: LIST_TAG },
+          'Model',
         ];
 
         if (result) {
@@ -514,7 +498,10 @@ export const modelsApi = api.injectEndpoints({
     getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
       query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
       providesTags: (result) => {
-        const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }];
+        const tags: ApiTagDescription[] = [
+          { type: 'VaeModel', id: LIST_TAG },
+          'Model',
+        ];
 
         if (result) {
           tags.push(
@@ -545,6 +532,7 @@ export const modelsApi = api.injectEndpoints({
       providesTags: (result) => {
         const tags: ApiTagDescription[] = [
           { type: 'TextualInversionModel', id: LIST_TAG },
+          'Model',
         ];
 
         if (result) {
@@ -577,21 +565,6 @@ export const modelsApi = api.injectEndpoints({
           url: `/models/search?${folderQueryStr}`,
         };
       },
-      providesTags: (result) => {
-        const tags: ApiTagDescription[] = [
-          { type: 'ScannedModels', id: LIST_TAG },
-        ];
-
-        if (result) {
-          tags.push(
-            ...result.map((id) => ({
-              type: 'ScannedModels' as const,
-              id,
-            }))
-          );
-        }
-        return tags;
-      },
     }),
     getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
       query: () => {
diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts
index b39b11af29..7a10377323 100644
--- a/invokeai/frontend/web/src/services/api/index.ts
+++ b/invokeai/frontend/web/src/services/api/index.ts
@@ -9,6 +9,8 @@ import {
 import { $authToken, $baseUrl, $projectId } from 'services/api/client';
 
 export const tagTypes = [
+  'AppVersion',
+  'AppConfig',
   'Board',
   'BoardImagesTotal',
   'BoardAssetsTotal',
@@ -17,15 +19,27 @@ export const tagTypes = [
   'ImageList',
   'ImageMetadata',
   'ImageMetadataFromFile',
-  'Model',
+  'IntermediatesCount',
   'SessionQueueItem',
   'SessionQueueItemDTO',
   'SessionQueueItemDTOList',
   'SessionQueueStatus',
   'SessionProcessorStatus',
+  'CurrentSessionQueueItem',
+  'NextSessionQueueItem',
   'BatchStatus',
   'InvocationCacheStatus',
-];
+  'Model',
+  'T2IAdapterModel',
+  'MainModel',
+  'OnnxModel',
+  'VaeModel',
+  'IPAdapterModel',
+  'TextualInversionModel',
+  'ControlNetModel',
+  'LoRAModel',
+  'SDXLRefinerModel',
+] as const;
 export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
 export const LIST_TAG = 'LIST';