diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 388f4a5ba2..ccb8e3772e 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -177,6 +177,7 @@ class ModelInstallJob(BaseModel): ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _do_install: Optional[bool] = PrivateAttr(default=True) _exception: Optional[Exception] = PrivateAttr(default=None) def set_error(self, e: Exception) -> None: @@ -407,6 +408,21 @@ class ModelInstallServiceBase(ABC): """ + @abstractmethod + def download_diffusers_model( + self, + source: HFModelSource, + download_to: Path, + ) -> ModelInstallJob: + """ + Download, but do not install, a diffusers model. + + :param source: An HFModelSource object containing a repo_id + :param download_to: Path to directory that will contain the downloaded model. + + Returns: a ModelInstallJob + """ + @abstractmethod def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: """Return the ModelInstallJob(s) corresponding to the provided source.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1a08624f8e..fe932649c4 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -249,6 +249,9 @@ class ModelInstallService(ModelInstallServiceBase): self._install_jobs.append(install_job) return install_job + def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob: + return self._import_from_hf(source, download_path=download_to) + def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 return self._install_jobs @@ -641,7 +644,12 @@ class ModelInstallService(ModelInstallServiceBase): inplace=source.inplace or False, ) - def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: + def _import_from_hf( + self, + source: HFModelSource, + config: Optional[Dict[str, Any]] = None, + download_path: Optional[Path] = None, + ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests source.access_token = source.access_token or HfFolder.get_token() if not source.access_token: @@ -660,9 +668,14 @@ class ModelInstallService(ModelInstallServiceBase): config=config, remote_files=remote_files, metadata=metadata, + download_path=download_path, ) - def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: + def _import_from_url( + self, + source: URLModelSource, + config: Optional[Dict[str, Any]], + ) -> ModelInstallJob: # URLs from HuggingFace will be handled specially metadata = None fetcher = None @@ -676,6 +689,7 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.debug(f"metadata={metadata}") if metadata and isinstance(metadata, ModelMetadataWithFiles): remote_files = metadata.download_urls(session=self._session) + print(remote_files) else: remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] return self._import_remote_model( @@ -691,13 +705,14 @@ class ModelInstallService(ModelInstallServiceBase): remote_files: List[RemoteModelFile], metadata: Optional[AnyModelRepoMetadata], config: Optional[Dict[str, Any]], + download_path: Optional[Path] = None, # if defined, download only - don't install! ) -> ModelInstallJob: # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # Currently the tmpdir isn't automatically removed at exit because it is # being held in a daemon thread. if len(remote_files) == 0: raise ValueError(f"{source}: No downloadable files found") - tmpdir = Path( + destdir = download_path or Path( mkdtemp( dir=self._app_config.models_path, prefix=TMPDIR_PREFIX, @@ -708,7 +723,7 @@ class ModelInstallService(ModelInstallServiceBase): source=source, config_in=config or {}, source_metadata=metadata, - local_path=tmpdir, # local path may change once the download has started due to content-disposition handling + local_path=destdir, # local path may change once the download has started due to content-disposition handling bytes=0, total_bytes=0, ) @@ -722,9 +737,10 @@ class ModelInstallService(ModelInstallServiceBase): root = Path(".") subfolder = Path(".") - # we remember the path up to the top of the tmpdir so that it may be + # we remember the path up to the top of the destdir so that it may be # removed safely at the end of the install process. - install_job._install_tmpdir = tmpdir + install_job._install_tmpdir = destdir + install_job._do_install = download_path is None assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below files_string = "file" if len(remote_files) == 1 else "file" @@ -736,7 +752,7 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.debug(f"Downloading {url} => {path}") install_job.total_bytes += model_file.size assert hasattr(source, "access_token") - dest = tmpdir / path.parent + dest = destdir / path.parent dest.mkdir(parents=True, exist_ok=True) download_job = DownloadJob( source=url, @@ -805,7 +821,8 @@ class ModelInstallService(ModelInstallServiceBase): # are there any more active jobs left in this task? if install_job.downloading and all(x.complete for x in install_job.download_parts): self._signal_job_downloads_done(install_job) - self._put_in_queue(install_job) + if install_job._do_install: + self._put_in_queue(install_job) # Let other threads know that the number of downloads has changed self._download_cache.pop(download_job.source, None) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9e30d96016..e3c99c5644 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - Default = "" # model files without "fp16" or other qualifier - empty str + Default = "" # model files without "fp16" or other qualifier FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 4e3625fdbe..ab78b3e064 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): assert s.size is not None files.append( RemoteModelFile( - url=hf_hub_url(id, s.rfilename, revision=variant), + url=hf_hub_url(id, s.rfilename, revision=variant or "main"), path=Path(name, s.rfilename), size=s.size, sha256=s.lfs.get("sha256") if s.lfs else None, diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index c755d3c491..ba84455240 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -14,6 +14,7 @@ from pydantic_core import Url from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ( + HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, @@ -21,7 +22,13 @@ from invokeai.app.services.model_install import ( URLModelSource, ) from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException -from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType +from invokeai.backend.model_manager.config import ( + BaseModelType, + InvalidModelConfigException, + ModelFormat, + ModelRepoVariant, + ModelType, +) from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 OS = platform.uname().system @@ -247,7 +254,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: @pytest.mark.timeout(timeout=20, method="thread") -def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: +def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) bus = mm2_installer.event_bus @@ -278,6 +285,17 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co } +@pytest.mark.timeout(timeout=20, method="thread") +def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None: + source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) + job = mm2_installer.download_diffusers_model(source, tmp_path) + mm2_installer.wait_for_installs(timeout=5) + print(job.local_path) + assert job.status == InstallStatus.DOWNLOADS_DONE + assert (tmp_path / "sdxl-turbo").exists() + assert (tmp_path / "sdxl-turbo" / "model_index.json").exists() + + def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) job = mm2_installer.import_model(source) @@ -327,7 +345,7 @@ def test_other_error_during_install( }, ], ) -@pytest.mark.timeout(timeout=40, method="thread") +@pytest.mark.timeout(timeout=20, method="thread") def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): """Test whether or not type is respected on configs when passed to heuristic import.""" assert "name" in model_params and "type" in model_params