fix(ui): fix up MM queries & types (wip)

This commit is contained in:
psychedelicious 2024-02-27 15:09:42 +11:00
parent 202e739404
commit 97ecd99b9c
5 changed files with 40 additions and 50 deletions

View File

@ -10,17 +10,18 @@ import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/F
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect'; import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react'; import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useImportAdvancedModelMutation } from 'services/api/endpoints/models'; import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => { export const AdvancedImport = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [importAdvancedModel] = useImportAdvancedModelMutation(); const [installModel] = useInstallModelMutation();
const { t } = useTranslation(); const { t } = useTranslation();
@ -49,15 +50,9 @@ export const AdvancedImport = () => {
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>( const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => { (values) => {
const cleanValues = Object.fromEntries( installModel({
Object.entries(values).filter(([value]) => value !== null && value !== undefined) source: values.path,
); config: omitBy(values, isNil),
importAdvancedModel({
source: {
path: cleanValues.path as string,
type: 'local',
},
config: cleanValues,
}) })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
@ -86,7 +81,7 @@ export const AdvancedImport = () => {
} }
}); });
}, },
[dispatch, reset, t, importAdvancedModel] [installModel, dispatch, t, reset]
); );
const watchedModelType = watch('type'); const watchedModelType = watch('type');

View File

@ -6,7 +6,7 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5'; import { IoAdd } from 'react-icons/io5';
import type { ScanFolderResponse } from 'services/api/endpoints/models'; import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { useImportMainModelsMutation } from 'services/api/endpoints/models'; import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = { type Props = {
result: ScanFolderResponse[number]; result: ScanFolderResponse[number];
@ -15,10 +15,10 @@ export const ScanModelResultItem = ({ result }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [importMainModel] = useImportMainModelsMutation(); const [installModel] = useInstallModelMutation();
const handleQuickAdd = useCallback(() => { const handleQuickAdd = useCallback(() => {
importMainModel({ source: result.path, config: undefined }) installModel({ source: result.path })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -42,7 +42,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
); );
} }
}); });
}, [importMainModel, result, dispatch, t]); }, [installModel, result, dispatch, t]);
return ( return (
<Flex justifyContent="space-between"> <Flex justifyContent="space-between">

View File

@ -6,7 +6,7 @@ import { t } from 'i18next';
import { useCallback } from 'react'; import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useImportMainModelsMutation } from 'services/api/endpoints/models'; import { useInstallModelMutation } from 'services/api/endpoints/models';
type SimpleImportModelConfig = { type SimpleImportModelConfig = {
location: string; location: string;
@ -15,7 +15,7 @@ type SimpleImportModelConfig = {
export const SimpleImport = () => { export const SimpleImport = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [importMainModel, { isLoading }] = useImportMainModelsMutation(); const [installModel, { isLoading }] = useInstallModelMutation();
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({ const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
defaultValues: { defaultValues: {
@ -30,7 +30,7 @@ export const SimpleImport = () => {
return; return;
} }
importMainModel({ source: values.location, config: undefined }) installModel({ source: values.location })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -57,7 +57,7 @@ export const SimpleImport = () => {
} }
}); });
}, },
[dispatch, reset, importMainModel] [dispatch, reset, installModel]
); );
return ( return (

View File

@ -1,6 +1,7 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { JSONObject } from 'common/types';
import queryString from 'query-string'; import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema'; import type { operations, paths } from 'services/api/schema';
import type { import type {
@ -19,8 +20,8 @@ import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..'; import { api, buildV2Url, LIST_TAG } from '..';
type UpdateModelArg = { type UpdateModelArg = {
key: NonNullable<operations['update_model_record']['parameters']['path']['key']>; key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
body: NonNullable<operations['update_model_record']['requestBody']['content']['application/json']>; body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
}; };
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
@ -40,14 +41,15 @@ type DeleteMainModelResponse = void;
type ConvertMainModelResponse = type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json']; paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = { type InstallModelArg = {
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>; source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token']; access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
config: NonNullable<operations['heuristic_install_model']['requestBody']['content']['application/json']>; // TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
config?: JSONObject;
// config: NonNullable<paths['/api/v2/models/heuristic_install']['post']['requestBody']>['content']['application/json'];
}; };
type ImportMainModelResponse = type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
paths['/api/v2/models/heuristic_install']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse = type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
@ -58,14 +60,6 @@ type DeleteImportModelsResponse =
type PruneModelImportsResponse = type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json']; paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type ImportAdvancedModelArg = {
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
};
type ImportAdvancedModelResponse =
paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
export type ScanFolderResponse = export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query']; type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
@ -183,7 +177,7 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({ installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => { query: ({ source, config, access_token }) => {
return { return {
url: buildModelsUrl('heuristic_install'), url: buildModelsUrl('heuristic_install'),
@ -194,16 +188,6 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model', 'ModelImports'], invalidatesTags: ['Model', 'ModelImports'],
}), }),
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
query: ({ source, config }) => {
return {
url: buildModelsUrl('install'),
method: 'POST',
body: { source, config },
};
},
invalidatesTags: ['Model', 'ModelImports'],
}),
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({ deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ key }) => { query: ({ key }) => {
return { return {
@ -365,8 +349,7 @@ export const {
useGetVaeModelsQuery, useGetVaeModelsQuery,
useDeleteModelsMutation, useDeleteModelsMutation,
useUpdateModelsMutation, useUpdateModelsMutation,
useImportMainModelsMutation, useInstallModelMutation,
useImportAdvancedModelMutation,
useConvertMainModelsMutation, useConvertMainModelsMutation,
useSyncModelsMutation, useSyncModelsMutation,
useScanModelsQuery, useScanModelsQuery,

View File

@ -70,6 +70,7 @@ export type T2IAdapterModelConfig = S['T2IConfig'];
export type TextualInversionModelConfig = S['TextualInversionConfig']; export type TextualInversionModelConfig = S['TextualInversionConfig'];
export type DiffusersModelConfig = S['MainDiffusersConfig']; export type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig']; export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' }; export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' }; export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
@ -81,7 +82,18 @@ export type AnyModelConfig =
| T2IAdapterModelConfig | T2IAdapterModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| RefinerMainModelConfig | RefinerMainModelConfig
| NonRefinerMainModelConfig; | NonRefinerMainModelConfig
| CLIPVisionDiffusersConfig;
type AnyModelConfig2 =
| (S['MainDiffusersConfig'] | S['MainCheckpointConfig'])
| (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig'])
| (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'])
| S['LoRAConfig']
| S['TextualInversionConfig']
| S['IPAdapterConfig']
| S['CLIPVisionDiffusersConfig']
| S['T2IConfig'];
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => { export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
return config.type === 'lora'; return config.type === 'lora';