fix(ui): fix refiner missing from model manager

Rolled back the earlier split of the refiner model query.

Now, when you use `useGetMainModelsQuery()`, you must provide it an array of base model types.

They are provided as constants for simplicity:
- ALL_BASE_MODELS
- NON_REFINER_BASE_MODELS
- REFINER_BASE_MODELS

Opted to just use args for the hook instead of wrapping the hook in another hook, we can tidy this up later if desired.
This commit is contained in:
psychedelicious
2023-07-26 11:04:02 +10:00
parent 6fa244a343
commit cbcd416b70
19 changed files with 72 additions and 75 deletions

View File

@ -107,9 +107,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@ -147,11 +144,14 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
query: () => {
getMainModels: build.query<
EntityState<MainModelConfigEntity>,
BaseModelType[]
>({
query: (base_models) => {
const params = {
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl'],
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
@ -187,43 +187,6 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
{
query: () => ({
url: 'models/',
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
}),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'SDXLRefinerModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'SDXLRefinerModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: MainModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
return sdxlRefinerModelsAdapter.setAll(
sdxlRefinerModelsAdapter.getInitialState(),
entities
);
},
}
),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
@ -494,7 +457,6 @@ export const modelsApi = api.injectEndpoints({
export const {
useGetMainModelsQuery,
useGetSDXLRefinerModelsQuery,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,