feat(ui): create rtk-query hooks for individual model types

Eg `useGetMainModelsQuery()`, `useGetLoRAModelsQuery()` instead of `useListModelsQuery({base_type})`.

Add specific adapters for each model type. Just more organised and easier to consume models now.

Also updated LoRA UI to use the model name.
This commit is contained in:
psychedelicious
2023-07-05 11:52:02 +10:00
parent c21b56ba31
commit 52a09422c7
13 changed files with 395 additions and 146 deletions

View File

@ -1,35 +1,85 @@
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { keyBy } from 'lodash-es';
import { ModelsList } from 'services/api/types';
import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
ControlNetModelConfig,
LoRAModelConfig,
MainModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
} from 'services/api/types';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
type ModelConfig = ModelsList['models'][number];
export type MainModelConfigEntity = MainModelConfig & { id: string };
type ListModelsArg = NonNullable<
paths['/api/v1/models/']['get']['parameters']['query']
>;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
const modelsAdapter = createEntityAdapter<ModelConfig>({
selectId: (model) => getModelId(model),
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string;
};
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
type AnyModelConfigEntity =
| MainModelConfigEntity
| LoRAModelConfigEntity
| ControlNetModelConfigEntity
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const controlNetModelsAdapter =
createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const getModelId = ({ base_model, type, name }: ModelConfig) =>
export const getModelId = ({ base_model, type, name }: AnyModelConfig) =>
`${base_model}/${type}/${name}`;
const createModelEntities = <T extends AnyModelConfigEntity>(
models: AnyModelConfig[]
): T[] => {
const entityArray: T[] = [];
models.forEach((model) => {
const entity = {
...cloneDeep(model),
id: getModelId(model),
} as T;
entityArray.push(entity);
});
return entityArray;
};
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
query: (arg) => ({ url: 'models/', params: arg }),
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }];
const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'Model' as const,
type: 'MainModel' as const,
id,
}))
);
@ -37,14 +87,161 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (response: ModelsList, meta, arg) => {
return modelsAdapter.setAll(
modelsAdapter.getInitialState(),
keyBy(response.models, getModelId)
transformResponse: (
response: { models: MainModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
return mainModelsAdapter.setAll(
mainModelsAdapter.getInitialState(),
entities
);
},
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'LoRAModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'LoRAModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: LoRAModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<LoRAModelConfigEntity>(
response.models
);
return loraModelsAdapter.setAll(
loraModelsAdapter.getInitialState(),
entities
);
},
}),
getControlNetModels: build.query<
EntityState<ControlNetModelConfigEntity>,
void
>({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'ControlNetModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'ControlNetModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: ControlNetModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<ControlNetModelConfigEntity>(
response.models
);
return controlNetModelsAdapter.setAll(
controlNetModelsAdapter.getInitialState(),
entities
);
},
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'VaeModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'VaeModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: VaeModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<VaeModelConfigEntity>(
response.models
);
return vaeModelsAdapter.setAll(
vaeModelsAdapter.getInitialState(),
entities
);
},
}),
getTextualInversionModels: build.query<
EntityState<TextualInversionModelConfigEntity>,
void
>({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'TextualInversionModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'TextualInversionModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: TextualInversionModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>(
response.models
);
return textualInversionModelsAdapter.setAll(
textualInversionModelsAdapter.getInitialState(),
entities
);
},
}),
}),
});
export const { useListModelsQuery } = modelsApi;
export const {
useGetMainModelsQuery,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} = modelsApi;