feat: Move SDXL Refiner to own route & set appropriate disabled statuses

This commit is contained in:
blessedcoolant
2023-07-25 22:14:19 +12:00
committed by psychedelicious
parent 8d1b8179af
commit 5202610160
9 changed files with 159 additions and 32 deletions

View File

@ -107,6 +107,9 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@ -176,6 +179,43 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
{
query: () => ({
url: 'models/',
params: { model_type: 'main', base_models: 'sdxl-refiner' },
}),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'SDXLRefinerModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'SDXLRefinerModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: MainModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
return sdxlRefinerModelsAdapter.setAll(
sdxlRefinerModelsAdapter.getInitialState(),
entities
);
},
}
),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
@ -187,7 +227,10 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}),
importMainModels: build.mutation<
ImportMainModelResponse,
@ -200,7 +243,10 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
@ -210,7 +256,10 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
@ -222,7 +271,10 @@ export const modelsApi = api.injectEndpoints({
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
@ -235,7 +287,10 @@ export const modelsApi = api.injectEndpoints({
params: params,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
@ -245,7 +300,10 @@ export const modelsApi = api.injectEndpoints({
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}),
syncModels: build.mutation<SyncModelsResponse, void>({
query: () => {
@ -254,7 +312,10 @@ export const modelsApi = api.injectEndpoints({
method: 'POST',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
@ -425,6 +486,7 @@ export const modelsApi = api.injectEndpoints({
export const {
useGetMainModelsQuery,
useGetSDXLRefinerModelsQuery,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,