From f7cd3cf1f41460be75dcc77cff315c1943ed9c37 Mon Sep 17 00:00:00 2001 From: Jennifer Player Date: Thu, 7 Mar 2024 16:57:28 -0500 Subject: [PATCH] added hf models import tab and route for getting available hf models --- invokeai/app/api/routers/model_manager.py | 26 +++- .../model_install/model_install_base.py | 14 +++ .../model_install/model_install_default.py | 16 +++ .../model_manager/util/select_hf_files.py | 13 +- invokeai/frontend/web/public/locales/en.json | 2 + .../HuggingFaceFolder/HuggingFaceForm.tsx | 49 ++++++++ .../HuggingFaceResultItem.tsx | 59 +++++++++ .../HuggingFaceFolder/HuggingFaceResults.tsx | 119 ++++++++++++++++++ .../subpanels/InstallModels.tsx | 5 + .../web/src/services/api/endpoints/models.ts | 8 ++ 10 files changed, 308 insertions(+), 3 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index d3c2510b1b..8740e8ce90 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -11,7 +11,7 @@ from fastapi import Body, Path, Query, Response, UploadFile from fastapi.responses import FileResponse from fastapi.routing import APIRouter from PIL import Image -from pydantic import BaseModel, ConfigDict, Field +from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field from starlette.exceptions import HTTPException from typing_extensions import Annotated @@ -29,6 +29,7 @@ from invokeai.backend.model_manager.config import ( ModelType, SubModelType, ) +from invokeai.backend.model_manager.metadata.metadata_base import RemoteModelFile from invokeai.backend.model_manager.search import ModelSearch from ..dependencies import ApiDependencies @@ -246,6 +247,29 @@ async def scan_for_models( return scan_results +@model_manager_router.get( + "/hugging_face", + operation_id="get_hugging_face_models", + responses={ + 200: {"description": "Hugging Face repo scanned successfully"}, + 400: {"description": "Invalid hugging face repo"}, + }, + status_code=200, + response_model=List[AnyHttpUrl], +) +async def get_hugging_face_models( + hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None), +) -> List[AnyHttpUrl]: + get_hugging_face_models = ApiDependencies.invoker.services.model_manager.install.get_hugging_face_models + get_hugging_face_models(hugging_face_repo) + + result = get_hugging_face_models( + source=hugging_face_repo, + ) + + return result + + @model_manager_router.patch( "/i/{key}", operation_id="update_model_record", diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index b7385495e5..3989b46c8e 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -20,6 +20,7 @@ from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata +from invokeai.backend.model_manager.metadata.metadata_base import RemoteModelFile class InstallStatus(str, Enum): @@ -405,6 +406,19 @@ class ModelInstallServiceBase(ABC): """ + @abstractmethod + def get_hugging_face_models( + self, + source: str, + ) -> List[AnyHttpUrl]: + """Get the available models in a HuggingFace repo. + + :param source: HuggingFace repo string + + This will get the urls for the available models in the indicated, + repo, and return them as a list of AnyHttpUrl strings. + """ + @abstractmethod def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: """Return the ModelInstallJob(s) corresponding to the provided source.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 138bde8bbf..67ba4072f8 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -233,6 +233,22 @@ class ModelInstallService(ModelInstallServiceBase): self._install_jobs.append(install_job) return install_job + def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]: + # Add user's cached access token to HuggingFace requests + access_token = HfFolder.get_token() + if not access_token: + self._logger.info("No HuggingFace access token present; some models may not be downloadable.") + + metadata = HuggingFaceMetadataFetch(self._session).from_id(source) + self._logger.info(f"metadata is {metadata}") + assert isinstance(metadata, ModelMetadataWithFiles) + remote_files = metadata.download_urls( + session=self._session, + ) + + # return array of remote_files.url + return [x.url for x in remote_files] + def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 return self._install_jobs diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 4a63ab27b7..d45c488492 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -51,7 +51,15 @@ def filter_files( ( "learned_embeds.bin", "ip_adapter.bin", - "lora_weights.safetensors", + # jennifer added a bunch of these, probably will break something + ".safetensors", + ".bin", + ".onnx", + ".xml", + ".pth", + ".pt", + ".ckpt", + ".msgpack", "weights.pb", "onnx_data", ) @@ -71,7 +79,8 @@ def filter_files( paths = [x for x in paths if x.parent == Path(subfolder)] # _filter_by_variant uniquifies the paths and returns a set - return sorted(_filter_by_variant(paths, variant)) + # jennifer removed the filter since it removed models, probably will break something but i dont understand why it removes valid models :| + return sorted(paths) @dataclass diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ee92770620..4a2e1d30d3 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -762,6 +762,8 @@ "formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above.", "height": "Height", "heightValidationMsg": "Default height of your model.", + "huggingFace": "Hugging Face", + "huggingFaceRepoID": "HuggingFace Repo ID", "ignoreMismatch": "Ignore Mismatches Between Selected Models", "imageEncoderModelId": "Image Encoder Model ID", "importModels": "Import Models", diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx new file mode 100644 index 0000000000..eb35ba718d --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx @@ -0,0 +1,49 @@ +import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library'; +import type { ChangeEventHandler } from 'react'; +import { useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; + +import { HuggingFaceResults } from './HuggingFaceResults'; + +export const HuggingFaceForm = () => { + const [huggingFaceRepo, setHuggingFaceRepo] = useState(''); + const [errorMessage, setErrorMessage] = useState(''); + const { t } = useTranslation(); + + const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery(); + + const scanFolder = useCallback(async () => { + _getHuggingFaceModels(huggingFaceRepo).catch((error) => { + if (error) { + setErrorMessage(error.data.detail); + } + }); + }, [_getHuggingFaceModels, huggingFaceRepo]); + + const handleSetHuggingFaceRepo: ChangeEventHandler = useCallback((e) => { + setHuggingFaceRepo(e.target.value); + setErrorMessage(''); + }, []); + + return ( + + + + + + {t('modelManager.huggingFaceRepoID')} + + + + + + {!!errorMessage.length && {errorMessage}} + + + {data && } + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx new file mode 100644 index 0000000000..a0d84e4861 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx @@ -0,0 +1,59 @@ +import { Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { IoAdd } from 'react-icons/io5'; +import { useInstallModelMutation } from 'services/api/endpoints/models'; + +type Props = { + result: string; +}; +export const HuggingFaceResultItem = ({ result }: Props) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const [installModel] = useInstallModelMutation(); + + const handleQuickAdd = useCallback(() => { + installModel({ source: result }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('toast.modelAddedSimple'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + }, [installModel, result, dispatch, t]); + + return ( + + + {result.split('/').slice(-1)[0]} + {result} + + + + } onClick={handleQuickAdd} /> + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx new file mode 100644 index 0000000000..0a17a1b772 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx @@ -0,0 +1,119 @@ +import { + Button, + Divider, + Flex, + Heading, + IconButton, + Input, + InputGroup, + InputRightElement, +} from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import type { ChangeEventHandler } from 'react'; +import { useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiXBold } from 'react-icons/pi'; +import { useInstallModelMutation } from 'services/api/endpoints/models'; + +import { HuggingFaceResultItem } from './HuggingFaceResultItem'; + +type HuggingFaceResultsProps = { + // results: HuggingFaceFolderResponse; + results: string[]; +}; + +export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => { + const { t } = useTranslation(); + const [searchTerm, setSearchTerm] = useState(''); + const dispatch = useAppDispatch(); + + const [installModel] = useInstallModelMutation(); + + const handleSearch: ChangeEventHandler = useCallback((e) => { + setSearchTerm(e.target.value.trim()); + }, []); + + const clearSearch = useCallback(() => { + setSearchTerm(''); + }, []); + + const handleAddAll = useCallback(() => { + for (const result of results) { + installModel({ source: result }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('toast.modelAddedSimple'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + } + }, [results, installModel, dispatch, t]); + + return ( + <> + + + + + {t('modelManager.availableModels')} + + + + + + + {searchTerm && ( + + } + onClick={clearSearch} + /> + + )} + + + + + + + {results.map((result) => ( + + ))} + + + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx index 9eb8d5185f..8aff86d47c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx @@ -1,6 +1,7 @@ import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library'; import { useTranslation } from 'react-i18next'; +import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm'; import { InstallModelForm } from './AddModelPanel/InstallModelForm'; import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue'; import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm'; @@ -16,12 +17,16 @@ export const InstallModels = () => { {t('common.simple')} + {t('modelManager.huggingFace')} {t('modelManager.scan')} + + + diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c3780a7ca0..7141df76a8 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -258,6 +258,13 @@ export const modelsApi = api.injectEndpoints({ }; }, }), + getHuggingFaceModels: build.query({ + query: (hugging_face_repo) => { + return { + url: buildModelsUrl(`hugging_face?hugging_face_repo=${hugging_face_repo}`), + }; + }, + }), listModelInstalls: build.query({ query: () => { return { @@ -381,6 +388,7 @@ export const { useConvertModelMutation, useSyncModelsMutation, useLazyScanFolderQuery, + useLazyGetHuggingFaceModelsQuery, useListModelInstallsQuery, useCancelModelInstallMutation, usePruneCompletedModelInstallsMutation,