From fbede84405a30c55d9d02e8a61bc4ca2d0c20458 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 22 Dec 2023 12:35:57 -0500 Subject: [PATCH] [feature] Download Queue (#5225) * add base definition of download manager * basic functionality working * add unit tests for download queue * add documentation and FastAPI route * fix docs * add missing test dependency; fix import ordering * fix file path length checking on windows * fix ruff check error * move release() into the __del__ method * disable testing of stderr messages due to issues with pytest capsys fixture * fix unsorted imports * harmonized implementation of start() and stop() calls in download and & install modules * Update invokeai/app/services/download/download_base.py Co-authored-by: Ryan Dick * replace test datadir fixture with tmp_path * replace DownloadJobBase->DownloadJob in download manager documentation * make source and dest arguments to download_queue.download() an AnyHttpURL and Path respectively * fix pydantic typecheck errors in the download unit test * ruff formatting * add "job cancelled" as an event rather than an exception * fix ruff errors * Update invokeai/app/services/download/download_default.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * use threading.Event to stop service worker threads; handle unfinished job edge cases * remove dangling STOP job definition * fix ruff complaint * fix ruff check again * avoid race condition when start() and stop() are called simultaneously from different threads * avoid race condition in stop() when a job becomes active while shutting down --------- Co-authored-by: Lincoln Stein Co-authored-by: Ryan Dick Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> --- docs/contributing/DOWNLOAD_QUEUE.md | 277 ++++++++++++ invokeai/app/api/dependencies.py | 3 + invokeai/app/api/routers/download_queue.py | 111 +++++ invokeai/app/api_app.py | 2 + invokeai/app/services/download/__init__.py | 12 + .../app/services/download/download_base.py | 217 +++++++++ .../app/services/download/download_default.py | 418 ++++++++++++++++++ invokeai/app/services/events/events_base.py | 81 ++++ invokeai/app/services/invocation_services.py | 4 + .../model_install/model_install_base.py | 9 +- .../model_install/model_install_default.py | 10 +- mkdocs.yml | 2 + pyproject.toml | 1 + tests/aa_nodes/test_graph_execution_state.py | 1 + tests/aa_nodes/test_invoker.py | 1 + .../services/download/test_download_queue.py | 223 ++++++++++ .../model_install/test_model_install.py | 4 +- 17 files changed, 1367 insertions(+), 9 deletions(-) create mode 100644 docs/contributing/DOWNLOAD_QUEUE.md create mode 100644 invokeai/app/api/routers/download_queue.py create mode 100644 invokeai/app/services/download/__init__.py create mode 100644 invokeai/app/services/download/download_base.py create mode 100644 invokeai/app/services/download/download_default.py create mode 100644 tests/app/services/download/test_download_queue.py diff --git a/docs/contributing/DOWNLOAD_QUEUE.md b/docs/contributing/DOWNLOAD_QUEUE.md new file mode 100644 index 0000000000..d43c670d2c --- /dev/null +++ b/docs/contributing/DOWNLOAD_QUEUE.md @@ -0,0 +1,277 @@ +# The InvokeAI Download Queue + +The DownloadQueueService provides a multithreaded parallel download +queue for arbitrary URLs, with queue prioritization, event handling, +and restart capabilities. + +## Simple Example + +``` +from invokeai.app.services.download import DownloadQueueService, TqdmProgress + +download_queue = DownloadQueueService() +for url in ['https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true', + 'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true', + 'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png', + 'https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor', + ]: + + # urls start downloading as soon as download() is called + download_queue.download(source=url, + dest='/tmp/downloads', + on_progress=TqdmProgress().update + ) + +download_queue.join() # wait for all downloads to finish +for job in download_queue.list_jobs(): + print(job.model_dump_json(exclude_none=True, indent=4),"\n") +``` + +Output: + +``` +{ + "source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true", + "dest": "/tmp/downloads", + "id": 0, + "priority": 10, + "status": "completed", + "download_path": "/tmp/downloads/a-painting-of-a-fire.png", + "job_started": "2023-12-04T05:34:41.742174", + "job_ended": "2023-12-04T05:34:42.592035", + "bytes": 666734, + "total_bytes": 666734 +} + +{ + "source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true", + "dest": "/tmp/downloads", + "id": 1, + "priority": 10, + "status": "completed", + "download_path": "/tmp/downloads/birdhouse.png", + "job_started": "2023-12-04T05:34:41.741975", + "job_ended": "2023-12-04T05:34:42.652841", + "bytes": 774949, + "total_bytes": 774949 +} + +{ + "source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png", + "dest": "/tmp/downloads", + "id": 2, + "priority": 10, + "status": "error", + "job_started": "2023-12-04T05:34:41.742079", + "job_ended": "2023-12-04T05:34:42.147625", + "bytes": 0, + "total_bytes": 0, + "error_type": "HTTPError(Not Found)", + "error": "Traceback (most recent call last):\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 182, in _download_next_item\n self._do_download(job)\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 206, in _do_download\n raise HTTPError(resp.reason)\nrequests.exceptions.HTTPError: Not Found\n" +} + +{ + "source": "https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor", + "dest": "/tmp/downloads", + "id": 3, + "priority": 10, + "status": "completed", + "download_path": "/tmp/downloads/xl_more_art-full_v1.safetensors", + "job_started": "2023-12-04T05:34:42.147645", + "job_ended": "2023-12-04T05:34:43.735990", + "bytes": 719020768, + "total_bytes": 719020768 +} +``` + +## The API + +The default download queue is `DownloadQueueService`, an +implementation of ABC `DownloadQueueServiceBase`. It juggles multiple +background download requests and provides facilities for interrogating +and cancelling the requests. Access to a current or past download task +is mediated via `DownloadJob` objects which report the current status +of a job request + +### The Queue Object + +A default download queue is located in +`ApiDependencies.invoker.services.download_queue`. However, you can +create additional instances if you need to isolate your queue from the +main one. + +``` +queue = DownloadQueueService(event_bus=events) +``` + +`DownloadQueueService()` takes three optional arguments: + +| **Argument** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| `max_parallel_dl` | int | 5 | Maximum number of simultaneous downloads allowed | +| `event_bus` | EventServiceBase | None | System-wide FastAPI event bus for reporting download events | +| `requests_session` | requests.sessions.Session | None | An alternative requests Session object to use for the download | + +`max_parallel_dl` specifies how many download jobs are allowed to run +simultaneously. Each will run in a different thread of execution. + +`event_bus` is an EventServiceBase, typically the one created at +InvokeAI startup. If present, download events are periodically emitted +on this bus to allow clients to follow download progress. + +`requests_session` is a url library requests Session object. It is +used for testing. + +### The Job object + +The queue operates on a series of download job objects. These objects +specify the source and destination of the download, and keep track of +the progress of the download. + +The only job type currently implemented is `DownloadJob`, a pydantic object with the +following fields: + +| **Field** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| _Fields passed in at job creation time_ | +| `source` | AnyHttpUrl | | Where to download from | +| `dest` | Path | | Where to download to | +| `access_token` | str | | [optional] string containing authentication token for access | +| `on_start` | Callable | | [optional] callback when the download starts | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_complete` | Callable | | [optional] callback called after successful download completion | +| `on_error` | Callable | | [optional] callback called after an error occurs | +| `id` | int | auto assigned | Job ID, an integer >= 0 | +| `priority` | int | 10 | Job priority. Lower priorities run before higher priorities | +| | +| _Fields updated over the course of the download task_ +| `status` | DownloadJobStatus| | Status code | +| `download_path` | Path | | Path to the location of the downloaded file | +| `job_started` | float | | Timestamp for when the job started running | +| `job_ended` | float | | Timestamp for when the job completed or errored out | +| `job_sequence` | int | | A counter that is incremented each time a model is dequeued | +| `bytes` | int | 0 | Bytes downloaded so far | +| `total_bytes` | int | 0 | Total size of the file at the remote site | +| `error_type` | str | | String version of the exception that caused an error during download | +| `error` | str | | String version of the traceback associated with an error | +| `cancelled` | bool | False | Set to true if the job was cancelled by the caller| + +When you create a job, you can assign it a `priority`. If multiple +jobs are queued, the job with the lowest priority runs first. + +Every job has a `source` and a `dest`. `source` is a pydantic.networks AnyHttpUrl object. +The `dest` is a path on the local filesystem that specifies the +destination for the downloaded object. Its semantics are +described below. + +When the job is submitted, it is assigned a numeric `id`. The id can +then be used to fetch the job object from the queue. + +The `status` field is updated by the queue to indicate where the job +is in its lifecycle. Values are defined in the string enum +`DownloadJobStatus`, a symbol available from +`invokeai.app.services.download_manager`. Possible values are: + +| **Value** | **String Value** | ** Description ** | +|--------------|---------------------|-------------------| +| `WAITING` | waiting | Job is on the queue but not yet running| +| `RUNNING` | running | The download is started | +| `COMPLETED` | completed | Job has finished its work without an error | +| `ERROR` | error | Job encountered an error and will not run again| + +`job_started` and `job_ended` indicate when the job +was started (using a python timestamp) and when it completed. + +In case of an error, the job's status will be set to `DownloadJobStatus.ERROR`, the text of the +Exception that caused the error will be placed in the `error_type` +field and the traceback that led to the error will be in `error`. + +A cancelled job will have status `DownloadJobStatus.ERROR` and an +`error_type` field of "DownloadJobCancelledException". In addition, +the job's `cancelled` property will be set to True. + +### Callbacks + +Download jobs can be associated with a series of callbacks, each with +the signature `Callable[["DownloadJob"], None]`. The callbacks are assigned +using optional arguments `on_start`, `on_progress`, `on_complete` and +`on_error`. When the corresponding event occurs, the callback wil be +invoked and passed the job. The callback will be run in a `try:` +context in the same thread as the download job. Any exceptions that +occur during execution of the callback will be caught and converted +into a log error message, thereby allowing the download to continue. + +#### `TqdmProgress` + +The `invokeai.app.services.download.download_default` module defines a +class named `TqdmProgress` which can be used as an `on_progress` +handler to display a completion bar in the console. Use as follows: + +``` +from invokeai.app.services.download import TqdmProgress + +download_queue.download(source='http://some.server.somewhere/some_file', + dest='/tmp/downloads', + on_progress=TqdmProgress().update + ) + +``` + +### Events + +If the queue was initialized with the InvokeAI event bus (the case +when using `ApiDependencies.invoker.services.download_queue`), then +download events will also be issued on the bus. The events are: + +* `download_started` -- This is issued when a job is taken off the +queue and a request is made to the remote server for the URL headers, but before any data +has been downloaded. The event payload will contain the keys `source` +and `download_path`. The latter contains the path that the URL will be +downloaded to. + +* `download_progress -- This is issued periodically as the download +runs. The payload contains the keys `source`, `download_path`, +`current_bytes` and `total_bytes`. The latter two fields can be +used to display the percent complete. + +* `download_complete` -- This is issued when the download completes +successfully. The payload contains the keys `source`, `download_path` +and `total_bytes`. + +* `download_error` -- This is issued when the download stops because +of an error condition. The payload contains the fields `error_type` +and `error`. The former is the text representation of the exception, +and the latter is a traceback showing where the error occurred. + +### Job control + +To create a job call the queue's `download()` method. You can list all +jobs using `list_jobs()`, fetch a single job by its with +`id_to_job()`, cancel a running job with `cancel_job()`, cancel all +running jobs with `cancel_all_jobs()`, and wait for all jobs to finish +with `join()`. + +#### job = queue.download(source, dest, priority, access_token) + +Create a new download job and put it on the queue, returning the +DownloadJob object. + +#### jobs = queue.list_jobs() + +Return a list of all active and inactive `DownloadJob`s. + +#### job = queue.id_to_job(id) + +Return the job corresponding to given ID. + +Return a list of all active and inactive `DownloadJob`s. + +#### queue.prune_jobs() + +Remove inactive (complete or errored) jobs from the listing returned +by `list_jobs()`. + +#### queue.join() + +Block until all pending jobs have run to completion or errored out. + diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index eed178ee8b..9a8e06ac1a 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -11,6 +11,7 @@ from ..services.board_images.board_images_default import BoardImagesService from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage from ..services.boards.boards_default import BoardService from ..services.config import InvokeAIAppConfig +from ..services.download import DownloadQueueService from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService @@ -85,6 +86,7 @@ class ApiDependencies: latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) + download_queue_service = DownloadQueueService(event_bus=events) model_install_service = ModelInstallService( app_config=config, record_store=model_record_service, event_bus=events ) @@ -113,6 +115,7 @@ class ApiDependencies: logger=logger, model_manager=model_manager, model_records=model_record_service, + download_queue=download_queue_service, model_install=model_install_service, names=names, performance_statistics=performance_statistics, diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py new file mode 100644 index 0000000000..92b658c370 --- /dev/null +++ b/invokeai/app/api/routers/download_queue.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 Lincoln D. Stein +"""FastAPI route for the download queue.""" + +from typing import List, Optional + +from fastapi import Body, Path, Response +from fastapi.routing import APIRouter +from pydantic.networks import AnyHttpUrl +from starlette.exceptions import HTTPException + +from invokeai.app.services.download import ( + DownloadJob, + UnknownJobIDException, +) + +from ..dependencies import ApiDependencies + +download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"]) + + +@download_queue_router.get( + "/", + operation_id="list_downloads", +) +async def list_downloads() -> List[DownloadJob]: + """Get a list of active and inactive jobs.""" + queue = ApiDependencies.invoker.services.download_queue + return queue.list_jobs() + + +@download_queue_router.patch( + "/", + operation_id="prune_downloads", + responses={ + 204: {"description": "All completed jobs have been pruned"}, + 400: {"description": "Bad request"}, + }, +) +async def prune_downloads(): + """Prune completed and errored jobs.""" + queue = ApiDependencies.invoker.services.download_queue + queue.prune_jobs() + return Response(status_code=204) + + +@download_queue_router.post( + "/i/", + operation_id="download", +) +async def download( + source: AnyHttpUrl = Body(description="download source"), + dest: str = Body(description="download destination"), + priority: int = Body(default=10, description="queue priority"), + access_token: Optional[str] = Body(default=None, description="token for authorization to download"), +) -> DownloadJob: + """Download the source URL to the file or directory indicted in dest.""" + queue = ApiDependencies.invoker.services.download_queue + return queue.download(source, dest, priority, access_token) + + +@download_queue_router.get( + "/i/{id}", + operation_id="get_download_job", + responses={ + 200: {"description": "Success"}, + 404: {"description": "The requested download JobID could not be found"}, + }, +) +async def get_download_job( + id: int = Path(description="ID of the download job to fetch."), +) -> DownloadJob: + """Get a download job using its ID.""" + try: + job = ApiDependencies.invoker.services.download_queue.id_to_job(id) + return job + except UnknownJobIDException as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@download_queue_router.delete( + "/i/{id}", + operation_id="cancel_download_job", + responses={ + 204: {"description": "Job has been cancelled"}, + 404: {"description": "The requested download JobID could not be found"}, + }, +) +async def cancel_download_job( + id: int = Path(description="ID of the download job to cancel."), +): + """Cancel a download job using its ID.""" + try: + queue = ApiDependencies.invoker.services.download_queue + job = queue.id_to_job(id) + queue.cancel_job(job) + return Response(status_code=204) + except UnknownJobIDException as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@download_queue_router.delete( + "/i", + operation_id="cancel_all_download_jobs", + responses={ + 204: {"description": "Download jobs have been cancelled"}, + }, +) +async def cancel_all_download_jobs(): + """Cancel all download jobs.""" + ApiDependencies.invoker.services.download_queue.cancel_all_jobs() + return Response(status_code=204) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index ea28cdfe8e..8cbae23399 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -45,6 +45,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c app_info, board_images, boards, + download_queue, images, model_records, models, @@ -116,6 +117,7 @@ app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") app.include_router(models.models_router, prefix="/api") app.include_router(model_records.model_records_router, prefix="/api") +app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api") diff --git a/invokeai/app/services/download/__init__.py b/invokeai/app/services/download/__init__.py new file mode 100644 index 0000000000..04c1dfdb1d --- /dev/null +++ b/invokeai/app/services/download/__init__.py @@ -0,0 +1,12 @@ +"""Init file for download queue.""" +from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException +from .download_default import DownloadQueueService, TqdmProgress + +__all__ = [ + "DownloadJob", + "DownloadQueueServiceBase", + "DownloadQueueService", + "TqdmProgress", + "DownloadJobStatus", + "UnknownJobIDException", +] diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py new file mode 100644 index 0000000000..7ac5425443 --- /dev/null +++ b/invokeai/app/services/download/download_base.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +"""Model download service.""" + +from abc import ABC, abstractmethod +from enum import Enum +from functools import total_ordering +from pathlib import Path +from typing import Any, Callable, List, Optional + +from pydantic import BaseModel, Field, PrivateAttr +from pydantic.networks import AnyHttpUrl + + +class DownloadJobStatus(str, Enum): + """State of a download job.""" + + WAITING = "waiting" # not enqueued, will not run + RUNNING = "running" # actively downloading + COMPLETED = "completed" # finished running + CANCELLED = "cancelled" # user cancelled + ERROR = "error" # terminated with an error message + + +class DownloadJobCancelledException(Exception): + """This exception is raised when a download job is cancelled.""" + + +class UnknownJobIDException(Exception): + """This exception is raised when an invalid job id is referened.""" + + +class ServiceInactiveException(Exception): + """This exception is raised when user attempts to initiate a download before the service is started.""" + + +DownloadEventHandler = Callable[["DownloadJob"], None] + + +@total_ordering +class DownloadJob(BaseModel): + """Class to monitor and control a model download request.""" + + # required variables to be passed in on creation + source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") + dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path") + access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") + # automatically assigned on creation + id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel + priority: int = Field(default=10, description="Queue priority; lower values are higher priority") + + # set internally during download process + status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") + download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file") + job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started") + job_ended: Optional[str] = Field( + default=None, description="Timestamp for when the download job ende1d (completed or errored)" + ) + bytes: int = Field(default=0, description="Bytes downloaded so far") + total_bytes: int = Field(default=0, description="Total file size (bytes)") + + # set when an error occurs + error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error") + error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error") + + # internal flag + _cancelled: bool = PrivateAttr(default=False) + + # optional event handlers passed in on creation + _on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None) + _on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None) + _on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None) + _on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None) + _on_error: Optional[DownloadEventHandler] = PrivateAttr(default=None) + + def __le__(self, other: "DownloadJob") -> bool: + """Return True if this job's priority is less than another's.""" + return self.priority <= other.priority + + def cancel(self) -> None: + """Call to cancel the job.""" + self._cancelled = True + + # cancelled and the callbacks are private attributes in order to prevent + # them from being serialized and/or used in the Json Schema + @property + def cancelled(self) -> bool: + """Call to cancel the job.""" + return self._cancelled + + @property + def on_start(self) -> Optional[DownloadEventHandler]: + """Return the on_start event handler.""" + return self._on_start + + @property + def on_progress(self) -> Optional[DownloadEventHandler]: + """Return the on_progress event handler.""" + return self._on_progress + + @property + def on_complete(self) -> Optional[DownloadEventHandler]: + """Return the on_complete event handler.""" + return self._on_complete + + @property + def on_error(self) -> Optional[DownloadEventHandler]: + """Return the on_error event handler.""" + return self._on_error + + @property + def on_cancelled(self) -> Optional[DownloadEventHandler]: + """Return the on_cancelled event handler.""" + return self._on_cancelled + + def set_callbacks( + self, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadEventHandler] = None, + ) -> None: + """Set the callbacks for download events.""" + self._on_start = on_start + self._on_progress = on_progress + self._on_complete = on_complete + self._on_error = on_error + self._on_cancelled = on_cancelled + + +class DownloadQueueServiceBase(ABC): + """Multithreaded queue for downloading models via URL.""" + + @abstractmethod + def start(self, *args: Any, **kwargs: Any) -> None: + """Start the download worker threads.""" + + @abstractmethod + def stop(self, *args: Any, **kwargs: Any) -> None: + """Stop the download worker threads.""" + + @abstractmethod + def download( + self, + source: AnyHttpUrl, + dest: Path, + priority: int = 10, + access_token: Optional[str] = None, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadEventHandler] = None, + ) -> DownloadJob: + """ + Create a download job. + + :param source: Source of the download as a URL. + :param dest: Path to download to. See below. + :param on_start, on_progress, on_complete, on_error: Callbacks for the indicated + events. + :returns: A DownloadJob object for monitoring the state of the download. + + The `dest` argument is a Path object. Its behavior is: + + 1. If the path exists and is a directory, then the URL contents will be downloaded + into that directory using the filename indicated in the response's `Content-Disposition` field. + If no content-disposition is present, then the last component of the URL will be used (similar to + wget's behavior). + 2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded + content. + 3. If the path exists and is an existing file, then the downloader will try to resume the download from + the end of the existing file. + + """ + pass + + @abstractmethod + def list_jobs(self) -> List[DownloadJob]: + """ + List active download jobs. + + :returns List[DownloadJob]: List of download jobs whose state is not "completed." + """ + pass + + @abstractmethod + def id_to_job(self, id: int) -> DownloadJob: + """ + Return the DownloadJob corresponding to the integer ID. + + :param id: ID of the DownloadJob. + + Exceptions: + * UnknownJobIDException + """ + pass + + @abstractmethod + def cancel_all_jobs(self): + """Cancel all active and enquedjobs.""" + pass + + @abstractmethod + def prune_jobs(self): + """Prune completed and errored queue items from the job list.""" + pass + + @abstractmethod + def cancel_job(self, job: DownloadJob): + """Cancel the job, clearing partial downloads and putting it into ERROR state.""" + pass + + @abstractmethod + def join(self): + """Wait until all jobs are off the queue.""" + pass diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py new file mode 100644 index 0000000000..0f1dca4adc --- /dev/null +++ b/invokeai/app/services/download/download_default.py @@ -0,0 +1,418 @@ +# Copyright (c) 2023, Lincoln D. Stein +"""Implementation of multithreaded download queue for invokeai.""" + +import os +import re +import threading +import traceback +from logging import Logger +from pathlib import Path +from queue import Empty, PriorityQueue +from typing import Any, Dict, List, Optional, Set + +import requests +from pydantic.networks import AnyHttpUrl +from requests import HTTPError +from tqdm import tqdm + +from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.util.misc import get_iso_timestamp +from invokeai.backend.util.logging import InvokeAILogger + +from .download_base import ( + DownloadEventHandler, + DownloadJob, + DownloadJobCancelledException, + DownloadJobStatus, + DownloadQueueServiceBase, + ServiceInactiveException, + UnknownJobIDException, +) + +# Maximum number of bytes to download during each call to requests.iter_content() +DOWNLOAD_CHUNK_SIZE = 100000 + + +class DownloadQueueService(DownloadQueueServiceBase): + """Class for queued download of models.""" + + _jobs: Dict[int, DownloadJob] + _max_parallel_dl: int = 5 + _worker_pool: Set[threading.Thread] + _queue: PriorityQueue[DownloadJob] + _stop_event: threading.Event + _lock: threading.Lock + _logger: Logger + _events: Optional[EventServiceBase] = None + _next_job_id: int = 0 + _accept_download_requests: bool = False + _requests: requests.sessions.Session + + def __init__( + self, + max_parallel_dl: int = 5, + event_bus: Optional[EventServiceBase] = None, + requests_session: Optional[requests.sessions.Session] = None, + ): + """ + Initialize DownloadQueue. + + :param max_parallel_dl: Number of simultaneous downloads allowed [5]. + :param requests_session: Optional requests.sessions.Session object, for unit tests. + """ + self._jobs = {} + self._next_job_id = 0 + self._queue = PriorityQueue() + self._stop_event = threading.Event() + self._worker_pool = set() + self._lock = threading.Lock() + self._logger = InvokeAILogger.get_logger("DownloadQueueService") + self._event_bus = event_bus + self._requests = requests_session or requests.Session() + self._accept_download_requests = False + self._max_parallel_dl = max_parallel_dl + + def start(self, *args: Any, **kwargs: Any) -> None: + """Start the download worker threads.""" + with self._lock: + if self._worker_pool: + raise Exception("Attempt to start the download service twice") + self._stop_event.clear() + self._start_workers(self._max_parallel_dl) + self._accept_download_requests = True + + def stop(self, *args: Any, **kwargs: Any) -> None: + """Stop the download worker threads.""" + with self._lock: + if not self._worker_pool: + raise Exception("Attempt to stop the download service before it was started") + self._accept_download_requests = False # reject attempts to add new jobs to queue + queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING] + active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING] + if queued_jobs: + self._logger.warning(f"Cancelling {len(queued_jobs)} queued downloads") + if active_jobs: + self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete") + with self._queue.mutex: + self._queue.queue.clear() + self.join() # wait for all active jobs to finish + self._stop_event.set() + self._worker_pool.clear() + + def download( + self, + source: AnyHttpUrl, + dest: Path, + priority: int = 10, + access_token: Optional[str] = None, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadEventHandler] = None, + ) -> DownloadJob: + """Create a download job and return its ID.""" + if not self._accept_download_requests: + raise ServiceInactiveException( + "The download service is not currently accepting requests. Please call start() to initialize the service." + ) + with self._lock: + id = self._next_job_id + self._next_job_id += 1 + job = DownloadJob( + id=id, + source=source, + dest=dest, + priority=priority, + access_token=access_token, + ) + job.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + self._jobs[id] = job + self._queue.put(job) + return job + + def join(self) -> None: + """Wait for all jobs to complete.""" + self._queue.join() + + def list_jobs(self) -> List[DownloadJob]: + """List all the jobs.""" + return list(self._jobs.values()) + + def prune_jobs(self) -> None: + """Prune completed and errored queue items from the job list.""" + with self._lock: + to_delete = set() + for job_id, job in self._jobs.items(): + if self._in_terminal_state(job): + to_delete.add(job_id) + for job_id in to_delete: + del self._jobs[job_id] + + def id_to_job(self, id: int) -> DownloadJob: + """Translate a job ID into a DownloadJob object.""" + try: + return self._jobs[id] + except KeyError as excp: + raise UnknownJobIDException("Unrecognized job") from excp + + def cancel_job(self, job: DownloadJob) -> None: + """ + Cancel the indicated job. + + If it is running it will be stopped. + job.status will be set to DownloadJobStatus.CANCELLED + """ + with self._lock: + job.cancel() + + def cancel_all_jobs(self, preserve_partial: bool = False) -> None: + """Cancel all jobs (those not in enqueued, running or paused state).""" + for job in self._jobs.values(): + if not self._in_terminal_state(job): + self.cancel_job(job) + + def _in_terminal_state(self, job: DownloadJob) -> bool: + return job.status in [ + DownloadJobStatus.COMPLETED, + DownloadJobStatus.CANCELLED, + DownloadJobStatus.ERROR, + ] + + def _start_workers(self, max_workers: int) -> None: + """Start the requested number of worker threads.""" + self._stop_event.clear() + for i in range(0, max_workers): # noqa B007 + worker = threading.Thread(target=self._download_next_item, daemon=True) + self._logger.debug(f"Download queue worker thread {worker.name} starting.") + worker.start() + self._worker_pool.add(worker) + + def _download_next_item(self) -> None: + """Worker thread gets next job on priority queue.""" + done = False + while not done: + if self._stop_event.is_set(): + done = True + continue + try: + job = self._queue.get(timeout=1) + except Empty: + continue + + try: + job.job_started = get_iso_timestamp() + self._do_download(job) + self._signal_job_complete(job) + + except (OSError, HTTPError) as excp: + job.error_type = excp.__class__.__name__ + f"({str(excp)})" + job.error = traceback.format_exc() + self._signal_job_error(job) + except DownloadJobCancelledException: + self._signal_job_cancelled(job) + self._cleanup_cancelled_job(job) + + finally: + job.job_ended = get_iso_timestamp() + self._queue.task_done() + self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") + + def _do_download(self, job: DownloadJob) -> None: + """Do the actual download.""" + url = job.source + header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} + open_mode = "wb" + + # Make a streaming request. This will retrieve headers including + # content-length and content-disposition, but not fetch any content itself + resp = self._requests.get(str(url), headers=header, stream=True) + if not resp.ok: + raise HTTPError(resp.reason) + content_length = int(resp.headers.get("content-length", 0)) + job.total_bytes = content_length + + if job.dest.is_dir(): + file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL + + if match := re.search('filename="(.+)"', resp.headers.get("Content-Disposition", "")): + remote_name = match.group(1) + if self._validate_filename(job.dest.as_posix(), remote_name): + file_name = remote_name + + job.download_path = job.dest / file_name + + else: + job.dest.parent.mkdir(parents=True, exist_ok=True) + job.download_path = job.dest + + assert job.download_path + + # Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af + # for code that instead resumes an interrupted download. + if job.download_path.exists(): + raise OSError(f"[Errno 17] File {job.download_path} exists") + + # append ".downloading" to the path + in_progress_path = self._in_progress_path(job.download_path) + + # signal caller that the download is starting. At this point, key fields such as + # download_path and total_bytes will be populated. We call it here because the might + # discover that the local file is already complete and generate a COMPLETED status. + self._signal_job_started(job) + + # "range not satisfiable" - local file is at least as large as the remote file + if resp.status_code == 416 or (content_length > 0 and job.bytes >= content_length): + self._logger.warning(f"{job.download_path}: complete file found. Skipping.") + return + + # "partial content" - local file is smaller than remote file + elif resp.status_code == 206 or job.bytes > 0: + self._logger.warning(f"{job.download_path}: partial file found. Resuming") + + # some other error + elif resp.status_code != 200: + raise HTTPError(resp.reason) + + self._logger.debug(f"{job.source}: Downloading {job.download_path}") + report_delta = job.total_bytes / 100 # report every 1% change + last_report_bytes = 0 + + # DOWNLOAD LOOP + with open(in_progress_path, open_mode) as file: + for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): + if job.cancelled: + raise DownloadJobCancelledException("Job was cancelled at caller's request") + + job.bytes += file.write(data) + if (job.bytes - last_report_bytes >= report_delta) or (job.bytes >= job.total_bytes): + last_report_bytes = job.bytes + self._signal_job_progress(job) + + # if we get here we are done and can rename the file to the original dest + in_progress_path.rename(job.download_path) + + def _validate_filename(self, directory: str, filename: str) -> bool: + pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows + pc_path_max = ( + os.pathconf(directory, "PC_PATH_MAX") if hasattr(os, "pathconf") else 32767 + ) # hardcoded for windows with long names enabled + if "/" in filename: + return False + if filename.startswith(".."): + return False + if len(filename) > pc_name_max: + return False + if len(os.path.join(directory, filename)) > pc_path_max: + return False + return True + + def _in_progress_path(self, path: Path) -> Path: + return path.with_name(path.name + ".downloading") + + def _signal_job_started(self, job: DownloadJob) -> None: + job.status = DownloadJobStatus.RUNNING + if job.on_start: + try: + job.on_start(job) + except Exception as e: + self._logger.error(e) + if self._event_bus: + assert job.download_path + self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix()) + + def _signal_job_progress(self, job: DownloadJob) -> None: + if job.on_progress: + try: + job.on_progress(job) + except Exception as e: + self._logger.error(e) + if self._event_bus: + assert job.download_path + self._event_bus.emit_download_progress( + str(job.source), + download_path=job.download_path.as_posix(), + current_bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + def _signal_job_complete(self, job: DownloadJob) -> None: + job.status = DownloadJobStatus.COMPLETED + if job.on_complete: + try: + job.on_complete(job) + except Exception as e: + self._logger.error(e) + if self._event_bus: + assert job.download_path + self._event_bus.emit_download_complete( + str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes + ) + + def _signal_job_cancelled(self, job: DownloadJob) -> None: + job.status = DownloadJobStatus.CANCELLED + if job.on_cancelled: + try: + job.on_cancelled(job) + except Exception as e: + self._logger.error(e) + if self._event_bus: + self._event_bus.emit_download_cancelled(str(job.source)) + + def _signal_job_error(self, job: DownloadJob) -> None: + job.status = DownloadJobStatus.ERROR + if job.on_error: + try: + job.on_error(job) + except Exception as e: + self._logger.error(e) + if self._event_bus: + assert job.error_type + assert job.error + self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error) + + def _cleanup_cancelled_job(self, job: DownloadJob) -> None: + self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.download_path}") + try: + if job.download_path: + partial_file = self._in_progress_path(job.download_path) + partial_file.unlink() + except OSError as excp: + self._logger.warning(excp) + + +# Example on_progress event handler to display a TQDM status bar +# Activate with: +# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update +class TqdmProgress(object): + """TQDM-based progress bar object to use in on_progress handlers.""" + + _bars: Dict[int, tqdm] # the tqdm object + _last: Dict[int, int] # last bytes downloaded + + def __init__(self) -> None: # noqa D107 + self._bars = {} + self._last = {} + + def update(self, job: DownloadJob) -> None: # noqa D102 + job_id = job.id + # new job + if job_id not in self._bars: + assert job.download_path + dest = Path(job.download_path).name + self._bars[job_id] = tqdm( + desc=dest, + initial=0, + total=job.total_bytes, + unit="iB", + unit_scale=True, + ) + self._last[job_id] = 0 + self._bars[job_id].update(job.bytes - self._last[job_id]) + self._last[job_id] = job.bytes diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 93b84afaf1..16e7d72b2a 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy class EventServiceBase: queue_event: str = "queue_event" + download_event: str = "download_event" model_event: str = "model_event" """Basic event bus, to have an empty stand-in when not needed""" @@ -32,6 +33,13 @@ class EventServiceBase: payload={"event": event_name, "data": payload}, ) + def __emit_download_event(self, event_name: str, payload: dict) -> None: + payload["timestamp"] = get_timestamp() + self.dispatch( + event_name=EventServiceBase.download_event, + payload={"event": event_name, "data": payload}, + ) + def __emit_model_event(self, event_name: str, payload: dict) -> None: payload["timestamp"] = get_timestamp() self.dispatch( @@ -323,6 +331,79 @@ class EventServiceBase: payload={"queue_id": queue_id}, ) + def emit_download_started(self, source: str, download_path: str) -> None: + """ + Emit when a download job is started. + + :param url: The downloaded url + """ + self.__emit_download_event( + event_name="download_started", + payload={"source": source, "download_path": download_path}, + ) + + def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: + """ + Emit "download_progress" events at regular intervals during a download job. + + :param source: The downloaded source + :param download_path: The local downloaded file + :param current_bytes: Number of bytes downloaded so far + :param total_bytes: The size of the file being downloaded (if known) + """ + self.__emit_download_event( + event_name="download_progress", + payload={ + "source": source, + "download_path": download_path, + "current_bytes": current_bytes, + "total_bytes": total_bytes, + }, + ) + + def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: + """ + Emit a "download_complete" event at the end of a successful download. + + :param source: Source URL + :param download_path: Path to the locally downloaded file + :param total_bytes: The size of the downloaded file + """ + self.__emit_download_event( + event_name="download_complete", + payload={ + "source": source, + "download_path": download_path, + "total_bytes": total_bytes, + }, + ) + + def emit_download_cancelled(self, source: str) -> None: + """Emit a "download_cancelled" event in the event that the download was cancelled by user.""" + self.__emit_download_event( + event_name="download_cancelled", + payload={ + "source": source, + }, + ) + + def emit_download_error(self, source: str, error_type: str, error: str) -> None: + """ + Emit a "download_error" event when an download job encounters an exception. + + :param source: Source URL + :param error_type: The name of the exception that raised the error + :param error: The traceback from this error + """ + self.__emit_download_event( + event_name="download_error", + payload={ + "source": source, + "error_type": error_type, + "error": error, + }, + ) + def emit_model_install_started(self, source: str) -> None: """ Emitted when an install job is started. diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index d99a9aff25..11a4de99d6 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from .board_records.board_records_base import BoardRecordStorageBase from .boards.boards_base import BoardServiceABC from .config import InvokeAIAppConfig + from .download import DownloadQueueServiceBase from .events.events_base import EventServiceBase from .image_files.image_files_base import ImageFileStorageBase from .image_records.image_records_base import ImageRecordStorageBase @@ -50,6 +51,7 @@ class InvocationServices: logger: "Logger" model_manager: "ModelManagerServiceBase" model_records: "ModelRecordServiceBase" + download_queue: "DownloadQueueServiceBase" model_install: "ModelInstallServiceBase" processor: "InvocationProcessorABC" performance_statistics: "InvocationStatsServiceBase" @@ -77,6 +79,7 @@ class InvocationServices: logger: "Logger", model_manager: "ModelManagerServiceBase", model_records: "ModelRecordServiceBase", + download_queue: "DownloadQueueServiceBase", model_install: "ModelInstallServiceBase", processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", @@ -102,6 +105,7 @@ class InvocationServices: self.logger = logger self.model_manager = model_manager self.model_records = model_records + self.download_queue = download_queue self.model_install = model_install self.processor = processor self.performance_statistics = performance_statistics diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 80b493d02e..3146b5350a 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -11,7 +11,6 @@ from typing_extensions import Annotated from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events import EventServiceBase -from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig @@ -157,12 +156,12 @@ class ModelInstallServiceBase(ABC): :param event_bus: InvokeAI event bus for reporting events to. """ - def start(self, invoker: Invoker) -> None: - """Call at InvokeAI startup time.""" - self.sync_to_config() + @abstractmethod + def start(self, *args: Any, **kwarg: Any) -> None: + """Start the installer service.""" @abstractmethod - def stop(self) -> None: + def stop(self, *args: Any, **kwarg: Any) -> None: """Stop the model install service. After this the objection can be safely deleted.""" @property diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 70cc4d5018..3dcb7c527e 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -71,7 +71,6 @@ class ModelInstallService(ModelInstallServiceBase): self._install_queue = Queue() self._cached_model_paths = set() self._models_installed = set() - self._start_installer_thread() @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 @@ -85,8 +84,13 @@ class ModelInstallService(ModelInstallServiceBase): def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 return self._event_bus - def stop(self, *args, **kwargs) -> None: - """Stop the install thread; after this the object can be deleted and garbage collected.""" + def start(self, *args: Any, **kwarg: Any) -> None: + """Start the installer thread.""" + self._start_installer_thread() + self.sync_to_config() + + def stop(self, *args: Any, **kwarg: Any) -> None: + """Stop the installer thread; after this the object can be deleted and garbage collected.""" self._install_queue.put(STOP_JOB) def _start_installer_thread(self) -> None: diff --git a/mkdocs.yml b/mkdocs.yml index 7c67a2777a..c8875c0ff1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -172,6 +172,8 @@ nav: - Adding Tests: 'contributing/TESTS.md' - Documentation: 'contributing/contribution_guides/documentation.md' - Nodes: 'contributing/INVOCATIONS.md' + - Model Manager: 'contributing/MODEL_MANAGER.md' + - Download Queue: 'contributing/DOWNLOAD_QUEUE.md' - Translation: 'contributing/contribution_guides/translation.md' - Tutorials: 'contributing/contribution_guides/tutorials.md' - Changelog: 'CHANGELOG.md' diff --git a/pyproject.toml b/pyproject.toml index 98018dc7cb..52fbdd01a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ dependencies = [ "pytest>6.0.0", "pytest-cov", "pytest-datadir", + "requests_testadapter", ] "xformers" = [ "xformers==0.0.23; sys_platform!='darwin'", diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index b7bf7771a6..bb31161426 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -68,6 +68,7 @@ def mock_services() -> InvocationServices: logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore + download_queue=None, # type: ignore model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 63908193fd..d4959282a1 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -74,6 +74,7 @@ def mock_services() -> InvocationServices: logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore + download_queue=None, # type: ignore model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py new file mode 100644 index 0000000000..6e36af75ce --- /dev/null +++ b/tests/app/services/download/test_download_queue.py @@ -0,0 +1,223 @@ +"""Test the queued download facility""" +import re +import time +from pathlib import Path +from typing import Any, Dict, List + +import pytest +import requests +from pydantic import BaseModel +from pydantic.networks import AnyHttpUrl +from requests.sessions import Session +from requests_testadapter import TestAdapter + +from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService +from invokeai.app.services.events.events_base import EventServiceBase + +# Prevent pytest deprecation warnings +TestAdapter.__test__ = False + + +@pytest.fixture +def session() -> requests.sessions.Session: + sess = requests.Session() + for i in ["12345", "9999", "54321"]: + content = ( + b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) + ) # for pause tests, must make content large + sess.mount( + f"http://www.civitai.com/models/{i}", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": f'filename="mock{i}.safetensors"', + }, + ), + ) + + # here are some malformed URLs to test + # missing the content length + sess.mount( + "http://www.civitai.com/models/missing", + TestAdapter( + b"Missing content length", + headers={ + "Content-Disposition": 'filename="missing.txt"', + }, + ), + ) + # not found test + sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) + + return sess + + +class DummyEvent(BaseModel): + """Dummy Event to use with Dummy Event service.""" + + event_name: str + payload: Dict[str, Any] + + +# A dummy event service for testing event issuing +class DummyEventService(EventServiceBase): + """Dummy event service for testing.""" + + events: List[DummyEvent] + + def __init__(self) -> None: + super().__init__() + self.events = [] + + def dispatch(self, event_name: str, payload: Any) -> None: + """Dispatch an event by appending it to self.events.""" + self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) + + +def test_basic_queue_download(tmp_path: Path, session: Session) -> None: + events = set() + + def event_handler(job: DownloadJob) -> None: + events.add(job.status) + + queue = DownloadQueueService( + requests_session=session, + ) + queue.start() + job = queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase" + assert isinstance(job.id, int), "expected the job id to be numeric" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" + + assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} + queue.stop() + + +def test_errors(tmp_path: Path, session: Session) -> None: + queue = DownloadQueueService( + requests_session=session, + ) + queue.start() + + for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]: + queue.download(AnyHttpUrl(bad_url), dest=tmp_path) + + queue.join() + jobs = queue.list_jobs() + print(jobs) + assert len(jobs) == 2 + jobs_dict = {str(x.source): x for x in jobs} + assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR + assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)" + assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED + assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0 + queue.stop() + + +def test_event_bus(tmp_path: Path, session: Session) -> None: + event_bus = DummyEventService() + + queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue.start() + queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + ) + queue.join() + events = event_bus.events + assert len(events) == 3 + assert events[0].payload["timestamp"] <= events[1].payload["timestamp"] + assert events[1].payload["timestamp"] <= events[2].payload["timestamp"] + assert events[0].event_name == "download_started" + assert events[1].event_name == "download_progress" + assert events[1].payload["total_bytes"] > 0 + assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"] + assert events[2].event_name == "download_complete" + assert events[2].payload["total_bytes"] == 32029 + + # test a failure + event_bus.events = [] # reset our accumulator + queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path) + queue.join() + events = event_bus.events + print("\n".join([x.model_dump_json() for x in events])) + assert len(events) == 1 + assert events[0].event_name == "download_error" + assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)" + assert events[0].payload["error"] is not None + assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"]) + queue.stop() + + +def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, capsys) -> None: + queue = DownloadQueueService( + requests_session=session, + ) + queue.start() + + callback_ran = False + + def broken_callback(job: DownloadJob) -> None: + nonlocal callback_ran + callback_ran = True + print(1 / 0) # deliberate error here + + job = queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + on_progress=broken_callback, + ) + + queue.join() + assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked + assert Path(tmp_path, "mock12345.safetensors").exists() + assert callback_ran + # LS: The pytest capsys fixture does not seem to be working. I can see the + # correct stderr message in the pytest log, but it is not appearing in + # capsys.readouterr(). + # captured = capsys.readouterr() + # assert re.search("division by zero", captured.err) + queue.stop() + + +def test_cancel(tmp_path: Path, session: requests.sessions.Session) -> None: + event_bus = DummyEventService() + + queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue.start() + + cancelled = False + + def slow_callback(job: DownloadJob) -> None: + time.sleep(2) + + def cancelled_callback(job: DownloadJob) -> None: + nonlocal cancelled + cancelled = True + + job = queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + on_start=slow_callback, + on_cancelled=cancelled_callback, + ) + queue.cancel_job(job) + queue.join() + + assert job.status == DownloadJobStatus.CANCELLED + assert cancelled + events = event_bus.events + assert events[-1].event_name == "download_cancelled" + assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" + queue.stop() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 310bcfa0c1..9010f9f296 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -48,11 +48,13 @@ def store( @pytest.fixture def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase: - return ModelInstallService( + installer = ModelInstallService( app_config=app_config, record_store=store, event_bus=DummyEventService(), ) + installer.start() + return installer class DummyEvent(BaseModel):