From e079cc9f07303422baec0c607c22ef832dbff030 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 11 Oct 2023 22:42:07 -0400 Subject: [PATCH] add back source URL validation to download job hierarchy --- invokeai/app/api/dependencies.py | 2 +- invokeai/app/services/config/invokeai_config.py | 6 +++--- invokeai/app/services/model_install_service.py | 3 ++- invokeai/backend/model_manager/download/model_queue.py | 6 +++++- tests/AC_model_manager/test_model_install_service.py | 2 +- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index b3f704af66..a05d6d0d34 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -130,7 +130,7 @@ class ApiDependencies: ) ) - download_queue = DownloadQueueService(event_bus=events, config=config) + download_queue = DownloadQueueService(event_bus=events) model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=lock) model_loader = ModelLoadService(config, model_record_store) model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events) diff --git a/invokeai/app/services/config/invokeai_config.py b/invokeai/app/services/config/invokeai_config.py index 2c51550dc6..d17bca7791 100644 --- a/invokeai/app/services/config/invokeai_config.py +++ b/invokeai/app/services/config/invokeai_config.py @@ -241,8 +241,8 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") # CACHE - ram : Union[float, Literal["auto"]] = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", ) - vram : Union[float, Literal["auto"]] = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", ) + ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching", category="Model Cache", ) + vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage", category="Model Cache", ) disk : float = Field(default=DEFAULT_MAX_DISK_CACHE, ge=0, description="Maximum size (in GB) for the disk-based diffusers model conversion cache", category="Model Cache", ) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", ) @@ -443,7 +443,7 @@ class InvokeAIAppConfig(InvokeAISettings): return self.disk @property - def vram_cache_size(self) -> Union[Literal["auto"], float]: + def vram_cache_size(self) -> float: return self.max_vram_cache_size or self.vram @property diff --git a/invokeai/app/services/model_install_service.py b/invokeai/app/services/model_install_service.py index 333fcc4055..cfb1887d4b 100644 --- a/invokeai/app/services/model_install_service.py +++ b/invokeai/app/services/model_install_service.py @@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.download.model_queue import ( HTTP_RE, REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, + DownloadJobMetadataURL, DownloadJobRepoID, DownloadJobWithMetadata, ) @@ -60,7 +61,7 @@ class ModelInstallJob(DownloadJobBase): ) -class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob): +class ModelInstallURLJob(DownloadJobMetadataURL, ModelInstallJob): """Job for installing URLs.""" diff --git a/invokeai/backend/model_manager/download/model_queue.py b/invokeai/backend/model_manager/download/model_queue.py index 496d4ba619..270df4c798 100644 --- a/invokeai/backend/model_manager/download/model_queue.py +++ b/invokeai/backend/model_manager/download/model_queue.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, parse_obj_as, validator from pydantic.networks import AnyHttpUrl from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase -from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadQueue +from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue # regular expressions used to dispatch appropriate downloaders and metadata retrievers # endpoint for civitai get-model API @@ -40,6 +40,10 @@ class DownloadJobWithMetadata(DownloadJobRemoteSource): ) +class DownloadJobMetadataURL(DownloadJobWithMetadata, DownloadJobURL): + """DownloadJobWithMetadata with validation of the source URL.""" + + class DownloadJobRepoID(DownloadJobWithMetadata): """Download repo ids.""" diff --git a/tests/AC_model_manager/test_model_install_service.py b/tests/AC_model_manager/test_model_install_service.py index a96105a879..8f7c9b8378 100644 --- a/tests/AC_model_manager/test_model_install_service.py +++ b/tests/AC_model_manager/test_model_install_service.py @@ -50,7 +50,7 @@ def test_install(datadir: Path): ) event_bus = DummyEventService() - mm_store = ModelRecordServiceBase.get_impl(config) + mm_store = ModelRecordServiceBase.open(config) mm_load = ModelLoadService(config, mm_store) mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus)