mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): update model identifier to be key (wip)
- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet. - Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
This commit is contained in:
@ -1,64 +1,26 @@
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import queryString from 'query-string';
|
||||
import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
DiffusersModelConfig,
|
||||
ControlNetConfig,
|
||||
ImportModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
IPAdapterConfig,
|
||||
LoRAConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
ModelType,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VaeModelConfig,
|
||||
T2IAdapterConfig,
|
||||
TextualInversionConfig,
|
||||
VAEConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import type { ApiTagDescription } from '..';
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
import { api, LIST_TAG } from '..';
|
||||
|
||||
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
|
||||
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity;
|
||||
|
||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||
|
||||
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
||||
|
||||
export type AnyModelConfigEntity =
|
||||
| MainModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| ControlNetModelConfigEntity
|
||||
| IPAdapterModelConfigEntity
|
||||
| T2IAdapterModelConfigEntity
|
||||
| TextualInversionModelConfigEntity
|
||||
| VaeModelConfigEntity;
|
||||
|
||||
type UpdateMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
@ -68,11 +30,11 @@ type UpdateMainModelArg = {
|
||||
type UpdateLoRAModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: LoRAModelConfig;
|
||||
body: LoRAConfig;
|
||||
};
|
||||
|
||||
type UpdateMainModelResponse =
|
||||
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
|
||||
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type UpdateLoRAModelResponse = UpdateMainModelResponse;
|
||||
|
||||
@ -128,59 +90,71 @@ type CheckpointConfigsResponse =
|
||||
|
||||
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
|
||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
export const getModelId = ({
|
||||
base_model,
|
||||
model_type,
|
||||
model_name,
|
||||
}: Pick<AnyModelConfig, 'base_model' | 'model_name' | 'model_type'>) => `${base_model}/${model_type}/${model_name}`;
|
||||
const buildProvidesTags =
|
||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
|
||||
const createModelEntities = <T extends AnyModelConfigEntity>(models: AnyModelConfig[]): T[] => {
|
||||
const entityArray: T[] = [];
|
||||
models.forEach((model) => {
|
||||
const entity = {
|
||||
...cloneDeep(model),
|
||||
id: getModelId(model),
|
||||
} as T;
|
||||
entityArray.push(entity);
|
||||
});
|
||||
return entityArray;
|
||||
};
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
};
|
||||
|
||||
const buildTransformResponse =
|
||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity, string>, BaseModelType[]>({
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
query: (base_models) => {
|
||||
const params = {
|
||||
model_type: 'main',
|
||||
@ -190,24 +164,8 @@ export const modelsApi = api.injectEndpoints({
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return `models/?${query}`;
|
||||
},
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'MainModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: MainModelConfig[] }) => {
|
||||
const entities = createModelEntities<MainModelConfigEntity>(response.models);
|
||||
return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
}),
|
||||
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
@ -277,26 +235,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity, string>, void>({
|
||||
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'LoRAModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: LoRAModelConfig[] }) => {
|
||||
const entities = createModelEntities<LoRAModelConfigEntity>(response.models);
|
||||
return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
|
||||
}),
|
||||
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
@ -317,110 +259,30 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
getControlNetModels: build.query<EntityState<ControlNetModelConfigEntity, string>, void>({
|
||||
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'ControlNetModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
|
||||
const entities = createModelEntities<ControlNetModelConfigEntity>(response.models);
|
||||
return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
|
||||
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
|
||||
}),
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfigEntity, string>, void>({
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'IPAdapterModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
|
||||
const entities = createModelEntities<IPAdapterModelConfigEntity>(response.models);
|
||||
return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
|
||||
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
|
||||
}),
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfigEntity, string>, void>({
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'T2IAdapterModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
|
||||
const entities = createModelEntities<T2IAdapterModelConfigEntity>(response.models);
|
||||
return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
|
||||
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity, string>, void>({
|
||||
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'VaeModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: VaeModelConfig[] }) => {
|
||||
const entities = createModelEntities<VaeModelConfigEntity>(response.models);
|
||||
return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
|
||||
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
|
||||
}),
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfigEntity, string>, void>({
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'TextualInversionModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: TextualInversionModelConfig[] }) => {
|
||||
const entities = createModelEntities<TextualInversionModelConfigEntity>(response.models);
|
||||
return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities);
|
||||
},
|
||||
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
|
||||
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
|
||||
}),
|
||||
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
|
||||
query: (arg) => {
|
||||
|
Reference in New Issue
Block a user