mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): refactor metadata handling
Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models. - Create helpers to fetch a model outside react (e.g. not in a hook) - Created helpers to parse model metadata - Renamed a lot of types that were confusing and/or had naming collisions
This commit is contained in:
@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlNetConfig,
|
||||
ControlNetModelConfig,
|
||||
ImportModelConfig,
|
||||
IPAdapterConfig,
|
||||
LoRAConfig,
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
ModelType,
|
||||
T2IAdapterConfig,
|
||||
TextualInversionConfig,
|
||||
VAEConfig,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
@ -30,7 +30,7 @@ type UpdateMainModelArg = {
|
||||
type UpdateLoRAModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: LoRAConfig;
|
||||
body: LoRAModelConfig;
|
||||
};
|
||||
|
||||
type UpdateMainModelResponse =
|
||||
@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
|
||||
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
|
||||
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
|
||||
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
|
||||
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
|
||||
export const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
@ -162,6 +162,8 @@ const buildTransformResponse =
|
||||
*/
|
||||
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||
|
||||
// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
|
||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
||||
}),
|
||||
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
|
||||
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
||||
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
|
||||
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
|
||||
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
||||
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
||||
}),
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
|
||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
||||
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
|
||||
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
|
||||
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
||||
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
||||
}),
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
|
||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
||||
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
|
||||
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
|
||||
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
||||
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
|
||||
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
||||
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
|
||||
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
|
||||
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
||||
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
||||
}),
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
|
||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
||||
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
|
||||
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
|
||||
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
||||
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
||||
}),
|
||||
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
|
||||
query: (arg) => {
|
||||
|
Reference in New Issue
Block a user