From d93d42af4aaea7eca4bc43525e6428311b2bb55f Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 16 Jul 2023 15:21:49 +1200 Subject: [PATCH] feat: Add Manual Checkpoint / Safetensor Models --- .../src/common/components/IAIMantineInput.tsx | 1 + .../subpanels/AddModelsPanel/AddModels.tsx | 2 +- .../AddModelsPanel/ManualAddCheckpoint.tsx | 111 +++++++++++++++++- .../AddModelsPanel/ManualAddDiffusers.tsx | 2 +- .../shared/CheckpointConfigsSelect.tsx | 22 ++++ .../web/src/services/api/endpoints/models.ts | 11 ++ 6 files changed, 145 insertions(+), 4 deletions(-) create mode 100644 invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/shared/CheckpointConfigsSelect.tsx diff --git a/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx b/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx index d60f6614df..afe8b9442b 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx @@ -36,6 +36,7 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) { label: { color: mode(base700, base300)(colorMode), fontWeight: 'normal', + marginBottom: 4, }, })} {...rest} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddModels.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddModels.tsx index 4e1f3d8240..fdb890152b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddModels.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddModels.tsx @@ -13,7 +13,7 @@ export default function AddModels() { flexDirection="column" width="100%" overflow="scroll" - maxHeight={window.innerHeight - 270} + maxHeight={window.innerHeight - 250} gap={4} > diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddCheckpoint.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddCheckpoint.tsx index efee656d81..f63085aaaa 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddCheckpoint.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddCheckpoint.tsx @@ -1,7 +1,23 @@ +import { Flex } from '@chakra-ui/react'; import { useForm } from '@mantine/form'; +import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; +import IAIMantineTextInput from 'common/components/IAIMantineInput'; +import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; +import { addToast } from 'features/system/store/systemSlice'; +import { useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useAddMainModelsMutation } from 'services/api/endpoints/models'; import { CheckpointModelConfig } from 'services/api/types'; +import BaseModelSelect from '../shared/BaseModelSelect'; +import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect'; +import ModelVariantSelect from '../shared/ModelVariantSelect'; export default function ManualAddCheckpoint() { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const manualAddCheckpointForm = useForm({ initialValues: { model_name: '', @@ -13,8 +29,99 @@ export default function ManualAddCheckpoint() { error: undefined, vae: '', variant: 'normal', - config: '', + config: 'configs\\stable-diffusion\\v1-inference.yaml', }, }); - return
ManualAddCheckpoint
; + + const [addMainModel] = useAddMainModelsMutation(); + + const [useCustomConfig, setUseCustomConfig] = useState(false); + + const manualAddCheckpointFormHandler = (values: CheckpointModelConfig) => { + addMainModel({ + body: values, + }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: `Model Added: ${values.model_name}`, + status: 'success', + }) + ) + ); + manualAddCheckpointForm.reset(); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: 'Model Add Failed', + status: 'error', + }) + ) + ); + } + }); + }; + + return ( +
+ manualAddCheckpointFormHandler(v) + )} + style={{ width: '100%' }} + > + + + + + + + + + {!useCustomConfig ? ( + + ) : ( + + )} + setUseCustomConfig(!useCustomConfig)} + label="Use Custom Config" + /> + + {t('modelManager.addModel')} + + + +
+ ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddDiffusers.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddDiffusers.tsx index 1e51006dcd..a4b6870a54 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddDiffusers.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/ManualAddDiffusers.tsx @@ -31,7 +31,6 @@ export default function ManualAddDiffusers() { }, }); const manualAddDiffusersFormHandler = (values: DiffusersModelConfig) => { - console.log(values); addMainModel({ body: values, }) @@ -80,6 +79,7 @@ export default function ManualAddDiffusers() { ; + +export default function CheckpointConfigsSelect( + props: CheckpointConfigSelectProps +) { + const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery(); + const { ...rest } = props; + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index e0177ca8d1..d9cbaf69dd 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -96,6 +96,9 @@ type AddMainModelResponse = type SearchFolderResponse = paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json']; +type CheckpointConfigsResponse = + paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json']; + type SearchFolderArg = operations['search_for_models']['parameters']['query']; const mainModelsAdapter = createEntityAdapter({ @@ -383,6 +386,13 @@ export const modelsApi = api.injectEndpoints({ }; }, }), + getCheckpointConfigs: build.query({ + query: () => { + return { + url: `/models/ckpt_confs`, + }; + }, + }), }), }); @@ -399,4 +409,5 @@ export const { useConvertMainModelsMutation, useMergeMainModelsMutation, useGetModelsInFolderQuery, + useGetCheckpointConfigsQuery, } = modelsApi;