feat(ui): update model identifier to be key (wip)

- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet.
- Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
This commit is contained in:
psychedelicious
2024-02-16 18:56:02 +11:00
parent 6df3c450e8
commit dab939f7d1
54 changed files with 267 additions and 453 deletions

View File

@ -1,64 +1,26 @@
import type { EntityState } from '@reduxjs/toolkit';
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { cloneDeep } from 'lodash-es';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
ControlNetConfig,
ImportModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
IPAdapterConfig,
LoRAConfig,
MainModelConfig,
MergeModelConfig,
ModelType,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
T2IAdapterConfig,
TextualInversionConfig,
VAEConfig,
} from 'services/api/types';
import type { ApiTagDescription } from '..';
import type { ApiTagDescription, tagTypes } from '..';
import { api, LIST_TAG } from '..';
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
id: string;
};
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
id: string;
};
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string;
};
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
export type AnyModelConfigEntity =
| MainModelConfigEntity
| LoRAModelConfigEntity
| ControlNetModelConfigEntity
| IPAdapterModelConfigEntity
| T2IAdapterModelConfigEntity
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
@ -68,11 +30,11 @@ type UpdateMainModelArg = {
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
body: LoRAConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateLoRAModelResponse = UpdateMainModelResponse;
@ -128,59 +90,71 @@ type CheckpointConfigsResponse =
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const getModelId = ({
base_model,
model_type,
model_name,
}: Pick<AnyModelConfig, 'base_model' | 'model_name' | 'model_type'>) => `${base_model}/${model_type}/${model_name}`;
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
const createModelEntities = <T extends AnyModelConfigEntity>(models: AnyModelConfig[]): T[] => {
const entityArray: T[] = [];
models.forEach((model) => {
const entity = {
...cloneDeep(model),
id: getModelId(model),
} as T;
entityArray.push(entity);
});
return entityArray;
};
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity, string>, BaseModelType[]>({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params = {
model_type: 'main',
@ -190,24 +164,8 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'MainModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: MainModelConfig[] }) => {
const entities = createModelEntities<MainModelConfigEntity>(response.models);
return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
}),
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({
query: ({ base_model, model_name, body }) => {
@ -277,26 +235,10 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity, string>, void>({
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'LoRAModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: LoRAModelConfig[] }) => {
const entities = createModelEntities<LoRAModelConfigEntity>(response.models);
return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
}),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => {
@ -317,110 +259,30 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfigEntity, string>, void>({
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'ControlNetModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
const entities = createModelEntities<ControlNetModelConfigEntity>(response.models);
return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
}),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfigEntity, string>, void>({
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'IPAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
const entities = createModelEntities<IPAdapterModelConfigEntity>(response.models);
return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfigEntity, string>, void>({
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'T2IAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
const entities = createModelEntities<T2IAdapterModelConfigEntity>(response.models);
return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity, string>, void>({
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'VaeModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: VaeModelConfig[] }) => {
const entities = createModelEntities<VaeModelConfigEntity>(response.models);
return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
}),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfigEntity, string>, void>({
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'TextualInversionModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: TextualInversionModelConfig[] }) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>(response.models);
return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
}),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {