From 8bc1ca046c3d7a843b2ab60e5bc6e421231f3709 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 24 Sep 2023 10:08:21 -0400 Subject: [PATCH] allow priority to be set at install job submission time --- invokeai/app/api/routers/models.py | 15 ++++++++++++++- invokeai/app/services/model_manager_service.py | 6 ++++++ invokeai/backend/model_manager/download/base.py | 4 +++- invokeai/backend/model_manager/download/queue.py | 8 ++++---- invokeai/backend/model_manager/install.py | 8 ++++++-- 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 45f746c590..cc2fdabe09 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -134,6 +134,10 @@ async def import_model( prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body( description="Prediction type for SDv2 checkpoint files", default="v_prediction" ), + priority: Optional[int] = Body( + description="Which import jobs run first. Lower values run before higher ones.", + default=10, + ), ) -> ModelImportStatus: """ Add a model using its local path, repo_id, or remote URL. @@ -142,6 +146,13 @@ async def import_model( series of background threads. The return object has a `job_id` property that can be used to control the download job. + The priority controls which import jobs run first. Lower values run before + higher ones. + + The prediction_type applies to SDv2 models only and can be one of + "v_prediction", "epsilon", or "sample". Default if not provided is + "v_prediction". + Listen on the event bus for a series of `model_event` events with an `id` matching the returned job id to get the progress, completion status, errors, and information on the model that was installed. @@ -149,7 +160,9 @@ async def import_model( logger = ApiDependencies.invoker.services.logger try: result = ApiDependencies.invoker.services.model_manager.install_model( - location, model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)} + location, + model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)}, + priority=priority, ) return ModelImportStatus( job_id=result.id, diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 05c10d38bb..7343104cc1 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -198,11 +198,13 @@ class ModelManagerServiceBase(ABC): self, source: Union[str, Path, AnyHttpUrl], model_attributes: Optional[Dict[str, Any]] = None, + priority: Optional[int] = 10, ) -> ModelInstallJob: """Import a path, repo_id or URL. Returns an ModelInstallJob. :param model_attributes: Additional attributes to supplement/override the model information gained from automated probing. + :param priority: Queue priority. Lower values have higher priority. Typical usage: job = model_manager.install( @@ -413,6 +415,7 @@ class ModelManagerService(ModelManagerServiceBase): self, source: Union[str, Path, AnyHttpUrl], model_attributes: Optional[Dict[str, Any]] = None, + priority: Optional[int] = 10, ) -> ModelInstallJob: """ Add a model using a path, repo_id or URL. @@ -420,6 +423,8 @@ class ModelManagerService(ModelManagerServiceBase): :param model_attributes: Dictionary of ModelConfigBase fields to attach to the model. When installing a URL or repo_id, some metadata values, such as `tags` will be automagically added. + :param priority: Queue priority for this install job. Lower value jobs + will run before higher value ones. """ self.logger.debug(f"add model {source}") variant = "fp16" if self._loader.precision == "float16" else None @@ -427,6 +432,7 @@ class ModelManagerService(ModelManagerServiceBase): source, probe_override=model_attributes, variant=variant, + priority=priority, ) def list_install_jobs(self) -> List[ModelInstallJob]: diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 07e90261bd..b46e285c1d 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -135,8 +135,9 @@ class DownloadQueueBase(ABC): self, source: Union[str, Path, AnyHttpUrl], destdir: Path, - filename: Optional[Path] = None, start: bool = True, + priority: int = 10, + filename: Optional[Path] = None, variant: Optional[str] = None, access_token: Optional[str] = None, event_handlers: Optional[List[DownloadEventHandler]] = None, @@ -146,6 +147,7 @@ class DownloadQueueBase(ABC): :param source: Source of the download - URL, repo_id or Path :param destdir: Directory to download into. + :param priority: Initial priority for this job [10] :param filename: Optional name of file, if not provided will use the content-disposition field to assign the name. :param start: Immediately start job [True] diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index 21b444080d..c51473c182 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -116,15 +116,14 @@ class DownloadQueue(DownloadQueueBase): self, source: Union[str, Path, AnyHttpUrl], destdir: Path, - filename: Optional[Path] = None, start: bool = True, + priority: int = 10, + filename: Optional[Path] = None, variant: Optional[str] = None, access_token: Optional[str] = None, event_handlers: Optional[List[DownloadEventHandler]] = None, ) -> DownloadJobBase: - """ - Create a download job and return its ID. - """ + """Create a download job and return its ID.""" kwargs = dict() if Path(source).exists(): @@ -142,6 +141,7 @@ class DownloadQueue(DownloadQueueBase): destination=Path(destdir) / (filename or "."), access_token=access_token, event_handlers=(event_handlers or self._event_handlers), + priority=priority, **kwargs, ) diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index d231c4e6c0..3903ceae69 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -492,10 +492,11 @@ class ModelInstall(ModelInstallBase): probe_override: Optional[Dict[str, Any]] = None, metadata: Optional[ModelSourceMetadata] = None, access_token: Optional[str] = None, + priority: Optional[int] = 10, ) -> DownloadJobBase: # noqa D102 queue = self._download_queue - job = self._make_download_job(source, variant, access_token) + job = self._make_download_job(source, variant=variant, access_token=access_token, priority=priority) handler = ( self._complete_registration_handler if inplace and Path(source).exists() @@ -581,6 +582,7 @@ class ModelInstall(ModelInstallBase): source: Union[str, Path, AnyHttpUrl], variant: Optional[str] = None, access_token: Optional[str] = None, + priority: Optional[int] = 10, ) -> DownloadJobBase: # Clean up a common source of error. Doesn't work with Paths. if isinstance(source, str): @@ -606,7 +608,9 @@ class ModelInstall(ModelInstallBase): kwargs = {} else: raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL") - return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs) + return cls( + source=source, destination=Path(self._tmpdir.name), access_token=access_token, priority=priority, **kwargs + ) def wait_for_installs(self) -> Dict[str, str]: # noqa D102 self._download_queue.join()