diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index c12046293c..d53198b98e 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the following initialization pattern: ``` -from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.config import get_config from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.download import DownloadQueueService -from invokeai.app.services.shared.sqlite import SqliteDatabase +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.util.logging import InvokeAILogger -config = InvokeAIAppConfig.get_config() -config.parse_args() +config = get_config() logger = InvokeAILogger.get_logger(config=config) -db = SqliteDatabase(config, logger) +db = SqliteDatabase(config.db_path, logger) record_store = ModelRecordServiceSQL(db) queue = DownloadQueueService() queue.start() -installer = ModelInstallService(app_config=config, +installer = ModelInstallService(app_config=config, record_store=record_store, - download_queue=queue - ) + download_queue=queue + ) installer.start() ``` diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 68cf9591e0..b622c8dade 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -466,17 +466,14 @@ class ModelInstallServiceBase(ABC): """ @abstractmethod - def download_and_cache_ckpt( + def download_and_cache_model( self, - source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: int = 0, + source: str, ) -> Path: """ Download the model file located at source to the models cache and return its Path. - :param source: A Url or a string that can be converted into one. - :param access_token: Optional access token to access restricted resources. + :param source: A string representing a URL or repo_id. The model file will be downloaded into the system-wide model cache (`models/.cache`) if it isn't already there. Note that the model cache diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1d77b2c6e1..a6bb7ad10d 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -9,7 +9,7 @@ from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml @@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase @@ -208,26 +208,12 @@ class ModelInstallService(ModelInstallServiceBase): access_token: Optional[str] = None, inplace: Optional[bool] = False, ) -> ModelInstallJob: - variants = "|".join(ModelRepoVariant.__members__.values()) - hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" - source_obj: Optional[StringLikeSource] = None - - if Path(source).exists(): # A local file or directory - source_obj = LocalModelSource(path=Path(source), inplace=inplace) - elif match := re.match(hf_repoid_re, source): - source_obj = HFModelSource( - repo_id=match.group(1), - variant=match.group(2) if match.group(2) else None, # pass None rather than '' - subfolder=Path(match.group(3)) if match.group(3) else None, - access_token=access_token, - ) - elif re.match(r"^https?://[^/]+", source): - source_obj = URLModelSource( - url=AnyHttpUrl(source), - access_token=access_token, - ) - else: - raise ValueError(f"Unsupported model source: '{source}'") + """Install a model using pattern matching to infer the type of source.""" + source_obj = self._guess_source(source) + if isinstance(source_obj, LocalModelSource): + source_obj.inplace = inplace + elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource): + source_obj.access_token = access_token return self.import_model(source_obj, config) def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 @@ -383,37 +369,86 @@ class ModelInstallService(ModelInstallServiceBase): escaped_source = slugify(str(source)) return app_config.download_cache_path / escaped_source - def download_and_cache_ckpt( + def download_and_cache_model( self, - source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: int = 0, + source: str, ) -> Path: """Download the model file located at source to the models cache and return its Path.""" - model_path = self._download_cache_path(source, self._app_config) + model_path = self._download_cache_path(str(source), self._app_config) - # We expect the cache directory to contain one and only one downloaded file. + # We expect the cache directory to contain one and only one downloaded file or directory. # We don't know the file's name in advance, as it is set by the download # content-disposition header. if model_path.exists(): - contents = [x for x in model_path.iterdir() if x.is_file()] + contents: List[Path] = list(model_path.iterdir()) if len(contents) > 0: return contents[0] model_path.mkdir(parents=True, exist_ok=True) - job = self._download_queue.download( - source=AnyHttpUrl(str(source)), + model_source = self._guess_source(source) + remote_files, _ = self._remote_files_from_source(model_source) + job = self._download_queue.multifile_download( + parts=remote_files, dest=model_path, - access_token=access_token, - on_progress=TqdmProgress().update, ) - self._download_queue.wait_for_job(job, timeout) + files_string = "file" if len(remote_files) == 1 else "file" + self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) if job.complete: assert job.download_path is not None return job.download_path else: raise Exception(job.error) + def _remote_files_from_source( + self, source: ModelSource + ) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]: + metadata = None + if isinstance(source, HFModelSource): + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) + assert isinstance(metadata, ModelMetadataWithFiles) + return metadata.download_urls( + variant=source.variant or self._guess_variant(), + subfolder=source.subfolder, + session=self._session, + ), metadata + + if isinstance(source, URLModelSource): + try: + fetcher = self.get_fetcher_from_url(str(source.url)) + kwargs: dict[str, Any] = {"session": self._session} + metadata = fetcher(**kwargs).from_url(source.url) + assert isinstance(metadata, ModelMetadataWithFiles) + return metadata.download_urls(session=self._session), metadata + except ValueError: + pass + + return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None + + raise Exception(f"No files associated with {source}") + + def _guess_source(self, source: str) -> ModelSource: + """Turn a source string into a ModelSource object.""" + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=match.group(2) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=AnyHttpUrl(source), + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return source_obj + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -650,18 +685,9 @@ class ModelInstallService(ModelInstallServiceBase): config: Optional[Dict[str, Any]] = None, ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests - source.access_token = source.access_token or HfFolder.get_token() - if not source.access_token: - self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) - assert isinstance(metadata, ModelMetadataWithFiles) - remote_files = metadata.download_urls( - variant=source.variant or self._guess_variant(), - subfolder=source.subfolder, - session=self._session, - ) - + if source.access_token is None: + source.access_token = HfFolder.get_token() + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -674,21 +700,7 @@ class ModelInstallService(ModelInstallServiceBase): source: URLModelSource, config: Optional[Dict[str, Any]], ) -> ModelInstallJob: - # URLs from HuggingFace will be handled specially - metadata = None - fetcher = None - try: - fetcher = self.get_fetcher_from_url(str(source.url)) - except ValueError: - pass - kwargs: dict[str, Any] = {"session": self._session} - if fetcher is not None: - metadata = fetcher(**kwargs).from_url(source.url) - self._logger.debug(f"metadata={metadata}") - if metadata and isinstance(metadata, ModelMetadataWithFiles): - remote_files = metadata.download_urls(session=self._session) - else: - remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -733,26 +745,17 @@ class ModelInstallService(ModelInstallServiceBase): root = Path(".") subfolder = Path(".") - # we remember the path up to the top of the destdir so that it may be - # removed safely at the end of the install process. - install_job._install_tmpdir = destdir - parts: List[RemoteModelFile] = [] for model_file in remote_files: assert install_job.total_bytes is not None assert model_file.size is not None install_job.total_bytes += model_file.size parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder))) - multifile_job = self._download_queue.multifile_download( + multifile_job = self._multifile_download( parts=parts, dest=destdir, access_token=source.access_token, - submit_job=False, - on_start=self._download_started_callback, - on_progress=self._download_progress_callback, - on_complete=self._download_complete_callback, - on_error=self._download_error_callback, - on_cancelled=self._download_cancelled_callback, + submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict ) self._download_cache[multifile_job.id] = install_job install_job._download_job = multifile_job @@ -772,6 +775,21 @@ class ModelInstallService(ModelInstallServiceBase): size += sum(self._stat_size(Path(root, x)) for x in files) return size + def _multifile_download( + self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True + ) -> MultiFileDownloadJob: + return self._download_queue.multifile_download( + parts=parts, + dest=dest, + access_token=access_token, + submit_job=submit_job, + on_start=self._download_started_callback, + on_progress=self._download_progress_callback, + on_complete=self._download_complete_callback, + on_error=self._download_error_callback, + on_cancelled=self._download_cancelled_callback, + ) + # ------------------------------------------------------------------ # Callbacks are executed by the download queue in a separate thread # ------------------------------------------------------------------ @@ -875,10 +893,9 @@ class ModelInstallService(ModelInstallServiceBase): assert job.local_path is not None assert job.config_out is not None key = job.config_out.key - self._event_bus.emit_model_install_completed(source=str(job.source), - key=key, - id=job.id, - total_bytes=job.bytes) + self._event_bus.emit_model_install_completed( + source=str(job.source), key=key, id=job.id, total_bytes=job.bytes + ) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 094ade6383..57531cf3c1 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -12,15 +12,13 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager import ( +from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, - ModelFormat, - ModelType, -) -from invokeai.backend.model_manager.config import ( ControlAdapterDefaultSettings, MainModelDefaultSettings, + ModelFormat, + ModelType, ModelVariantType, SchedulerPredictionType, ) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index f73b827534..ca8616238f 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -222,7 +222,7 @@ def test_delete_register( store.get_model(key) -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) @@ -253,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: ] -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) @@ -285,9 +285,8 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con } -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: - # TODO: Test subfolder download source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) bus = mm2_installer.event_bus @@ -323,6 +322,7 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con assert job.total_bytes == completed_events[0].payload["total_bytes"] assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"]) + def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) job = mm2_installer.import_model(source) @@ -371,7 +371,7 @@ def test_other_error_during_install( }, ], ) -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): """Test whether or not type is respected on configs when passed to heuristic import.""" assert "name" in model_params and "type" in model_params @@ -387,7 +387,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode } assert "repo_id" in model_params install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) - mm2_installer.wait_for_job(install_job1, timeout=20) + mm2_installer.wait_for_job(install_job1, timeout=10) if model_params["type"] != "embedding": assert install_job1.errored assert install_job1.error_type == "InvalidModelConfigException" @@ -396,6 +396,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) - mm2_installer.wait_for_job(install_job2, timeout=20) + mm2_installer.wait_for_job(install_job2, timeout=10) assert install_job2.complete assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out