feat(ui): refactor metadata handling

Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models.

- Create helpers to fetch a model outside react (e.g. not in a hook)
- Created helpers to parse model metadata
- Renamed a lot of types that were confusing and/or had naming collisions
This commit is contained in:
psychedelicious
2024-02-22 17:33:20 +11:00
parent 79b16596b5
commit 3ed2963f43
16 changed files with 443 additions and 486 deletions

View File

@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
BaseModelType,
ControlNetConfig,
ControlNetModelConfig,
ImportModelConfig,
IPAdapterConfig,
LoRAConfig,
IPAdapterModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
ModelType,
T2IAdapterConfig,
TextualInversionConfig,
VAEConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import type { ApiTagDescription, tagTypes } from '..';
@ -30,7 +30,7 @@ type UpdateMainModelArg = {
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAConfig;
body: LoRAModelConfig;
};
type UpdateMainModelResponse =
@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap
undefined,
getSelectorsOptions
);
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
@ -162,6 +162,8 @@ const buildTransformResponse =
*/
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
}),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => {
@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
}),
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
}),
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
}),
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
}),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {