fix(ui): fix up MM queries & types (wip)

This commit is contained in:
psychedelicious 2024-02-27 15:09:42 +11:00 committed by Kent Keirsey
parent b361fabf81
commit ca00fabd79
5 changed files with 40 additions and 50 deletions

View File

@ -10,17 +10,18 @@ import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/F
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useImportAdvancedModelMutation } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => {
const dispatch = useAppDispatch();
const [importAdvancedModel] = useImportAdvancedModelMutation();
const [installModel] = useInstallModelMutation();
const { t } = useTranslation();
@ -49,15 +50,9 @@ export const AdvancedImport = () => {
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
const cleanValues = Object.fromEntries(
Object.entries(values).filter(([value]) => value !== null && value !== undefined)
);
importAdvancedModel({
source: {
path: cleanValues.path as string,
type: 'local',
},
config: cleanValues,
installModel({
source: values.path,
config: omitBy(values, isNil),
})
.unwrap()
.then((_) => {
@ -86,7 +81,7 @@ export const AdvancedImport = () => {
}
});
},
[dispatch, reset, t, importAdvancedModel]
[installModel, dispatch, t, reset]
);
const watchedModelType = watch('type');

View File

@ -6,7 +6,7 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: ScanFolderResponse[number];
@ -15,10 +15,10 @@ export const ScanModelResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [importMainModel] = useImportMainModelsMutation();
const [installModel] = useInstallModelMutation();
const handleQuickAdd = useCallback(() => {
importMainModel({ source: result.path, config: undefined })
installModel({ source: result.path })
.unwrap()
.then((_) => {
dispatch(
@ -42,7 +42,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
);
}
});
}, [importMainModel, result, dispatch, t]);
}, [installModel, result, dispatch, t]);
return (
<Flex justifyContent="space-between">

View File

@ -6,7 +6,7 @@ import { t } from 'i18next';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type SimpleImportModelConfig = {
location: string;
@ -15,7 +15,7 @@ type SimpleImportModelConfig = {
export const SimpleImport = () => {
const dispatch = useAppDispatch();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const [installModel, { isLoading }] = useInstallModelMutation();
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
defaultValues: {
@ -30,7 +30,7 @@ export const SimpleImport = () => {
return;
}
importMainModel({ source: values.location, config: undefined })
installModel({ source: values.location })
.unwrap()
.then((_) => {
dispatch(
@ -57,7 +57,7 @@ export const SimpleImport = () => {
}
});
},
[dispatch, reset, importMainModel]
[dispatch, reset, installModel]
);
return (

View File

@ -1,6 +1,7 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { JSONObject } from 'common/types';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
import type {
@ -19,8 +20,8 @@ import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
type UpdateModelArg = {
key: NonNullable<operations['update_model_record']['parameters']['path']['key']>;
body: NonNullable<operations['update_model_record']['requestBody']['content']['application/json']>;
key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
@ -40,14 +41,15 @@ type DeleteMainModelResponse = void;
type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>;
type InstallModelArg = {
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
config?: JSONObject;
// config: NonNullable<paths['/api/v2/models/heuristic_install']['post']['requestBody']>['content']['application/json'];
};
type ImportMainModelResponse =
paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
@ -58,14 +60,6 @@ type DeleteImportModelsResponse =
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type ImportAdvancedModelArg = {
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
};
type ImportAdvancedModelResponse =
paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
@ -183,7 +177,7 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => {
return {
url: buildModelsUrl('heuristic_install'),
@ -194,16 +188,6 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model', 'ModelImports'],
}),
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
query: ({ source, config }) => {
return {
url: buildModelsUrl('install'),
method: 'POST',
body: { source, config },
};
},
invalidatesTags: ['Model', 'ModelImports'],
}),
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ key }) => {
return {
@ -365,8 +349,7 @@ export const {
useGetVaeModelsQuery,
useDeleteModelsMutation,
useUpdateModelsMutation,
useImportMainModelsMutation,
useImportAdvancedModelMutation,
useInstallModelMutation,
useConvertMainModelsMutation,
useSyncModelsMutation,
useScanModelsQuery,

View File

@ -70,6 +70,7 @@ export type T2IAdapterModelConfig = S['T2IConfig'];
export type TextualInversionModelConfig = S['TextualInversionConfig'];
export type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
@ -81,7 +82,18 @@ export type AnyModelConfig =
| T2IAdapterModelConfig
| TextualInversionModelConfig
| RefinerMainModelConfig
| NonRefinerMainModelConfig;
| NonRefinerMainModelConfig
| CLIPVisionDiffusersConfig;
type AnyModelConfig2 =
| (S['MainDiffusersConfig'] | S['MainCheckpointConfig'])
| (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig'])
| (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'])
| S['LoRAConfig']
| S['TextualInversionConfig']
| S['IPAdapterConfig']
| S['CLIPVisionDiffusersConfig']
| S['T2IConfig'];
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
return config.type === 'lora';