From 7f56e84a8d1668f5e8b85ad657541883f94a2c9f Mon Sep 17 00:00:00 2001 From: Jennifer Player Date: Fri, 23 Feb 2024 16:16:42 -0500 Subject: [PATCH] refactored and fixed issues with advanced import form --- .../AddModelPanel/AdvancedImport.tsx | 248 +++++++++++++++--- .../AdvancedImportCheckpoint.tsx | 160 ----------- .../AddModelPanel/AdvancedImportDiffusers.tsx | 132 ---------- .../ModelPanel/Fields/ModelFormatSelect.tsx | 12 +- .../web/src/services/api/endpoints/models.ts | 17 +- 5 files changed, 231 insertions(+), 338 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImportCheckpoint.tsx delete mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImportDiffusers.tsx diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx index 9ec39a0649..36af89baec 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/AdvancedImport.tsx @@ -1,56 +1,238 @@ -import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; -import { Combobox, Flex, FormLabel,Text } from '@invoke-ai/ui-library'; +import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { useCallback, useMemo, useState } from 'react'; +import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect'; +import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect'; +import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect'; +import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect'; +import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect'; +import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect'; +import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { useCallback, useEffect } from 'react'; +import type { SubmitHandler} from 'react-hook-form'; +import {useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { useImportAdvancedModelMutation } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; import { z } from 'zod'; -import { AdvancedImportCheckpoint } from './AdvancedImportCheckpoint'; -import { AdvancedImportDiffusers } from './AdvancedImportDiffusers'; - export const zManualAddMode = z.enum(['diffusers', 'checkpoint']); export type ManualAddMode = z.infer; export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success; export const AdvancedImport = () => { - const [advancedAddMode, setAdvancedAddMode] = useState('diffusers'); + const dispatch = useAppDispatch(); + + const [importAdvancedModel] = useImportAdvancedModelMutation(); const { t } = useTranslation(); - const handleChange: ComboboxOnChange = useCallback((v) => { - if (!isManualAddMode(v?.value)) { - return; - } - setAdvancedAddMode(v.value); - }, []); - const options: ComboboxOption[] = useMemo( - () => [ - { label: t('modelManager.diffusersModels'), value: 'diffusers' }, - { label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' }, - ], - [t] + const { + register, + handleSubmit, + control, + formState: { errors }, + setValue, + resetField, + reset, + watch, + } = useForm({ + defaultValues: { + name: '', + base: 'sd-1', + type: 'main', + path: '', + description: '', + format: 'diffusers', + vae: '', + variant: 'normal', + }, + mode: 'onChange', + }); + + const onSubmit = useCallback>( + (values) => { + const cleanValues = Object.fromEntries( + Object.entries(values).filter(([value]) => value !== null && value !== undefined) + ); + importAdvancedModel({ + source: { + path: cleanValues.path as string, + type: 'local', + }, + config: cleanValues, + }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelAdded', { + modelName: values.name, + }), + status: 'success', + }) + ) + ); + reset(); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: t('toast.modelAddFailed'), + status: 'error', + }) + ) + ); + } + }); + }, + [dispatch, reset, t, importAdvancedModel] ); - const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]); + const watchedModelType = watch('type'); + const watchedModelFormat = watch('format'); + + useEffect(() => { + if (watchedModelType === 'main') { + setValue('format', 'diffusers'); + setValue('repo_variant', ''); + setValue('variant', 'normal'); + } + if (watchedModelType === 'lora') { + setValue('format', 'lycoris'); + } else if (watchedModelType === 'embedding') { + setValue('format', 'embedding_file'); + } else if (watchedModelType === 'ip_adapter') { + setValue('format', 'invokeai'); + } else { + setValue('format', 'diffusers'); + } + resetField('upcast_attention'); + resetField('ztsnr_training'); + resetField('vae'); + resetField('config'); + resetField('prediction_type'); + resetField('image_encoder_model_id'); + }, [watchedModelType, resetField, setValue]); return ( - - - - {t('modelManager.modelType')} - +
+ + + + Model Type + control={control} name="type" /> + + + {t('modelManager.advancedImportInfo')} + - - {t('modelManager.advancedImportInfo')} - - - - {advancedAddMode === 'diffusers' && } - {advancedAddMode === 'checkpoint' && } + + + + {t('modelManager.name')} + value.trim().length >= 3 || 'Must be at least 3 characters', + })} + /> + {errors.name?.message && {errors.name?.message}} + + + + + + {t('modelManager.description')} +