feat: Add Manual Checkpoint / Safetensor Models

This commit is contained in:
blessedcoolant 2023-07-16 15:21:49 +12:00
parent 421fcb761b
commit d93d42af4a
6 changed files with 145 additions and 4 deletions

View File

@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
marginBottom: 4,
},
})}
{...rest}

View File

@ -13,7 +13,7 @@ export default function AddModels() {
flexDirection="column"
width="100%"
overflow="scroll"
maxHeight={window.innerHeight - 270}
maxHeight={window.innerHeight - 250}
gap={4}
>
<ButtonGroup isAttached>

View File

@ -1,7 +1,23 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { addToast } from 'features/system/store/systemSlice';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
export default function ManualAddCheckpoint() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const manualAddCheckpointForm = useForm<CheckpointModelConfig>({
initialValues: {
model_name: '',
@ -13,8 +29,99 @@ export default function ManualAddCheckpoint() {
error: undefined,
vae: '',
variant: 'normal',
config: '',
config: 'configs\\stable-diffusion\\v1-inference.yaml',
},
});
return <div>ManualAddCheckpoint</div>;
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const manualAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Model Added: ${values.model_name}`,
status: 'success',
})
)
);
manualAddCheckpointForm.reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: 'Model Add Failed',
status: 'error',
})
)
);
}
});
};
return (
<form
onSubmit={manualAddCheckpointForm.onSubmit((v) =>
manualAddCheckpointFormHandler(v)
)}
style={{ width: '100%' }}
>
<Flex flexDirection="column" gap={2}>
<IAIMantineTextInput
label="Model Name"
required
{...manualAddCheckpointForm.getInputProps('model_name')}
/>
<BaseModelSelect
{...manualAddCheckpointForm.getInputProps('base_model')}
/>
<IAIMantineTextInput
label="Model Location"
required
{...manualAddCheckpointForm.getInputProps('path')}
/>
<IAIMantineTextInput
label="Description"
{...manualAddCheckpointForm.getInputProps('description')}
/>
<IAIMantineTextInput
label="VAE Location"
{...manualAddCheckpointForm.getInputProps('vae')}
/>
<ModelVariantSelect
{...manualAddCheckpointForm.getInputProps('variant')}
/>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
width="100%"
{...manualAddCheckpointForm.getInputProps('config')}
/>
) : (
<IAIMantineTextInput
required
label="Custom Config File Location"
{...manualAddCheckpointForm.getInputProps('config')}
/>
)}
<IAISimpleCheckbox
isChecked={useCustomConfig}
onChange={() => setUseCustomConfig(!useCustomConfig)}
label="Use Custom Config"
/>
<IAIButton mt={2} type="submit">
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</Flex>
</form>
);
}

View File

@ -31,7 +31,6 @@ export default function ManualAddDiffusers() {
},
});
const manualAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
console.log(values);
addMainModel({
body: values,
})
@ -80,6 +79,7 @@ export default function ManualAddDiffusers() {
<IAIMantineTextInput
required
label="Model Location"
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
{...manualAddDiffusersForm.getInputProps('path')}
/>
<IAIMantineTextInput

View File

@ -0,0 +1,22 @@
import IAIMantineSelect, {
IAISelectProps,
} from 'common/components/IAIMantineSelect';
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
type CheckpointConfigSelectProps = Omit<IAISelectProps, 'data'>;
export default function CheckpointConfigsSelect(
props: CheckpointConfigSelectProps
) {
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const { ...rest } = props;
return (
<IAIMantineSelect
label="Config File"
placeholder="Select A Config File"
data={availableCheckpointConfigs ? availableCheckpointConfigs : []}
{...rest}
/>
);
}

View File

@ -96,6 +96,9 @@ type AddMainModelResponse =
type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse =
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
@ -383,6 +386,13 @@ export const modelsApi = api.injectEndpoints({
};
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
url: `/models/ckpt_confs`,
};
},
}),
}),
});
@ -399,4 +409,5 @@ export const {
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
} = modelsApi;