mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Manual Checkpoint / Safetensor Models
This commit is contained in:
parent
421fcb761b
commit
d93d42af4a
@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
fontWeight: 'normal',
|
||||
marginBottom: 4,
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
|
@ -13,7 +13,7 @@ export default function AddModels() {
|
||||
flexDirection="column"
|
||||
width="100%"
|
||||
overflow="scroll"
|
||||
maxHeight={window.innerHeight - 270}
|
||||
maxHeight={window.innerHeight - 250}
|
||||
gap={4}
|
||||
>
|
||||
<ButtonGroup isAttached>
|
||||
|
@ -1,7 +1,23 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { CheckpointModelConfig } from 'services/api/types';
|
||||
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
|
||||
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||
|
||||
export default function ManualAddCheckpoint() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const manualAddCheckpointForm = useForm<CheckpointModelConfig>({
|
||||
initialValues: {
|
||||
model_name: '',
|
||||
@ -13,8 +29,99 @@ export default function ManualAddCheckpoint() {
|
||||
error: undefined,
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
config: '',
|
||||
config: 'configs\\stable-diffusion\\v1-inference.yaml',
|
||||
},
|
||||
});
|
||||
return <div>ManualAddCheckpoint</div>;
|
||||
|
||||
const [addMainModel] = useAddMainModelsMutation();
|
||||
|
||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||
|
||||
const manualAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
|
||||
addMainModel({
|
||||
body: values,
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `Model Added: ${values.model_name}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
manualAddCheckpointForm.reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Model Add Failed',
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<form
|
||||
onSubmit={manualAddCheckpointForm.onSubmit((v) =>
|
||||
manualAddCheckpointFormHandler(v)
|
||||
)}
|
||||
style={{ width: '100%' }}
|
||||
>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<IAIMantineTextInput
|
||||
label="Model Name"
|
||||
required
|
||||
{...manualAddCheckpointForm.getInputProps('model_name')}
|
||||
/>
|
||||
<BaseModelSelect
|
||||
{...manualAddCheckpointForm.getInputProps('base_model')}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
label="Model Location"
|
||||
required
|
||||
{...manualAddCheckpointForm.getInputProps('path')}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
label="Description"
|
||||
{...manualAddCheckpointForm.getInputProps('description')}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
label="VAE Location"
|
||||
{...manualAddCheckpointForm.getInputProps('vae')}
|
||||
/>
|
||||
<ModelVariantSelect
|
||||
{...manualAddCheckpointForm.getInputProps('variant')}
|
||||
/>
|
||||
<Flex flexDirection="column" width="100%" gap={2}>
|
||||
{!useCustomConfig ? (
|
||||
<CheckpointConfigsSelect
|
||||
width="100%"
|
||||
{...manualAddCheckpointForm.getInputProps('config')}
|
||||
/>
|
||||
) : (
|
||||
<IAIMantineTextInput
|
||||
required
|
||||
label="Custom Config File Location"
|
||||
{...manualAddCheckpointForm.getInputProps('config')}
|
||||
/>
|
||||
)}
|
||||
<IAISimpleCheckbox
|
||||
isChecked={useCustomConfig}
|
||||
onChange={() => setUseCustomConfig(!useCustomConfig)}
|
||||
label="Use Custom Config"
|
||||
/>
|
||||
<IAIButton mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
@ -31,7 +31,6 @@ export default function ManualAddDiffusers() {
|
||||
},
|
||||
});
|
||||
const manualAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
|
||||
console.log(values);
|
||||
addMainModel({
|
||||
body: values,
|
||||
})
|
||||
@ -80,6 +79,7 @@ export default function ManualAddDiffusers() {
|
||||
<IAIMantineTextInput
|
||||
required
|
||||
label="Model Location"
|
||||
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
|
||||
{...manualAddDiffusersForm.getInputProps('path')}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
|
@ -0,0 +1,22 @@
|
||||
import IAIMantineSelect, {
|
||||
IAISelectProps,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type CheckpointConfigSelectProps = Omit<IAISelectProps, 'data'>;
|
||||
|
||||
export default function CheckpointConfigsSelect(
|
||||
props: CheckpointConfigSelectProps
|
||||
) {
|
||||
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
|
||||
const { ...rest } = props;
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
label="Config File"
|
||||
placeholder="Select A Config File"
|
||||
data={availableCheckpointConfigs ? availableCheckpointConfigs : []}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
}
|
@ -96,6 +96,9 @@ type AddMainModelResponse =
|
||||
type SearchFolderResponse =
|
||||
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type CheckpointConfigsResponse =
|
||||
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
@ -383,6 +386,13 @@ export const modelsApi = api.injectEndpoints({
|
||||
};
|
||||
},
|
||||
}),
|
||||
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: `/models/ckpt_confs`,
|
||||
};
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -399,4 +409,5 @@ export const {
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useGetModelsInFolderQuery,
|
||||
useGetCheckpointConfigsQuery,
|
||||
} = modelsApi;
|
||||
|
Loading…
Reference in New Issue
Block a user