From 3a5314f1caf446e510b8ea48cfeff1206e53f5c9 Mon Sep 17 00:00:00 2001 From: Jennifer Player Date: Mon, 11 Mar 2024 08:42:04 -0400 Subject: [PATCH] install model if diffusers or single file, cleaned up backend logic to not mess with existing model install --- .../model_install/model_install_default.py | 14 +++-- .../model_manager/util/select_hf_files.py | 13 +--- .../HuggingFaceFolder/HuggingFaceForm.tsx | 59 ++++++++++++++++--- .../HuggingFaceFolder/HuggingFaceResults.tsx | 1 - 4 files changed, 63 insertions(+), 24 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index ffc838fb7a..ba9e88d745 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -241,12 +241,16 @@ class ModelInstallService(ModelInstallServiceBase): metadata = HuggingFaceMetadataFetch(self._session).from_id(source) assert isinstance(metadata, ModelMetadataWithFiles) - remote_files = metadata.download_urls( - session=self._session, - ) + urls: List[AnyHttpUrl] = [] - # return array of remote_files.url - return [x.url for x in remote_files] + for file in metadata.files: + if str(file.url).endswith( + (".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack", "model_index.json") + ): + urls.append(file.url) + + self._logger.info(f"here are the metadata files {metadata.files}") + return urls def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 return self._install_jobs diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index d45c488492..4a63ab27b7 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -51,15 +51,7 @@ def filter_files( ( "learned_embeds.bin", "ip_adapter.bin", - # jennifer added a bunch of these, probably will break something - ".safetensors", - ".bin", - ".onnx", - ".xml", - ".pth", - ".pt", - ".ckpt", - ".msgpack", + "lora_weights.safetensors", "weights.pb", "onnx_data", ) @@ -79,8 +71,7 @@ def filter_files( paths = [x for x in paths if x.parent == Path(subfolder)] # _filter_by_variant uniquifies the paths and returns a set - # jennifer removed the filter since it removed models, probably will break something but i dont understand why it removes valid models :| - return sorted(paths) + return sorted(_filter_by_variant(paths, variant)) @dataclass diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx index eb35ba718d..78313a8863 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx @@ -1,24 +1,69 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; import type { ChangeEventHandler } from 'react'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; +import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; import { HuggingFaceResults } from './HuggingFaceResults'; export const HuggingFaceForm = () => { const [huggingFaceRepo, setHuggingFaceRepo] = useState(''); + const [displayResults, setDisplayResults] = useState(false); const [errorMessage, setErrorMessage] = useState(''); const { t } = useTranslation(); + const dispatch = useAppDispatch(); const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery(); + const [installModel] = useInstallModelMutation(); + + const handleInstallModel = useCallback((source: string) => { + installModel({ source }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('toast.modelAddedSimple'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + }, [installModel, dispatch, t]); const scanFolder = useCallback(async () => { - _getHuggingFaceModels(huggingFaceRepo).catch((error) => { - if (error) { - setErrorMessage(error.data.detail); - } - }); + _getHuggingFaceModels(huggingFaceRepo) + .then((response) => { + if (response.data?.some((result) => result.endsWith('model_index.json'))) { + handleInstallModel(huggingFaceRepo); + setDisplayResults(false); + } else if (response.data?.length === 1 && response.data[0]) { + handleInstallModel(response.data[0]); + setDisplayResults(false); + } else { + setDisplayResults(true); + } + }) + .catch((error) => { + if (error) { + setErrorMessage(error.data.detail); + } + }); }, [_getHuggingFaceModels, huggingFaceRepo]); const handleSetHuggingFaceRepo: ChangeEventHandler = useCallback((e) => { @@ -43,7 +88,7 @@ export const HuggingFaceForm = () => { {!!errorMessage.length && {errorMessage}} - {data && } + {data && displayResults && } ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx index 0a17a1b772..f95504900a 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx @@ -21,7 +21,6 @@ import { useInstallModelMutation } from 'services/api/endpoints/models'; import { HuggingFaceResultItem } from './HuggingFaceResultItem'; type HuggingFaceResultsProps = { - // results: HuggingFaceFolderResponse; results: string[]; };