mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Model Search
Going to rework the whole thing. The old system is convoluted and too difficult to plug back.
This commit is contained in:
parent
558c26d78f
commit
4a2f34f77f
@ -0,0 +1,34 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export default function FoundModelsList() {
|
||||
const searchFolder = useAppSelector(
|
||||
(state: RootState) => state.modelmanager.searchFolder
|
||||
);
|
||||
|
||||
const { data: foundModels } = useGetModelsInFolderQuery({
|
||||
search_path: searchFolder ? searchFolder : '',
|
||||
});
|
||||
|
||||
console.log(foundModels);
|
||||
|
||||
const renderFoundModels = () => {
|
||||
if (!searchFolder) return;
|
||||
|
||||
if (!foundModels || foundModels.length === 0) {
|
||||
return <Flex>No Models Found</Flex>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex sx={{ flexDirection: 'column' }}>
|
||||
{foundModels.map((model) => (
|
||||
<Flex key={model}>{model}</Flex>
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
return renderFoundModels();
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSearch, FaTrash } from 'react-icons/fa';
|
||||
import { setSearchFolder } from '../../store/modelManagerSlice';
|
||||
|
||||
type SearchFolderForm = {
|
||||
folder: string;
|
||||
};
|
||||
|
||||
function SearchFolderForm() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const searchFolder = useAppSelector(
|
||||
(state: RootState) => state.modelmanager.searchFolder
|
||||
);
|
||||
|
||||
const searchFolderForm = useForm<SearchFolderForm>({
|
||||
initialValues: {
|
||||
folder: '',
|
||||
},
|
||||
});
|
||||
|
||||
const searchFolderFormSubmitHandler = useCallback(
|
||||
(values: SearchFolderForm) => {
|
||||
dispatch(setSearchFolder(values.folder));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<form
|
||||
onSubmit={searchFolderForm.onSubmit((values) =>
|
||||
searchFolderFormSubmitHandler(values)
|
||||
)}
|
||||
style={{ width: '100%' }}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
w: '100%',
|
||||
gap: 2,
|
||||
p: 4,
|
||||
borderRadius: 4,
|
||||
background: 'base.400',
|
||||
alignItems: 'center',
|
||||
_dark: {
|
||||
background: 'base.800',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex w="100%" alignItems="center" gap={4} minH={12}>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: 600,
|
||||
color: 'base.700',
|
||||
minW: 'max-content',
|
||||
_dark: { color: 'base.300' },
|
||||
}}
|
||||
>
|
||||
Search Folder
|
||||
</Text>
|
||||
{!searchFolder ? (
|
||||
<IAIInput
|
||||
w="100%"
|
||||
size="md"
|
||||
{...searchFolderForm.getInputProps('folder')}
|
||||
/>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
w: '100%',
|
||||
p: 2,
|
||||
px: 4,
|
||||
bg: 'base.200',
|
||||
borderRadius: 4,
|
||||
_dark: { bg: 'base.700' },
|
||||
}}
|
||||
>
|
||||
{searchFolder}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<Flex gap={2}>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.scanAgain')}
|
||||
tooltip={t('modelManager.scanAgain')}
|
||||
icon={<FaSearch />}
|
||||
fontSize={18}
|
||||
size="sm"
|
||||
type="submit"
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaTrash />}
|
||||
size="sm"
|
||||
onClick={() => dispatch(setSearchFolder(null))}
|
||||
isDisabled={!searchFolder}
|
||||
colorScheme="red"
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(SearchFolderForm);
|
@ -1,430 +1,17 @@
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import React from 'react';
|
||||
|
||||
import {
|
||||
Badge,
|
||||
Flex,
|
||||
FormControl,
|
||||
HStack,
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { FaSearch, FaTrash } from 'react-icons/fa';
|
||||
|
||||
// import { addNewModel, searchForModels } from 'app/socketio/actions';
|
||||
import {
|
||||
setFoundModels,
|
||||
setSearchFolder,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { setShouldShowExistingModelsInSearch } from 'features/ui/store/uiSlice';
|
||||
|
||||
import type { FoundModel } from 'app/types/invokeai';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { forEach, remove } from 'lodash-es';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const existingModelsSelector = createSelector([systemSelector], (system) => {
|
||||
const { model_list } = system;
|
||||
|
||||
const existingModels: string[] = [];
|
||||
|
||||
forEach(model_list, (value) => {
|
||||
existingModels.push(value.weights);
|
||||
});
|
||||
|
||||
return existingModels;
|
||||
});
|
||||
|
||||
interface SearchModelEntry {
|
||||
model: FoundModel;
|
||||
modelsToAdd: string[];
|
||||
setModelsToAdd: React.Dispatch<React.SetStateAction<string[]>>;
|
||||
}
|
||||
|
||||
function SearchModelEntry({
|
||||
model,
|
||||
modelsToAdd,
|
||||
setModelsToAdd,
|
||||
}: SearchModelEntry) {
|
||||
const { t } = useTranslation();
|
||||
const existingModels = useAppSelector(existingModelsSelector);
|
||||
|
||||
const foundModelsChangeHandler = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
if (!modelsToAdd.includes(e.target.value)) {
|
||||
setModelsToAdd([...modelsToAdd, e.target.value]);
|
||||
} else {
|
||||
setModelsToAdd(remove(modelsToAdd, (v) => v !== e.target.value));
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4} alignItems="center" justifyContent="space-between">
|
||||
<IAISimpleCheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAISimpleCheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
import FoundModelsList from './FoundModelsList';
|
||||
import SearchFolderForm from './SearchFolderForm';
|
||||
|
||||
export default function SearchModels() {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const searchFolder = useAppSelector(
|
||||
(state: RootState) => state.system.searchFolder
|
||||
);
|
||||
|
||||
const foundModels = useAppSelector(
|
||||
(state: RootState) => state.system.foundModels
|
||||
);
|
||||
|
||||
const existingModels = useAppSelector(existingModelsSelector);
|
||||
|
||||
const shouldShowExistingModelsInSearch = useAppSelector(
|
||||
(state: RootState) => state.ui.shouldShowExistingModelsInSearch
|
||||
);
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const [modelsToAdd, setModelsToAdd] = React.useState<string[]>([]);
|
||||
const [modelType, setModelType] = React.useState<string>('v1');
|
||||
const [pathToConfig, setPathToConfig] = React.useState<string>('');
|
||||
|
||||
const resetSearchModelHandler = () => {
|
||||
dispatch(setSearchFolder(null));
|
||||
dispatch(setFoundModels(null));
|
||||
setModelsToAdd([]);
|
||||
};
|
||||
|
||||
const findModelsHandler = (values: { checkpointFolder: string }) => {
|
||||
dispatch(searchForModels(values.checkpointFolder));
|
||||
};
|
||||
|
||||
const addAllToSelected = () => {
|
||||
setModelsToAdd([]);
|
||||
if (foundModels) {
|
||||
foundModels.forEach((model) => {
|
||||
if (!existingModels.includes(model.location)) {
|
||||
setModelsToAdd((currentModels) => {
|
||||
return [...currentModels, model.name];
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const removeAllFromSelected = () => {
|
||||
setModelsToAdd([]);
|
||||
};
|
||||
|
||||
const addSelectedModels = () => {
|
||||
const modelsToBeAdded = foundModels?.filter((foundModel) =>
|
||||
modelsToAdd.includes(foundModel.name)
|
||||
);
|
||||
|
||||
const configFiles = {
|
||||
v1: 'configs/stable-diffusion/v1-inference.yaml',
|
||||
v2_base: 'configs/stable-diffusion/v2-inference-v.yaml',
|
||||
v2_768: 'configs/stable-diffusion/v2-inference-v.yaml',
|
||||
inpainting: 'configs/stable-diffusion/v1-inpainting-inference.yaml',
|
||||
custom: pathToConfig,
|
||||
};
|
||||
|
||||
modelsToBeAdded?.forEach((model) => {
|
||||
const modelFormat = {
|
||||
name: model.name,
|
||||
description: '',
|
||||
config: configFiles[modelType as keyof typeof configFiles],
|
||||
weights: model.location,
|
||||
vae: '',
|
||||
width: 512,
|
||||
height: 512,
|
||||
default: false,
|
||||
format: 'ckpt',
|
||||
};
|
||||
dispatch(addNewModel(modelFormat));
|
||||
});
|
||||
setModelsToAdd([]);
|
||||
};
|
||||
|
||||
const renderFoundModels = () => {
|
||||
const newFoundModels: ReactNode[] = [];
|
||||
const existingFoundModels: ReactNode[] = [];
|
||||
|
||||
if (foundModels) {
|
||||
foundModels.forEach((model, index) => {
|
||||
if (existingModels.includes(model.location)) {
|
||||
existingFoundModels.push(
|
||||
<SearchModelEntry
|
||||
key={index}
|
||||
model={model}
|
||||
modelsToAdd={modelsToAdd}
|
||||
setModelsToAdd={setModelsToAdd}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
newFoundModels.push(
|
||||
<SearchModelEntry
|
||||
key={index}
|
||||
model={model}
|
||||
modelsToAdd={modelsToAdd}
|
||||
setModelsToAdd={setModelsToAdd}
|
||||
/>
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
{newFoundModels}
|
||||
{shouldShowExistingModelsInSearch && existingFoundModels}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{searchFolder ? (
|
||||
<Flex
|
||||
sx={{
|
||||
padding: 4,
|
||||
gap: 2,
|
||||
position: 'relative',
|
||||
borderRadius: 'base',
|
||||
alignItems: 'center',
|
||||
w: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: 500,
|
||||
}}
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.checkpointFolder')}
|
||||
</Text>
|
||||
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
|
||||
</Flex>
|
||||
<Spacer />
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.scanAgain')}
|
||||
tooltip={t('modelManager.scanAgain')}
|
||||
icon={<FaSearch />}
|
||||
fontSize={18}
|
||||
disabled={isProcessing}
|
||||
onClick={() => dispatch(searchForModels(searchFolder))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaTrash />}
|
||||
onClick={resetSearchModelHandler}
|
||||
/>
|
||||
</Flex>
|
||||
) : (
|
||||
<Formik
|
||||
initialValues={{ checkpointFolder: '' }}
|
||||
onSubmit={(values) => {
|
||||
findModelsHandler(values);
|
||||
}}
|
||||
>
|
||||
{({ handleSubmit }) => (
|
||||
<IAIForm onSubmit={handleSubmit} width="100%">
|
||||
<HStack columnGap={2} alignItems="flex-end">
|
||||
<FormControl flexGrow={1}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="checkpointFolder"
|
||||
name="checkpointFolder"
|
||||
type="text"
|
||||
size="md"
|
||||
label={t('modelManager.checkpointFolder')}
|
||||
/>
|
||||
</FormControl>
|
||||
<IAIButton
|
||||
leftIcon={<FaSearch />}
|
||||
aria-label={t('modelManager.findModels')}
|
||||
tooltip={t('modelManager.findModels')}
|
||||
type="submit"
|
||||
disabled={isProcessing}
|
||||
px={8}
|
||||
>
|
||||
{t('modelManager.findModels')}
|
||||
</IAIButton>
|
||||
</HStack>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
)}
|
||||
{foundModels && (
|
||||
<Flex flexDirection="column" rowGap={4} width="full">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<p>
|
||||
{t('modelManager.modelsFound')}: {foundModels.length}
|
||||
</p>
|
||||
<p>
|
||||
{t('modelManager.selected')}: {modelsToAdd.length}
|
||||
</p>
|
||||
</Flex>
|
||||
<Flex columnGap={2} justifyContent="space-between">
|
||||
<Flex columnGap={2}>
|
||||
<IAIButton
|
||||
isDisabled={modelsToAdd.length === foundModels.length}
|
||||
onClick={addAllToSelected}
|
||||
>
|
||||
{t('modelManager.selectAll')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
isDisabled={modelsToAdd.length === 0}
|
||||
onClick={removeAllFromSelected}
|
||||
>
|
||||
{t('modelManager.deselectAll')}
|
||||
</IAIButton>
|
||||
<IAISimpleCheckbox
|
||||
label={t('modelManager.showExisting')}
|
||||
isChecked={shouldShowExistingModelsInSearch}
|
||||
onChange={() =>
|
||||
dispatch(
|
||||
setShouldShowExistingModelsInSearch(
|
||||
!shouldShowExistingModelsInSearch
|
||||
)
|
||||
)
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<IAIButton
|
||||
isDisabled={modelsToAdd.length === 0}
|
||||
onClick={addSelectedModels}
|
||||
colorScheme="accent"
|
||||
>
|
||||
{t('modelManager.addSelected')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<Text fontWeight={500} variant="subtext">
|
||||
{t('modelManager.pickModelType')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelType}
|
||||
onChange={(v) => setModelType(v)}
|
||||
defaultValue="v1"
|
||||
name="model_type"
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<Radio value="v1">
|
||||
<Text fontSize="sm">{t('modelManager.v1')}</Text>
|
||||
</Radio>
|
||||
<Radio value="v2_base">
|
||||
<Text fontSize="sm">{t('modelManager.v2_base')}</Text>
|
||||
</Radio>
|
||||
<Radio value="v2_768">
|
||||
<Text fontSize="sm">{t('modelManager.v2_768')}</Text>
|
||||
</Radio>
|
||||
<Radio value="inpainting">
|
||||
<Text fontSize="sm">{t('modelManager.inpainting')}</Text>
|
||||
</Radio>
|
||||
<Radio value="custom">
|
||||
<Text fontSize="sm">{t('modelManager.customConfig')}</Text>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
{modelType === 'custom' && (
|
||||
<Flex flexDirection="column" rowGap={2}>
|
||||
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
||||
{t('modelManager.pathToCustomConfig')}
|
||||
</Text>
|
||||
<IAIInput
|
||||
value={pathToConfig}
|
||||
onChange={(e) => {
|
||||
if (e.target.value !== '') setPathToConfig(e.target.value);
|
||||
}}
|
||||
width="full"
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
maxHeight={72}
|
||||
overflowY="scroll"
|
||||
borderRadius="sm"
|
||||
gap={2}
|
||||
>
|
||||
{foundModels.length > 0 ? (
|
||||
renderFoundModels()
|
||||
) : (
|
||||
<Text
|
||||
fontWeight="500"
|
||||
padding={2}
|
||||
borderRadius="sm"
|
||||
textAlign="center"
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.noModelsFound')}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
</>
|
||||
<Flex flexDirection="column" w="100%">
|
||||
<SearchFolderForm />
|
||||
<FoundModelsList />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -0,0 +1,430 @@
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import React from 'react';
|
||||
|
||||
import {
|
||||
Badge,
|
||||
Flex,
|
||||
FormControl,
|
||||
HStack,
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { FaSearch, FaTrash } from 'react-icons/fa';
|
||||
|
||||
// import { addNewModel, searchForModels } from 'app/socketio/actions';
|
||||
import {
|
||||
setFoundModels,
|
||||
setSearchFolder,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { setShouldShowExistingModelsInSearch } from 'features/ui/store/uiSlice';
|
||||
|
||||
import type { FoundModel } from 'app/types/invokeai';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { forEach, remove } from 'lodash-es';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const existingModelsSelector = createSelector([systemSelector], (system) => {
|
||||
const { model_list } = system;
|
||||
|
||||
const existingModels: string[] = [];
|
||||
|
||||
forEach(model_list, (value) => {
|
||||
existingModels.push(value.weights);
|
||||
});
|
||||
|
||||
return existingModels;
|
||||
});
|
||||
|
||||
interface SearchModelEntry {
|
||||
model: FoundModel;
|
||||
modelsToAdd: string[];
|
||||
setModelsToAdd: React.Dispatch<React.SetStateAction<string[]>>;
|
||||
}
|
||||
|
||||
function SearchModelEntry({
|
||||
model,
|
||||
modelsToAdd,
|
||||
setModelsToAdd,
|
||||
}: SearchModelEntry) {
|
||||
const { t } = useTranslation();
|
||||
const existingModels = useAppSelector(existingModelsSelector);
|
||||
|
||||
const foundModelsChangeHandler = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
if (!modelsToAdd.includes(e.target.value)) {
|
||||
setModelsToAdd([...modelsToAdd, e.target.value]);
|
||||
} else {
|
||||
setModelsToAdd(remove(modelsToAdd, (v) => v !== e.target.value));
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4} alignItems="center" justifyContent="space-between">
|
||||
<IAISimpleCheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAISimpleCheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
export default function SearchModels() {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const searchFolder = useAppSelector(
|
||||
(state: RootState) => state.system.searchFolder
|
||||
);
|
||||
|
||||
const foundModels = useAppSelector(
|
||||
(state: RootState) => state.system.foundModels
|
||||
);
|
||||
|
||||
const existingModels = useAppSelector(existingModelsSelector);
|
||||
|
||||
const shouldShowExistingModelsInSearch = useAppSelector(
|
||||
(state: RootState) => state.ui.shouldShowExistingModelsInSearch
|
||||
);
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const [modelsToAdd, setModelsToAdd] = React.useState<string[]>([]);
|
||||
const [modelType, setModelType] = React.useState<string>('v1');
|
||||
const [pathToConfig, setPathToConfig] = React.useState<string>('');
|
||||
|
||||
const resetSearchModelHandler = () => {
|
||||
dispatch(setSearchFolder(null));
|
||||
dispatch(setFoundModels(null));
|
||||
setModelsToAdd([]);
|
||||
};
|
||||
|
||||
const findModelsHandler = (values: { checkpointFolder: string }) => {
|
||||
dispatch(searchForModels(values.checkpointFolder));
|
||||
};
|
||||
|
||||
const addAllToSelected = () => {
|
||||
setModelsToAdd([]);
|
||||
if (foundModels) {
|
||||
foundModels.forEach((model) => {
|
||||
if (!existingModels.includes(model.location)) {
|
||||
setModelsToAdd((currentModels) => {
|
||||
return [...currentModels, model.name];
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const removeAllFromSelected = () => {
|
||||
setModelsToAdd([]);
|
||||
};
|
||||
|
||||
const addSelectedModels = () => {
|
||||
const modelsToBeAdded = foundModels?.filter((foundModel) =>
|
||||
modelsToAdd.includes(foundModel.name)
|
||||
);
|
||||
|
||||
const configFiles = {
|
||||
v1: 'configs/stable-diffusion/v1-inference.yaml',
|
||||
v2_base: 'configs/stable-diffusion/v2-inference-v.yaml',
|
||||
v2_768: 'configs/stable-diffusion/v2-inference-v.yaml',
|
||||
inpainting: 'configs/stable-diffusion/v1-inpainting-inference.yaml',
|
||||
custom: pathToConfig,
|
||||
};
|
||||
|
||||
modelsToBeAdded?.forEach((model) => {
|
||||
const modelFormat = {
|
||||
name: model.name,
|
||||
description: '',
|
||||
config: configFiles[modelType as keyof typeof configFiles],
|
||||
weights: model.location,
|
||||
vae: '',
|
||||
width: 512,
|
||||
height: 512,
|
||||
default: false,
|
||||
format: 'ckpt',
|
||||
};
|
||||
dispatch(addNewModel(modelFormat));
|
||||
});
|
||||
setModelsToAdd([]);
|
||||
};
|
||||
|
||||
const renderFoundModels = () => {
|
||||
const newFoundModels: ReactNode[] = [];
|
||||
const existingFoundModels: ReactNode[] = [];
|
||||
|
||||
if (foundModels) {
|
||||
foundModels.forEach((model, index) => {
|
||||
if (existingModels.includes(model.location)) {
|
||||
existingFoundModels.push(
|
||||
<SearchModelEntry
|
||||
key={index}
|
||||
model={model}
|
||||
modelsToAdd={modelsToAdd}
|
||||
setModelsToAdd={setModelsToAdd}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
newFoundModels.push(
|
||||
<SearchModelEntry
|
||||
key={index}
|
||||
model={model}
|
||||
modelsToAdd={modelsToAdd}
|
||||
setModelsToAdd={setModelsToAdd}
|
||||
/>
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
{newFoundModels}
|
||||
{shouldShowExistingModelsInSearch && existingFoundModels}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{searchFolder ? (
|
||||
<Flex
|
||||
sx={{
|
||||
padding: 4,
|
||||
gap: 2,
|
||||
position: 'relative',
|
||||
borderRadius: 'base',
|
||||
alignItems: 'center',
|
||||
w: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: 500,
|
||||
}}
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.checkpointFolder')}
|
||||
</Text>
|
||||
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
|
||||
</Flex>
|
||||
<Spacer />
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.scanAgain')}
|
||||
tooltip={t('modelManager.scanAgain')}
|
||||
icon={<FaSearch />}
|
||||
fontSize={18}
|
||||
disabled={isProcessing}
|
||||
onClick={() => dispatch(searchForModels(searchFolder))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaTrash />}
|
||||
onClick={resetSearchModelHandler}
|
||||
/>
|
||||
</Flex>
|
||||
) : (
|
||||
<Formik
|
||||
initialValues={{ checkpointFolder: '' }}
|
||||
onSubmit={(values) => {
|
||||
findModelsHandler(values);
|
||||
}}
|
||||
>
|
||||
{({ handleSubmit }) => (
|
||||
<IAIForm onSubmit={handleSubmit} width="100%">
|
||||
<HStack columnGap={2} alignItems="flex-end">
|
||||
<FormControl flexGrow={1}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="checkpointFolder"
|
||||
name="checkpointFolder"
|
||||
type="text"
|
||||
size="md"
|
||||
label={t('modelManager.checkpointFolder')}
|
||||
/>
|
||||
</FormControl>
|
||||
<IAIButton
|
||||
leftIcon={<FaSearch />}
|
||||
aria-label={t('modelManager.findModels')}
|
||||
tooltip={t('modelManager.findModels')}
|
||||
type="submit"
|
||||
disabled={isProcessing}
|
||||
px={8}
|
||||
>
|
||||
{t('modelManager.findModels')}
|
||||
</IAIButton>
|
||||
</HStack>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
)}
|
||||
{foundModels && (
|
||||
<Flex flexDirection="column" rowGap={4} width="full">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<p>
|
||||
{t('modelManager.modelsFound')}: {foundModels.length}
|
||||
</p>
|
||||
<p>
|
||||
{t('modelManager.selected')}: {modelsToAdd.length}
|
||||
</p>
|
||||
</Flex>
|
||||
<Flex columnGap={2} justifyContent="space-between">
|
||||
<Flex columnGap={2}>
|
||||
<IAIButton
|
||||
isDisabled={modelsToAdd.length === foundModels.length}
|
||||
onClick={addAllToSelected}
|
||||
>
|
||||
{t('modelManager.selectAll')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
isDisabled={modelsToAdd.length === 0}
|
||||
onClick={removeAllFromSelected}
|
||||
>
|
||||
{t('modelManager.deselectAll')}
|
||||
</IAIButton>
|
||||
<IAISimpleCheckbox
|
||||
label={t('modelManager.showExisting')}
|
||||
isChecked={shouldShowExistingModelsInSearch}
|
||||
onChange={() =>
|
||||
dispatch(
|
||||
setShouldShowExistingModelsInSearch(
|
||||
!shouldShowExistingModelsInSearch
|
||||
)
|
||||
)
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<IAIButton
|
||||
isDisabled={modelsToAdd.length === 0}
|
||||
onClick={addSelectedModels}
|
||||
colorScheme="accent"
|
||||
>
|
||||
{t('modelManager.addSelected')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<Text fontWeight={500} variant="subtext">
|
||||
{t('modelManager.pickModelType')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelType}
|
||||
onChange={(v) => setModelType(v)}
|
||||
defaultValue="v1"
|
||||
name="model_type"
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<Radio value="v1">
|
||||
<Text fontSize="sm">{t('modelManager.v1')}</Text>
|
||||
</Radio>
|
||||
<Radio value="v2_base">
|
||||
<Text fontSize="sm">{t('modelManager.v2_base')}</Text>
|
||||
</Radio>
|
||||
<Radio value="v2_768">
|
||||
<Text fontSize="sm">{t('modelManager.v2_768')}</Text>
|
||||
</Radio>
|
||||
<Radio value="inpainting">
|
||||
<Text fontSize="sm">{t('modelManager.inpainting')}</Text>
|
||||
</Radio>
|
||||
<Radio value="custom">
|
||||
<Text fontSize="sm">{t('modelManager.customConfig')}</Text>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
{modelType === 'custom' && (
|
||||
<Flex flexDirection="column" rowGap={2}>
|
||||
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
||||
{t('modelManager.pathToCustomConfig')}
|
||||
</Text>
|
||||
<IAIInput
|
||||
value={pathToConfig}
|
||||
onChange={(e) => {
|
||||
if (e.target.value !== '') setPathToConfig(e.target.value);
|
||||
}}
|
||||
width="full"
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
maxHeight={72}
|
||||
overflowY="scroll"
|
||||
borderRadius="sm"
|
||||
gap={2}
|
||||
>
|
||||
{foundModels.length > 0 ? (
|
||||
renderFoundModels()
|
||||
) : (
|
||||
<Text
|
||||
fontWeight="500"
|
||||
padding={2}
|
||||
borderRadius="sm"
|
||||
textAlign="center"
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.noModelsFound')}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
@ -14,8 +14,9 @@ import {
|
||||
VaeModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import queryString from 'query-string';
|
||||
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
||||
import { paths } from '../schema';
|
||||
import { operations, paths } from '../schema';
|
||||
|
||||
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
|
||||
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
|
||||
@ -77,6 +78,11 @@ type MergeMainModelArg = {
|
||||
type MergeMainModelResponse =
|
||||
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type SearchFolderResponse =
|
||||
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
@ -331,6 +337,14 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
|
||||
query: (arg) => {
|
||||
const folderQueryStr = queryString.stringify(arg, {});
|
||||
return {
|
||||
url: `/models/search?${folderQueryStr}`,
|
||||
};
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -344,4 +358,5 @@ export const {
|
||||
useDeleteMainModelsMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useGetModelsInFolderQuery,
|
||||
} = modelsApi;
|
||||
|
@ -4655,18 +4655,18 @@ export type components = {
|
||||
*/
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
@ -58,6 +58,7 @@ export type AnyModelConfig =
|
||||
|
||||
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
|
||||
export type SearchFolderConfig = components['schemas'];
|
||||
|
||||
// Graphs
|
||||
export type Graph = components['schemas']['Graph'];
|
||||
|
Loading…
Reference in New Issue
Block a user