fix(ui): invalidate query cache for all models on sync models

Also realised the tags were set up incorrectly, fixed that to get type safety with tags.
This commit is contained in:
psychedelicious
2023-10-07 22:21:38 +11:00
parent 1cc686734b
commit a681fa4b03
2 changed files with 43 additions and 56 deletions

View File

@ -199,7 +199,10 @@ export const modelsApi = api.injectEndpoints({
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'OnnxModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -236,7 +239,10 @@ export const modelsApi = api.injectEndpoints({
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -270,11 +276,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
importMainModels: build.mutation<
ImportMainModelResponse,
@ -287,11 +289,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
@ -301,11 +299,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
@ -317,11 +311,7 @@ export const modelsApi = api.injectEndpoints({
method: 'DELETE',
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
@ -334,11 +324,7 @@ export const modelsApi = api.injectEndpoints({
params: { convert_dest_directory },
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
@ -348,11 +334,7 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
syncModels: build.mutation<SyncModelsResponse, void>({
query: () => {
@ -361,16 +343,15 @@ export const modelsApi = api.injectEndpoints({
method: 'POST',
};
},
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
invalidatesTags: ['Model'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'LoRAModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -426,6 +407,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'ControlNetModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -457,6 +439,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'IPAdapterModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -488,6 +471,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'T2IAdapterModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -514,7 +498,10 @@ export const modelsApi = api.injectEndpoints({
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }];
const tags: ApiTagDescription[] = [
{ type: 'VaeModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
@ -545,6 +532,7 @@ export const modelsApi = api.injectEndpoints({
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'TextualInversionModel', id: LIST_TAG },
'Model',
];
if (result) {
@ -577,21 +565,6 @@ export const modelsApi = api.injectEndpoints({
url: `/models/search?${folderQueryStr}`,
};
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG },
];
if (result) {
tags.push(
...result.map((id) => ({
type: 'ScannedModels' as const,
id,
}))
);
}
return tags;
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {