feat: Add Manual Checkpoint / Safetensor Models

This commit is contained in:
blessedcoolant 2023-07-16 15:21:49 +12:00
parent 421fcb761b
commit d93d42af4a
6 changed files with 145 additions and 4 deletions

View File

@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
label: { label: {
color: mode(base700, base300)(colorMode), color: mode(base700, base300)(colorMode),
fontWeight: 'normal', fontWeight: 'normal',
marginBottom: 4,
}, },
})} })}
{...rest} {...rest}

View File

@ -13,7 +13,7 @@ export default function AddModels() {
flexDirection="column" flexDirection="column"
width="100%" width="100%"
overflow="scroll" overflow="scroll"
maxHeight={window.innerHeight - 270} maxHeight={window.innerHeight - 250}
gap={4} gap={4}
> >
<ButtonGroup isAttached> <ButtonGroup isAttached>

View File

@ -1,7 +1,23 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { addToast } from 'features/system/store/systemSlice';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types'; import { CheckpointModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
export default function ManualAddCheckpoint() { export default function ManualAddCheckpoint() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const manualAddCheckpointForm = useForm<CheckpointModelConfig>({ const manualAddCheckpointForm = useForm<CheckpointModelConfig>({
initialValues: { initialValues: {
model_name: '', model_name: '',
@ -13,8 +29,99 @@ export default function ManualAddCheckpoint() {
error: undefined, error: undefined,
vae: '', vae: '',
variant: 'normal', variant: 'normal',
config: '', config: 'configs\\stable-diffusion\\v1-inference.yaml',
}, },
}); });
return <div>ManualAddCheckpoint</div>;
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const manualAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
manualAddCheckpointForm.reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={manualAddCheckpointForm.onSubmit((v) =>
manualAddCheckpointFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
label="Model Name"
required
{...manualAddCheckpointForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...manualAddCheckpointForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
label="Model Location"
required
{...manualAddCheckpointForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...manualAddCheckpointForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...manualAddCheckpointForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...manualAddCheckpointForm.getInputProps('variant')}
/>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
width="100%"
{...manualAddCheckpointForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label="Custom Config File Location"
{...manualAddCheckpointForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</Flex>
</form>
);
} }

View File

@ -31,7 +31,6 @@ export default function ManualAddDiffusers() {
}, },
}); });
const manualAddDiffusersFormHandler = (values: DiffusersModelConfig) => { const manualAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
console.log(values);
addMainModel({ addMainModel({
body: values, body: values,
}) })
@ -80,6 +79,7 @@ export default function ManualAddDiffusers() {
<IAIMantineTextInput <IAIMantineTextInput
required required
label="Model Location" label="Model Location"
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
{...manualAddDiffusersForm.getInputProps('path')} {...manualAddDiffusersForm.getInputProps('path')}
/> />
<IAIMantineTextInput <IAIMantineTextInput

View File

@ -0,0 +1,22 @@
import IAIMantineSelect, {
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
type CheckpointConfigSelectProps = Omit<IAISelectProps, 'data'>;
export default function CheckpointConfigsSelect(
props: CheckpointConfigSelectProps
) {
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const { ...rest } = props;
return (
<IAIMantineSelect
label="Config File"
placeholder="Select A Config File"
data={availableCheckpointConfigs ? availableCheckpointConfigs : []}
{...rest}
/>
);
}

View File

@ -96,6 +96,9 @@ type AddMainModelResponse =
type SearchFolderResponse = type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json']; paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse =
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query']; type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
@ -383,6 +386,13 @@ export const modelsApi = api.injectEndpoints({
}; };
}, },
}), }),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
url: `/models/ckpt_confs`,
};
},
}),
}), }),
}); });
@ -399,4 +409,5 @@ export const {
useConvertMainModelsMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation, useMergeMainModelsMutation,
useGetModelsInFolderQuery, useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
} = modelsApi; } = modelsApi;