mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Manual Checkpoint / Safetensor Models
This commit is contained in:
parent
421fcb761b
commit
d93d42af4a
@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
|
|||||||
label: {
|
label: {
|
||||||
color: mode(base700, base300)(colorMode),
|
color: mode(base700, base300)(colorMode),
|
||||||
fontWeight: 'normal',
|
fontWeight: 'normal',
|
||||||
|
marginBottom: 4,
|
||||||
},
|
},
|
||||||
})}
|
})}
|
||||||
{...rest}
|
{...rest}
|
||||||
|
@ -13,7 +13,7 @@ export default function AddModels() {
|
|||||||
flexDirection="column"
|
flexDirection="column"
|
||||||
width="100%"
|
width="100%"
|
||||||
overflow="scroll"
|
overflow="scroll"
|
||||||
maxHeight={window.innerHeight - 270}
|
maxHeight={window.innerHeight - 250}
|
||||||
gap={4}
|
gap={4}
|
||||||
>
|
>
|
||||||
<ButtonGroup isAttached>
|
<ButtonGroup isAttached>
|
||||||
|
@ -1,7 +1,23 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { useForm } from '@mantine/form';
|
import { useForm } from '@mantine/form';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
import { CheckpointModelConfig } from 'services/api/types';
|
import { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
|
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
|
||||||
|
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||||
|
|
||||||
export default function ManualAddCheckpoint() {
|
export default function ManualAddCheckpoint() {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const manualAddCheckpointForm = useForm<CheckpointModelConfig>({
|
const manualAddCheckpointForm = useForm<CheckpointModelConfig>({
|
||||||
initialValues: {
|
initialValues: {
|
||||||
model_name: '',
|
model_name: '',
|
||||||
@ -13,8 +29,99 @@ export default function ManualAddCheckpoint() {
|
|||||||
error: undefined,
|
error: undefined,
|
||||||
vae: '',
|
vae: '',
|
||||||
variant: 'normal',
|
variant: 'normal',
|
||||||
config: '',
|
config: 'configs\\stable-diffusion\\v1-inference.yaml',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
return <div>ManualAddCheckpoint</div>;
|
|
||||||
|
const [addMainModel] = useAddMainModelsMutation();
|
||||||
|
|
||||||
|
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||||
|
|
||||||
|
const manualAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
|
||||||
|
addMainModel({
|
||||||
|
body: values,
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `Model Added: ${values.model_name}`,
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
manualAddCheckpointForm.reset();
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Model Add Failed',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form
|
||||||
|
onSubmit={manualAddCheckpointForm.onSubmit((v) =>
|
||||||
|
manualAddCheckpointFormHandler(v)
|
||||||
|
)}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" gap={2}>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Model Name"
|
||||||
|
required
|
||||||
|
{...manualAddCheckpointForm.getInputProps('model_name')}
|
||||||
|
/>
|
||||||
|
<BaseModelSelect
|
||||||
|
{...manualAddCheckpointForm.getInputProps('base_model')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Model Location"
|
||||||
|
required
|
||||||
|
{...manualAddCheckpointForm.getInputProps('path')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="Description"
|
||||||
|
{...manualAddCheckpointForm.getInputProps('description')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label="VAE Location"
|
||||||
|
{...manualAddCheckpointForm.getInputProps('vae')}
|
||||||
|
/>
|
||||||
|
<ModelVariantSelect
|
||||||
|
{...manualAddCheckpointForm.getInputProps('variant')}
|
||||||
|
/>
|
||||||
|
<Flex flexDirection="column" width="100%" gap={2}>
|
||||||
|
{!useCustomConfig ? (
|
||||||
|
<CheckpointConfigsSelect
|
||||||
|
width="100%"
|
||||||
|
{...manualAddCheckpointForm.getInputProps('config')}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IAIMantineTextInput
|
||||||
|
required
|
||||||
|
label="Custom Config File Location"
|
||||||
|
{...manualAddCheckpointForm.getInputProps('config')}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<IAISimpleCheckbox
|
||||||
|
isChecked={useCustomConfig}
|
||||||
|
onChange={() => setUseCustomConfig(!useCustomConfig)}
|
||||||
|
label="Use Custom Config"
|
||||||
|
/>
|
||||||
|
<IAIButton mt={2} type="submit">
|
||||||
|
{t('modelManager.addModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,6 @@ export default function ManualAddDiffusers() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
const manualAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
|
const manualAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
|
||||||
console.log(values);
|
|
||||||
addMainModel({
|
addMainModel({
|
||||||
body: values,
|
body: values,
|
||||||
})
|
})
|
||||||
@ -80,6 +79,7 @@ export default function ManualAddDiffusers() {
|
|||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
required
|
required
|
||||||
label="Model Location"
|
label="Model Location"
|
||||||
|
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
|
||||||
{...manualAddDiffusersForm.getInputProps('path')}
|
{...manualAddDiffusersForm.getInputProps('path')}
|
||||||
/>
|
/>
|
||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
import IAIMantineSelect, {
|
||||||
|
IAISelectProps,
|
||||||
|
} from 'common/components/IAIMantineSelect';
|
||||||
|
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
type CheckpointConfigSelectProps = Omit<IAISelectProps, 'data'>;
|
||||||
|
|
||||||
|
export default function CheckpointConfigsSelect(
|
||||||
|
props: CheckpointConfigSelectProps
|
||||||
|
) {
|
||||||
|
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
|
||||||
|
const { ...rest } = props;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
label="Config File"
|
||||||
|
placeholder="Select A Config File"
|
||||||
|
data={availableCheckpointConfigs ? availableCheckpointConfigs : []}
|
||||||
|
{...rest}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
@ -96,6 +96,9 @@ type AddMainModelResponse =
|
|||||||
type SearchFolderResponse =
|
type SearchFolderResponse =
|
||||||
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
|
type CheckpointConfigsResponse =
|
||||||
|
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||||
|
|
||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
@ -383,6 +386,13 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
};
|
};
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
|
||||||
|
query: () => {
|
||||||
|
return {
|
||||||
|
url: `/models/ckpt_confs`,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -399,4 +409,5 @@ export const {
|
|||||||
useConvertMainModelsMutation,
|
useConvertMainModelsMutation,
|
||||||
useMergeMainModelsMutation,
|
useMergeMainModelsMutation,
|
||||||
useGetModelsInFolderQuery,
|
useGetModelsInFolderQuery,
|
||||||
|
useGetCheckpointConfigsQuery,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
Loading…
Reference in New Issue
Block a user