mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix up MM queries & types (wip)
This commit is contained in:
parent
b361fabf81
commit
ca00fabd79
@ -10,17 +10,18 @@ import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/F
|
||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { isNil, omitBy } from 'lodash-es';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useImportAdvancedModelMutation } from 'services/api/endpoints/models';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
export const AdvancedImport = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [importAdvancedModel] = useImportAdvancedModelMutation();
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -49,15 +50,9 @@ export const AdvancedImport = () => {
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
(values) => {
|
||||
const cleanValues = Object.fromEntries(
|
||||
Object.entries(values).filter(([value]) => value !== null && value !== undefined)
|
||||
);
|
||||
importAdvancedModel({
|
||||
source: {
|
||||
path: cleanValues.path as string,
|
||||
type: 'local',
|
||||
},
|
||||
config: cleanValues,
|
||||
installModel({
|
||||
source: values.path,
|
||||
config: omitBy(values, isNil),
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
@ -86,7 +81,7 @@ export const AdvancedImport = () => {
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, reset, t, importAdvancedModel]
|
||||
[installModel, dispatch, t, reset]
|
||||
);
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
|
@ -6,7 +6,7 @@ import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoAdd } from 'react-icons/io5';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
result: ScanFolderResponse[number];
|
||||
@ -15,10 +15,10 @@ export const ScanModelResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [importMainModel] = useImportMainModelsMutation();
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const handleQuickAdd = useCallback(() => {
|
||||
importMainModel({ source: result.path, config: undefined })
|
||||
installModel({ source: result.path })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -42,7 +42,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [importMainModel, result, dispatch, t]);
|
||||
}, [installModel, result, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex justifyContent="space-between">
|
||||
|
@ -6,7 +6,7 @@ import { t } from 'i18next';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type SimpleImportModelConfig = {
|
||||
location: string;
|
||||
@ -15,7 +15,7 @@ type SimpleImportModelConfig = {
|
||||
export const SimpleImport = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||
const [installModel, { isLoading }] = useInstallModelMutation();
|
||||
|
||||
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
||||
defaultValues: {
|
||||
@ -30,7 +30,7 @@ export const SimpleImport = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
importMainModel({ source: values.location, config: undefined })
|
||||
installModel({ source: values.location })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -57,7 +57,7 @@ export const SimpleImport = () => {
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, reset, importMainModel]
|
||||
[dispatch, reset, installModel]
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -1,6 +1,7 @@
|
||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { JSONObject } from 'common/types';
|
||||
import queryString from 'query-string';
|
||||
import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
@ -19,8 +20,8 @@ import type { ApiTagDescription, tagTypes } from '..';
|
||||
import { api, buildV2Url, LIST_TAG } from '..';
|
||||
|
||||
type UpdateModelArg = {
|
||||
key: NonNullable<operations['update_model_record']['parameters']['path']['key']>;
|
||||
body: NonNullable<operations['update_model_record']['requestBody']['content']['application/json']>;
|
||||
key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
@ -40,14 +41,15 @@ type DeleteMainModelResponse = void;
|
||||
type ConvertMainModelResponse =
|
||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ImportMainModelArg = {
|
||||
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
|
||||
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
|
||||
config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>;
|
||||
type InstallModelArg = {
|
||||
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
||||
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
|
||||
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
|
||||
config?: JSONObject;
|
||||
// config: NonNullable<paths['/api/v2/models/heuristic_install']['post']['requestBody']>['content']['application/json'];
|
||||
};
|
||||
|
||||
type ImportMainModelResponse =
|
||||
paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
|
||||
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type ListImportModelsResponse =
|
||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||
@ -58,14 +60,6 @@ type DeleteImportModelsResponse =
|
||||
type PruneModelImportsResponse =
|
||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ImportAdvancedModelArg = {
|
||||
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
|
||||
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
|
||||
};
|
||||
|
||||
type ImportAdvancedModelResponse =
|
||||
paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
export type ScanFolderResponse =
|
||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
||||
@ -183,7 +177,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source, config, access_token }) => {
|
||||
return {
|
||||
url: buildModelsUrl('heuristic_install'),
|
||||
@ -194,16 +188,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
}),
|
||||
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
|
||||
query: ({ source, config }) => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
method: 'POST',
|
||||
body: { source, config },
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
}),
|
||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||
query: ({ key }) => {
|
||||
return {
|
||||
@ -365,8 +349,7 @@ export const {
|
||||
useGetVaeModelsQuery,
|
||||
useDeleteModelsMutation,
|
||||
useUpdateModelsMutation,
|
||||
useImportMainModelsMutation,
|
||||
useImportAdvancedModelMutation,
|
||||
useInstallModelMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useSyncModelsMutation,
|
||||
useScanModelsQuery,
|
||||
|
@ -70,6 +70,7 @@ export type T2IAdapterModelConfig = S['T2IConfig'];
|
||||
export type TextualInversionModelConfig = S['TextualInversionConfig'];
|
||||
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
|
||||
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
|
||||
@ -81,7 +82,18 @@ export type AnyModelConfig =
|
||||
| T2IAdapterModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| RefinerMainModelConfig
|
||||
| NonRefinerMainModelConfig;
|
||||
| NonRefinerMainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
|
||||
type AnyModelConfig2 =
|
||||
| (S['MainDiffusersConfig'] | S['MainCheckpointConfig'])
|
||||
| (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig'])
|
||||
| (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'])
|
||||
| S['LoRAConfig']
|
||||
| S['TextualInversionConfig']
|
||||
| S['IPAdapterConfig']
|
||||
| S['CLIPVisionDiffusersConfig']
|
||||
| S['T2IConfig'];
|
||||
|
||||
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
|
||||
return config.type === 'lora';
|
||||
|
Loading…
Reference in New Issue
Block a user