mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): create rtk-query hooks for individual model types
Eg `useGetMainModelsQuery()`, `useGetLoRAModelsQuery()` instead of `useListModelsQuery({base_type})`. Add specific adapters for each model type. Just more organised and easier to consume models now. Also updated LoRA UI to use the model name.
This commit is contained in:
@ -1,35 +1,85 @@
|
||||
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import { ModelsList } from 'services/api/types';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import {
|
||||
AnyModelConfig,
|
||||
ControlNetModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VaeModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
||||
import { paths } from '../schema';
|
||||
|
||||
type ModelConfig = ModelsList['models'][number];
|
||||
export type MainModelConfigEntity = MainModelConfig & { id: string };
|
||||
|
||||
type ListModelsArg = NonNullable<
|
||||
paths['/api/v1/models/']['get']['parameters']['query']
|
||||
>;
|
||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||
|
||||
const modelsAdapter = createEntityAdapter<ModelConfig>({
|
||||
selectId: (model) => getModelId(model),
|
||||
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
||||
|
||||
type AnyModelConfigEntity =
|
||||
| MainModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| ControlNetModelConfigEntity
|
||||
| TextualInversionModelConfigEntity
|
||||
| VaeModelConfigEntity;
|
||||
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
const controlNetModelsAdapter =
|
||||
createEntityAdapter<ControlNetModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
const textualInversionModelsAdapter =
|
||||
createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
const getModelId = ({ base_model, type, name }: ModelConfig) =>
|
||||
export const getModelId = ({ base_model, type, name }: AnyModelConfig) =>
|
||||
`${base_model}/${type}/${name}`;
|
||||
|
||||
const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
models: AnyModelConfig[]
|
||||
): T[] => {
|
||||
const entityArray: T[] = [];
|
||||
models.forEach((model) => {
|
||||
const entity = {
|
||||
...cloneDeep(model),
|
||||
id: getModelId(model),
|
||||
} as T;
|
||||
entityArray.push(entity);
|
||||
});
|
||||
return entityArray;
|
||||
};
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
|
||||
query: (arg) => ({ url: 'models/', params: arg }),
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }];
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'MainModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'Model' as const,
|
||||
type: 'MainModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
@ -37,14 +87,161 @@ export const modelsApi = api.injectEndpoints({
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: ModelsList, meta, arg) => {
|
||||
return modelsAdapter.setAll(
|
||||
modelsAdapter.getInitialState(),
|
||||
keyBy(response.models, getModelId)
|
||||
transformResponse: (
|
||||
response: { models: MainModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<MainModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return mainModelsAdapter.setAll(
|
||||
mainModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'LoRAModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'LoRAModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: LoRAModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<LoRAModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return loraModelsAdapter.setAll(
|
||||
loraModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getControlNetModels: build.query<
|
||||
EntityState<ControlNetModelConfigEntity>,
|
||||
void
|
||||
>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'ControlNetModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'ControlNetModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: ControlNetModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<ControlNetModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return controlNetModelsAdapter.setAll(
|
||||
controlNetModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'VaeModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'VaeModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: VaeModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<VaeModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return vaeModelsAdapter.setAll(
|
||||
vaeModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getTextualInversionModels: build.query<
|
||||
EntityState<TextualInversionModelConfigEntity>,
|
||||
void
|
||||
>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'TextualInversionModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'TextualInversionModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: TextualInversionModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<TextualInversionModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return textualInversionModelsAdapter.setAll(
|
||||
textualInversionModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
export const { useListModelsQuery } = modelsApi;
|
||||
export const {
|
||||
useGetMainModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
} = modelsApi;
|
||||
|
Reference in New Issue
Block a user