mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix up MM queries & types (wip)
This commit is contained in:
parent
202e739404
commit
97ecd99b9c
@ -10,17 +10,18 @@ import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/F
|
|||||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { isNil, omitBy } from 'lodash-es';
|
||||||
import { useCallback, useEffect } from 'react';
|
import { useCallback, useEffect } from 'react';
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useImportAdvancedModelMutation } from 'services/api/endpoints/models';
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const AdvancedImport = () => {
|
export const AdvancedImport = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [importAdvancedModel] = useImportAdvancedModelMutation();
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -49,15 +50,9 @@ export const AdvancedImport = () => {
|
|||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||||
(values) => {
|
(values) => {
|
||||||
const cleanValues = Object.fromEntries(
|
installModel({
|
||||||
Object.entries(values).filter(([value]) => value !== null && value !== undefined)
|
source: values.path,
|
||||||
);
|
config: omitBy(values, isNil),
|
||||||
importAdvancedModel({
|
|
||||||
source: {
|
|
||||||
path: cleanValues.path as string,
|
|
||||||
type: 'local',
|
|
||||||
},
|
|
||||||
config: cleanValues,
|
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
@ -86,7 +81,7 @@ export const AdvancedImport = () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[dispatch, reset, t, importAdvancedModel]
|
[installModel, dispatch, t, reset]
|
||||||
);
|
);
|
||||||
|
|
||||||
const watchedModelType = watch('type');
|
const watchedModelType = watch('type');
|
||||||
|
@ -6,7 +6,7 @@ import { useCallback } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoAdd } from 'react-icons/io5';
|
import { IoAdd } from 'react-icons/io5';
|
||||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
result: ScanFolderResponse[number];
|
result: ScanFolderResponse[number];
|
||||||
@ -15,10 +15,10 @@ export const ScanModelResultItem = ({ result }: Props) => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [importMainModel] = useImportMainModelsMutation();
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
const handleQuickAdd = useCallback(() => {
|
const handleQuickAdd = useCallback(() => {
|
||||||
importMainModel({ source: result.path, config: undefined })
|
installModel({ source: result.path })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -42,7 +42,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [importMainModel, result, dispatch, t]);
|
}, [installModel, result, dispatch, t]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex justifyContent="space-between">
|
<Flex justifyContent="space-between">
|
||||||
|
@ -6,7 +6,7 @@ import { t } from 'i18next';
|
|||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type SimpleImportModelConfig = {
|
type SimpleImportModelConfig = {
|
||||||
location: string;
|
location: string;
|
||||||
@ -15,7 +15,7 @@ type SimpleImportModelConfig = {
|
|||||||
export const SimpleImport = () => {
|
export const SimpleImport = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
const [installModel, { isLoading }] = useInstallModelMutation();
|
||||||
|
|
||||||
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
||||||
defaultValues: {
|
defaultValues: {
|
||||||
@ -30,7 +30,7 @@ export const SimpleImport = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
importMainModel({ source: values.location, config: undefined })
|
installModel({ source: values.location })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -57,7 +57,7 @@ export const SimpleImport = () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[dispatch, reset, importMainModel]
|
[dispatch, reset, installModel]
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||||
|
import type { JSONObject } from 'common/types';
|
||||||
import queryString from 'query-string';
|
import queryString from 'query-string';
|
||||||
import type { operations, paths } from 'services/api/schema';
|
import type { operations, paths } from 'services/api/schema';
|
||||||
import type {
|
import type {
|
||||||
@ -19,8 +20,8 @@ import type { ApiTagDescription, tagTypes } from '..';
|
|||||||
import { api, buildV2Url, LIST_TAG } from '..';
|
import { api, buildV2Url, LIST_TAG } from '..';
|
||||||
|
|
||||||
type UpdateModelArg = {
|
type UpdateModelArg = {
|
||||||
key: NonNullable<operations['update_model_record']['parameters']['path']['key']>;
|
key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
|
||||||
body: NonNullable<operations['update_model_record']['requestBody']['content']['application/json']>;
|
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||||
};
|
};
|
||||||
|
|
||||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||||
@ -40,14 +41,15 @@ type DeleteMainModelResponse = void;
|
|||||||
type ConvertMainModelResponse =
|
type ConvertMainModelResponse =
|
||||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type ImportMainModelArg = {
|
type InstallModelArg = {
|
||||||
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
|
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
||||||
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
|
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
|
||||||
config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>;
|
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
|
||||||
|
config?: JSONObject;
|
||||||
|
// config: NonNullable<paths['/api/v2/models/heuristic_install']['post']['requestBody']>['content']['application/json'];
|
||||||
};
|
};
|
||||||
|
|
||||||
type ImportMainModelResponse =
|
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
||||||
paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
|
|
||||||
|
|
||||||
type ListImportModelsResponse =
|
type ListImportModelsResponse =
|
||||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||||
@ -58,14 +60,6 @@ type DeleteImportModelsResponse =
|
|||||||
type PruneModelImportsResponse =
|
type PruneModelImportsResponse =
|
||||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type ImportAdvancedModelArg = {
|
|
||||||
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
|
|
||||||
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
|
|
||||||
};
|
|
||||||
|
|
||||||
type ImportAdvancedModelResponse =
|
|
||||||
paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
|
|
||||||
|
|
||||||
export type ScanFolderResponse =
|
export type ScanFolderResponse =
|
||||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||||
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
||||||
@ -183,7 +177,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
|
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||||
query: ({ source, config, access_token }) => {
|
query: ({ source, config, access_token }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl('heuristic_install'),
|
url: buildModelsUrl('heuristic_install'),
|
||||||
@ -194,16 +188,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model', 'ModelImports'],
|
invalidatesTags: ['Model', 'ModelImports'],
|
||||||
}),
|
}),
|
||||||
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
|
|
||||||
query: ({ source, config }) => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl('install'),
|
|
||||||
method: 'POST',
|
|
||||||
body: { source, config },
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['Model', 'ModelImports'],
|
|
||||||
}),
|
|
||||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||||
query: ({ key }) => {
|
query: ({ key }) => {
|
||||||
return {
|
return {
|
||||||
@ -365,8 +349,7 @@ export const {
|
|||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useUpdateModelsMutation,
|
useUpdateModelsMutation,
|
||||||
useImportMainModelsMutation,
|
useInstallModelMutation,
|
||||||
useImportAdvancedModelMutation,
|
|
||||||
useConvertMainModelsMutation,
|
useConvertMainModelsMutation,
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
useScanModelsQuery,
|
useScanModelsQuery,
|
||||||
|
@ -70,6 +70,7 @@ export type T2IAdapterModelConfig = S['T2IConfig'];
|
|||||||
export type TextualInversionModelConfig = S['TextualInversionConfig'];
|
export type TextualInversionModelConfig = S['TextualInversionConfig'];
|
||||||
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||||
|
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||||
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
|
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
|
||||||
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
|
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
|
||||||
@ -81,7 +82,18 @@ export type AnyModelConfig =
|
|||||||
| T2IAdapterModelConfig
|
| T2IAdapterModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| RefinerMainModelConfig
|
| RefinerMainModelConfig
|
||||||
| NonRefinerMainModelConfig;
|
| NonRefinerMainModelConfig
|
||||||
|
| CLIPVisionDiffusersConfig;
|
||||||
|
|
||||||
|
type AnyModelConfig2 =
|
||||||
|
| (S['MainDiffusersConfig'] | S['MainCheckpointConfig'])
|
||||||
|
| (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig'])
|
||||||
|
| (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'])
|
||||||
|
| S['LoRAConfig']
|
||||||
|
| S['TextualInversionConfig']
|
||||||
|
| S['IPAdapterConfig']
|
||||||
|
| S['CLIPVisionDiffusersConfig']
|
||||||
|
| S['T2IConfig'];
|
||||||
|
|
||||||
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
|
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
|
||||||
return config.type === 'lora';
|
return config.type === 'lora';
|
||||||
|
Loading…
Reference in New Issue
Block a user