feat: Scan models add to differentiate between ckpt and diffusers

This commit is contained in:
blessedcoolant 2023-07-17 19:40:08 +12:00
parent f398fe4136
commit cfdaa30d44
4 changed files with 66 additions and 11 deletions

View File

@ -11,18 +11,23 @@ import { DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
export default function AdvancedAddDiffusers() {
type AdvancedAddDiffusersProps = {
model_path?: string;
};
export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const [addMainModel] = useAddMainModelsMutation();
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
initialValues: {
model_name: '',
model_name: model_path ? model_path.split('\\').splice(-1)[0] : '',
base_model: 'sd-1',
model_type: 'main',
path: '',
path: model_path ? model_path : '',
description: '',
model_format: 'diffusers',
error: undefined,
@ -30,6 +35,7 @@ export default function AdvancedAddDiffusers() {
variant: 'normal',
},
});
const advancedAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
addMainModel({
body: values,

View File

@ -5,12 +5,12 @@ import { useState } from 'react';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
const advancedAddModeData: SelectItem[] = [
export const advancedAddModeData: SelectItem[] = [
{ label: 'Diffusers', value: 'diffusers' },
{ label: 'Checkpoint / Safetensors', value: 'checkpoint' },
];
type ManualAddMode = 'diffusers' | 'checkpoint';
export type ManualAddMode = 'diffusers' | 'checkpoint';
export default function AdvancedAddModels() {
const [advancedAddMode, setAdvancedAddMode] =

View File

@ -2,16 +2,36 @@ import { Box, Flex, Text } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { motion } from 'framer-motion';
import { useEffect, useState } from 'react';
import { FaTimes } from 'react-icons/fa';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
import { ManualAddMode, advancedAddModeData } from './AdvancedAddModels';
export default function ScanAdvancedAddModels() {
const advancedAddScanModel = useAppSelector(
(state: RootState) => state.modelmanager.advancedAddScanModel
);
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
const [isCheckpoint, setIsCheckpoint] = useState(
advancedAddScanModel &&
['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) =>
advancedAddScanModel.endsWith(ext)
)
);
useEffect(() => {
isCheckpoint
? setAdvancedAddMode('checkpoint')
: setAdvancedAddMode('diffusers');
}, [setAdvancedAddMode, isCheckpoint]);
const dispatch = useAppDispatch();
return (
@ -37,7 +57,9 @@ export default function ScanAdvancedAddModels() {
>
<Flex justifyContent="space-between" alignItems="center">
<Text size="xl" fontWeight={600}>
Add Checkpoint Model
{isCheckpoint || advancedAddMode === 'checkpoint'
? 'Add Checkpoint Model'
: 'Add Diffusers Model'}
</Text>
<IAIIconButton
icon={<FaTimes />}
@ -46,10 +68,31 @@ export default function ScanAdvancedAddModels() {
size="sm"
/>
</Flex>
<AdvancedAddCheckpoint
key={advancedAddScanModel}
model_path={advancedAddScanModel}
<IAIMantineSelect
label="Model Type"
value={advancedAddMode}
data={advancedAddModeData}
onChange={(v) => {
if (!v) return;
setAdvancedAddMode(v as ManualAddMode);
if (v === 'checkpoint') {
setIsCheckpoint(true);
} else {
setIsCheckpoint(false);
}
}}
/>
{isCheckpoint ? (
<AdvancedAddCheckpoint
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
) : (
<AdvancedAddDiffusers
key={advancedAddScanModel}
model_path={advancedAddScanModel}
/>
)}
</Box>
)
);

View File

@ -7,7 +7,10 @@ import IAIInput from 'common/components/IAIInput';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSearch, FaSync, FaTrash } from 'react-icons/fa';
import { setSearchFolder } from '../../store/modelManagerSlice';
import {
setAdvancedAddScanModel,
setSearchFolder,
} from '../../store/modelManagerSlice';
type SearchFolderForm = {
folder: string;
@ -101,7 +104,10 @@ function SearchFolderForm() {
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
size="sm"
onClick={() => dispatch(setSearchFolder(null))}
onClick={() => {
dispatch(setSearchFolder(null));
dispatch(setAdvancedAddScanModel(null));
}}
isDisabled={!searchFolder}
colorScheme="red"
/>