fix(ui): invalidate query cache for all models on sync models

Also realised the tags were set up incorrectly, fixed that to get type safety with tags.
This commit is contained in:
psychedelicious 2023-10-07 22:21:38 +11:00
parent 1cc686734b
commit a681fa4b03
2 changed files with 43 additions and 56 deletions

View File

@ -199,7 +199,10 @@ export const modelsApi = api.injectEndpoints({
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'OnnxModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -236,7 +239,10 @@ export const modelsApi = api.injectEndpoints({
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -270,11 +276,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
importMainModels: build.mutation<
ImportMainModelResponse,
@ -287,11 +289,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
@ -301,11 +299,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
@ -317,11 +311,7 @@ export const modelsApi = api.injectEndpoints({
method: 'DELETE',
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
@ -334,11 +324,7 @@ export const modelsApi = api.injectEndpoints({
params: { convert_dest_directory },
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
@ -348,11 +334,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
syncModels: build.mutation<SyncModelsResponse, void>({
query: () => {
@ -361,16 +343,15 @@ export const modelsApi = api.injectEndpoints({
method: 'POST',
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'LoRAModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -426,6 +407,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'ControlNetModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -457,6 +439,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'IPAdapterModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -488,6 +471,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'T2IAdapterModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -514,7 +498,10 @@ export const modelsApi = api.injectEndpoints({
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'VaeModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -545,6 +532,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'TextualInversionModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -577,21 +565,6 @@ export const modelsApi = api.injectEndpoints({
url: `/models/search?${folderQueryStr}`,
};
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG },
];
if (result) {
tags.push(
...result.map((id) => ({
type: 'ScannedModels' as const,
id,
}))
);
}
return tags;
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {

View File

@ -9,6 +9,8 @@ import {
import { $authToken, $baseUrl, $projectId } from 'services/api/client';
export const tagTypes = [
'AppVersion',
'AppConfig',
'Board',
'BoardImagesTotal',
'BoardAssetsTotal',
@ -17,15 +19,27 @@ export const tagTypes = [
'ImageList',
'ImageMetadata',
'ImageMetadataFromFile',
'Model',
'IntermediatesCount',
'SessionQueueItem',
'SessionQueueItemDTO',
'SessionQueueItemDTOList',
'SessionQueueStatus',
'SessionProcessorStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'BatchStatus',
'InvocationCacheStatus',
];
'Model',
'T2IAdapterModel',
'MainModel',
'OnnxModel',
'VaeModel',
'IPAdapterModel',
'TextualInversionModel',
'ControlNetModel',
'LoRAModel',
'SDXLRefinerModel',
] as const;
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
export const LIST_TAG = 'LIST';