diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx index d7b1561912..7c0bcf0ab1 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx @@ -11,18 +11,23 @@ import { DiffusersModelConfig } from 'services/api/types'; import BaseModelSelect from '../shared/BaseModelSelect'; import ModelVariantSelect from '../shared/ModelVariantSelect'; -export default function AdvancedAddDiffusers() { +type AdvancedAddDiffusersProps = { + model_path?: string; +}; + +export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) { const { t } = useTranslation(); const dispatch = useAppDispatch(); + const { model_path } = props; const [addMainModel] = useAddMainModelsMutation(); const advancedAddDiffusersForm = useForm({ initialValues: { - model_name: '', + model_name: model_path ? model_path.split('\\').splice(-1)[0] : '', base_model: 'sd-1', model_type: 'main', - path: '', + path: model_path ? model_path : '', description: '', model_format: 'diffusers', error: undefined, @@ -30,6 +35,7 @@ export default function AdvancedAddDiffusers() { variant: 'normal', }, }); + const advancedAddDiffusersFormHandler = (values: DiffusersModelConfig) => { addMainModel({ body: values, diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddModels.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddModels.tsx index 18d576c843..88e83fadc8 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddModels.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddModels.tsx @@ -5,12 +5,12 @@ import { useState } from 'react'; import AdvancedAddCheckpoint from './AdvancedAddCheckpoint'; import AdvancedAddDiffusers from './AdvancedAddDiffusers'; -const advancedAddModeData: SelectItem[] = [ +export const advancedAddModeData: SelectItem[] = [ { label: 'Diffusers', value: 'diffusers' }, { label: 'Checkpoint / Safetensors', value: 'checkpoint' }, ]; -type ManualAddMode = 'diffusers' | 'checkpoint'; +export type ManualAddMode = 'diffusers' | 'checkpoint'; export default function AdvancedAddModels() { const [advancedAddMode, setAdvancedAddMode] = diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ScanAdvancedAddModels.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ScanAdvancedAddModels.tsx index 65454c3363..e5b89c7bbf 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ScanAdvancedAddModels.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ScanAdvancedAddModels.tsx @@ -2,16 +2,36 @@ import { Box, Flex, Text } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { motion } from 'framer-motion'; +import { useEffect, useState } from 'react'; import { FaTimes } from 'react-icons/fa'; import { setAdvancedAddScanModel } from '../../store/modelManagerSlice'; import AdvancedAddCheckpoint from './AdvancedAddCheckpoint'; +import AdvancedAddDiffusers from './AdvancedAddDiffusers'; +import { ManualAddMode, advancedAddModeData } from './AdvancedAddModels'; export default function ScanAdvancedAddModels() { const advancedAddScanModel = useAppSelector( (state: RootState) => state.modelmanager.advancedAddScanModel ); + const [advancedAddMode, setAdvancedAddMode] = + useState('diffusers'); + + const [isCheckpoint, setIsCheckpoint] = useState( + advancedAddScanModel && + ['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) => + advancedAddScanModel.endsWith(ext) + ) + ); + + useEffect(() => { + isCheckpoint + ? setAdvancedAddMode('checkpoint') + : setAdvancedAddMode('diffusers'); + }, [setAdvancedAddMode, isCheckpoint]); + const dispatch = useAppDispatch(); return ( @@ -37,7 +57,9 @@ export default function ScanAdvancedAddModels() { > - Add Checkpoint Model + {isCheckpoint || advancedAddMode === 'checkpoint' + ? 'Add Checkpoint Model' + : 'Add Diffusers Model'} } @@ -46,10 +68,31 @@ export default function ScanAdvancedAddModels() { size="sm" /> - { + if (!v) return; + setAdvancedAddMode(v as ManualAddMode); + if (v === 'checkpoint') { + setIsCheckpoint(true); + } else { + setIsCheckpoint(false); + } + }} /> + {isCheckpoint ? ( + + ) : ( + + )} ) ); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx index dd0aa818c8..255307f6f4 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx @@ -7,7 +7,10 @@ import IAIInput from 'common/components/IAIInput'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { FaSearch, FaSync, FaTrash } from 'react-icons/fa'; -import { setSearchFolder } from '../../store/modelManagerSlice'; +import { + setAdvancedAddScanModel, + setSearchFolder, +} from '../../store/modelManagerSlice'; type SearchFolderForm = { folder: string; @@ -101,7 +104,10 @@ function SearchFolderForm() { tooltip={t('modelManager.clearCheckpointFolder')} icon={} size="sm" - onClick={() => dispatch(setSearchFolder(null))} + onClick={() => { + dispatch(setSearchFolder(null)); + dispatch(setAdvancedAddScanModel(null)); + }} isDisabled={!searchFolder} colorScheme="red" />