diff --git a/invokeai/frontend/src/features/system/components/ModelManager/SearchModels.tsx b/invokeai/frontend/src/features/system/components/ModelManager/SearchModels.tsx index fd7da520bd..dfa197f375 100644 --- a/invokeai/frontend/src/features/system/components/ModelManager/SearchModels.tsx +++ b/invokeai/frontend/src/features/system/components/ModelManager/SearchModels.tsx @@ -3,7 +3,16 @@ import IAICheckbox from 'common/components/IAICheckbox'; import IAIIconButton from 'common/components/IAIIconButton'; import React from 'react'; -import { Box, Flex, FormControl, HStack, Text, VStack } from '@chakra-ui/react'; +import { + Box, + Flex, + FormControl, + HStack, + Radio, + RadioGroup, + Text, + VStack, +} from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { systemSelector } from 'features/system/store/systemSelectors'; @@ -135,6 +144,8 @@ export default function SearchModels() { ); const [modelsToAdd, setModelsToAdd] = React.useState([]); + const [modelType, setModelType] = React.useState('v1'); + const [pathToConfig, setPathToConfig] = React.useState(''); const resetSearchModelHandler = () => { dispatch(setSearchFolder(null)); @@ -167,11 +178,19 @@ export default function SearchModels() { const modelsToBeAdded = foundModels?.filter((foundModel) => modelsToAdd.includes(foundModel.name) ); + + const configFiles = { + v1: 'configs/stable-diffusion/v1-inference.yaml', + v2: 'configs/stable-diffusion/v2-inference-v.yaml', + inpainting: 'configs/stable-diffusion/v1-inpainting-inference.yaml', + custom: pathToConfig, + }; + modelsToBeAdded?.forEach((model) => { const modelFormat = { name: model.name, description: '', - config: 'configs/stable-diffusion/v1-inference.yaml', + config: configFiles[modelType as keyof typeof configFiles], weights: model.location, vae: '', width: 512, @@ -346,6 +365,55 @@ export default function SearchModels() { {t('modelmanager:addSelected')} + + + + + Pick Model Type: + + setModelType(v)} + defaultValue="v1" + name="model_type" + > + + {t('modelmanager:v1')} + {t('modelmanager:v2')} + + {t('modelmanager:inpainting')} + + {t('modelmanager:customConfig')} + + + + + {modelType === 'custom' && ( + + + {t('modelmanager:pathToCustomConfig')} + + { + if (e.target.value !== '') setPathToConfig(e.target.value); + }} + width="42.5rem" + /> + + )} + +