From f454304c9149bc98f0d2cbb84ac0efd1106cbc26 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 10 Sep 2023 17:20:47 -0400 Subject: [PATCH] make it possible to pause/resume repo_id downloads --- invokeai/app/services/download_manager.py | 31 ++--- .../backend/model_manager/download/base.py | 28 +++-- .../backend/model_manager/download/queue.py | 111 ++++++++++++++---- invokeai/backend/model_manager/install.py | 110 ++++++++++++----- tests/test_model_download.py | 9 +- 5 files changed, 208 insertions(+), 81 deletions(-) diff --git a/invokeai/app/services/download_manager.py b/invokeai/app/services/download_manager.py index 1be47cb89c..ab51510d0e 100644 --- a/invokeai/app/services/download_manager.py +++ b/invokeai/app/services/download_manager.py @@ -5,7 +5,8 @@ Model download service. from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, List +from typing import Optional, List, Union +from pydantic.networks import AnyHttpUrl from .events import EventServiceBase from invokeai.backend.model_manager.download import DownloadQueue, DownloadJobBase, DownloadEventHandler @@ -16,7 +17,7 @@ class DownloadQueueServiceBase(ABC): @abstractmethod def create_download_job( self, - source: str, + source: Union[str, Path, AnyHttpUrl], destdir: Path, filename: Optional[Path] = None, start: bool = True, @@ -26,7 +27,7 @@ class DownloadQueueServiceBase(ABC): """ Create a download job. - :param source: Source of the download - URL or repo_id + :param source: Source of the download - URL, repo_id or local Path :param destdir: Directory to download into. :param filename: Optional name of file, if not provided will use the content-disposition field to assign the name. @@ -126,13 +127,13 @@ class DownloadQueueService(DownloadQueueServiceBase): def create_download_job( self, - source: str, + source: Union[str, Path, AnyHttpUrl], destdir: Path, filename: Optional[Path] = None, start: bool = True, access_token: Optional[str] = None, event_handlers: Optional[List[DownloadEventHandler]] = None, - ) -> DownloadJobBase: + ) -> DownloadJobBase: # noqa D102 event_handlers = event_handlers or [] if self._event_bus: event_handlers.append([self._event_bus.emit_model_download_event]) @@ -145,32 +146,32 @@ class DownloadQueueService(DownloadQueueServiceBase): event_handlers=event_handlers, ) - def list_jobs(self) -> List[DownloadJobBase]: + def list_jobs(self) -> List[DownloadJobBase]: # noqa D102 return self._queue.list_jobs() - def id_to_job(self, id: int) -> DownloadJobBase: + def id_to_job(self, id: int) -> DownloadJobBase: # noqa D102 return self._queue.id_to_job(id) - def start_all_jobs(self): + def start_all_jobs(self): # noqa D102 return self._queue.start_all_jobs() - def pause_all_jobs(self): + def pause_all_jobs(self): # noqa D102 return self._queue.pause_all_jobs() - def cancel_all_jobs(self): + def cancel_all_jobs(self): # noqa D102 return self._queue.cancel_all_jobs() - def start_job(self, job: DownloadJobBase): + def start_job(self, job: DownloadJobBase): # noqa D102 return self._queue.start_job(id) - def pause_job(self, job: DownloadJobBase): + def pause_job(self, job: DownloadJobBase): # noqa D102 return self._queue.pause_job(id) - def cancel_job(self, job: DownloadJobBase): + def cancel_job(self, job: DownloadJobBase): # noqa D102 return self._queue.cancel_job(id) - def change_priority(self, job: DownloadJobBase, delta: int): + def change_priority(self, job: DownloadJobBase, delta: int): # noqa D102 return self._queue.change_priority(id, delta) - def join(self): + def join(self): # noqa D102 return self._queue.join() diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 2ab9140d19..5d3defc288 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from enum import Enum from functools import total_ordering from pathlib import Path -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Union from pydantic import BaseModel, Field from pydantic.networks import AnyHttpUrl @@ -63,6 +63,12 @@ class DownloadJobBase(BaseModel): job_sequence: Optional[int] = Field( description="Counter that records order in which this job was dequeued (for debugging)" ) + subqueue: Optional["DownloadQueueBase"] = Field( + description="a subqueue used for downloading repo_ids", default=None + ) + preserve_partial_downloads: bool = Field( + description="if true, then preserve partial downloads when cancelled or errored", default=False + ) error: Optional[Exception] = Field(default=None, description="Exception that caused an error") def add_event_handler(self, handler: DownloadEventHandler): @@ -96,7 +102,7 @@ class DownloadQueueBase(ABC): @abstractmethod def create_download_job( self, - source: str, + source: Union[str, Path, AnyHttpUrl], destdir: Path, filename: Optional[Path] = None, start: bool = True, @@ -107,7 +113,7 @@ class DownloadQueueBase(ABC): """ Create a download job. - :param source: Source of the download - URL or repo_id + :param source: Source of the download - URL, repo_id or Path :param destdir: Directory to download into. :param filename: Optional name of file, if not provided will use the content-disposition field to assign the name. @@ -165,8 +171,12 @@ class DownloadQueueBase(ABC): pass @abstractmethod - def cancel_all_jobs(self): - """Cancel all active and enquedjobs.""" + def cancel_all_jobs(self, preserve_partial: bool = False): + """ + Cancel all active and enquedjobs. + + :param preserve_partial: Keep partially downloaded files [False]. + """ pass @abstractmethod @@ -180,8 +190,12 @@ class DownloadQueueBase(ABC): pass @abstractmethod - def cancel_job(self, job: DownloadJobBase): - """Cancel the job, clearing partial downloads and putting it into ERROR state.""" + def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): + """ + Cancel the job, clearing partial downloads and putting it into CANCELLED state. + + :param preserve_partial: Keep partial downloads [False] + """ pass @abstractmethod diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index eff137ac67..7d6e8e551d 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -12,7 +12,7 @@ import traceback from json import JSONDecodeError from pathlib import Path from requests import HTTPError -from typing import Dict, Optional, Set, List, Tuple +from typing import Dict, Optional, Set, List, Tuple, Union from pydantic import Field, validator, ValidationError from pydantic.networks import AnyHttpUrl @@ -32,6 +32,9 @@ from .base import ( ) from ..storage import DuplicateModelException +# Maximum number of bytes to download during each call to requests.iter_content() +DOWNLOAD_CHUNK_SIZE = 100000 + # marker that the queue is done and that thread should exit STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/") @@ -61,6 +64,12 @@ class DownloadJobRepoID(DownloadJobBase): return v +class DownloadJobPath(DownloadJobBase): + """Handle file paths.""" + + source: Path = Field(description="Path to a file or directory to install") + + class DownloadQueue(DownloadQueueBase): """Class for queued download of models.""" @@ -74,6 +83,10 @@ class DownloadQueue(DownloadQueueBase): _sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order _requests: requests.sessions.Session + # for debugging + _gets: int = 0 + _dones: int = 0 + def __init__( self, max_parallel_dl: int = 5, @@ -99,9 +112,13 @@ class DownloadQueue(DownloadQueueBase): self._start_workers(max_parallel_dl) + # debugging - get rid of this + self._gets = 0 + self._dones = 0 + def create_download_job( self, - source: str, + source: Union[str, Path, AnyHttpUrl], destdir: Path, filename: Optional[Path] = None, start: bool = True, @@ -110,12 +127,18 @@ class DownloadQueue(DownloadQueueBase): event_handlers: Optional[List[DownloadEventHandler]] = None, ) -> DownloadJobBase: """Create a download job and return its ID.""" - if re.match(r"^[\w-]+/[\w-]+$", source): + kwargs = dict() + + if Path(source).exists(): + cls = DownloadJobPath + elif re.match(r"^[\w-]+/[\w-]+$", str(source)): cls = DownloadJobRepoID kwargs = dict(variant=variant) - else: + elif re.match(r"^https?://", str(source)): cls = DownloadJobURL - kwargs = dict() + else: + raise NotImplementedError(f"Don't know what to do with this type of source: {source}") + try: self._lock.acquire() id = self._next_job_id @@ -160,7 +183,7 @@ class DownloadQueue(DownloadQueueBase): finally: self._lock.release() - def cancel_job(self, job: DownloadJobBase): + def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): """ Cancel the indicated job. @@ -170,7 +193,10 @@ class DownloadQueue(DownloadQueueBase): try: self._lock.acquire() assert isinstance(self._jobs[job.id], DownloadJobBase) + job.preserve_partial_downloads = preserve_partial self._update_job_status(job, DownloadJobStatus.CANCELLED) + if job.subqueue: + job.subqueue.cancel_all_jobs(preserve_partial=preserve_partial) except (AssertionError, KeyError) as excp: raise UnknownJobIDException("Unrecognized job") from excp finally: @@ -196,13 +222,16 @@ class DownloadQueue(DownloadQueueBase): """ Pause (dequeue) the indicated job. - In theory the job can be restarted and the download will pick up + The job can be restarted with start_job() and the download will pick up from where it left off. """ try: self._lock.acquire() assert isinstance(self._jobs[job.id], DownloadJobBase) self._update_job_status(job, DownloadJobStatus.PAUSED) + if job.subqueue: + job.subqueue.cancel_all_jobs(preserve_partial=True) + job.subqueue.release() except (AssertionError, KeyError) as excp: raise UnknownJobIDException("Unrecognized job") from excp finally: @@ -213,7 +242,7 @@ class DownloadQueue(DownloadQueueBase): try: self._lock.acquire() for job in self._jobs.values(): - if job.status in [DownloadJobStatus.IDLE or DownloadJobStatus.PAUSED]: + if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]: self.start_job(job) finally: self._lock.release() @@ -222,19 +251,19 @@ class DownloadQueue(DownloadQueueBase): """Pause all running jobs.""" try: self._lock.acquire() - for id, job in self._jobs.items(): - if job.stats == DownloadJobStatus.RUNNING: - self.pause_job(id) + for job in self._jobs.values(): + if job.status == DownloadJobStatus.RUNNING: + self.pause_job(job) finally: self._lock.release() - def cancel_all_jobs(self): + def cancel_all_jobs(self, preserve_partial: bool = False): """Cancel all running jobs.""" try: self._lock.acquire() - for id, job in self._jobs.items(): + for job in self._jobs.values(): if not self._in_terminal_state(job): - self.cancel_job(id) + self.cancel_job(job, preserve_partial) finally: self._lock.release() @@ -254,10 +283,12 @@ class DownloadQueue(DownloadQueueBase): def _download_next_item(self): """Worker thread gets next job on priority queue.""" - while True: + done = False + while not done: job = self._queue.get() + self._gets += 1 - try: + try: # this is for debugging priority self._lock.acquire() job.job_sequence = self._sequence self._sequence += 1 @@ -265,13 +296,16 @@ class DownloadQueue(DownloadQueueBase): self._lock.release() if job == STOP_JOB: # marker that queue is done - break + done = True if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen) + # There should be a better way to dispatch on the job type if isinstance(job, DownloadJobURL): self._download_with_resume(job) elif isinstance(job, DownloadJobRepoID): self._download_repoid(job) + elif isinstance(job, DownloadJobPath): + self._download_path(job) else: raise NotImplementedError(f"Don't know what to do with this job: {job}") @@ -280,6 +314,8 @@ class DownloadQueue(DownloadQueueBase): if self._in_terminal_state(job): del self._jobs[job.id] + + self._dones += 1 self._queue.task_done() def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]: @@ -375,7 +411,7 @@ class DownloadQueue(DownloadQueueBase): if resp.status_code == 206 or exist_size > 0: self._logger.warning(f"{dest}: partial file found. Resuming") elif resp.status_code != 200: - raise HTTPError(f"status code {resp.status_code}: {resp.reason}") + raise HTTPError(resp.reason) else: self._logger.info(f"{job.source}: Downloading {job.destination}") @@ -384,7 +420,7 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job, DownloadJobStatus.RUNNING) with open(dest, open_mode) as file: - for data in resp.iter_content(chunk_size=100000): + for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored return job.bytes += file.write(data) @@ -393,6 +429,8 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job) self._update_job_status(job, DownloadJobStatus.COMPLETED) + except KeyboardInterrupt as excp: + raise excp except DuplicateModelException as excp: self._logger.error(f"A model with the same hash as {dest} is already installed.") job.error = excp @@ -426,6 +464,8 @@ class DownloadQueue(DownloadQueueBase): for handler in job.event_handlers: try: handler(job) + except KeyboardInterrupt as excp: + raise excp except Exception as excp: job.error = excp self._update_job_status(job, DownloadJobStatus.ERROR) @@ -442,7 +482,7 @@ class DownloadQueue(DownloadQueueBase): if subjob.status == DownloadJobStatus.ERROR: job.error = subjob.error - subqueue.cancel_all_jobs() + subjob.subqueue.cancel_all_jobs() self._update_job_status(job, DownloadJobStatus.ERROR) return @@ -452,7 +492,7 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job, DownloadJobStatus.RUNNING) return - subqueue = self.__class__( + job.subqueue = self.__class__( event_handlers=[subdownload_event], requests_session=self._requests, ) @@ -460,28 +500,32 @@ class DownloadQueue(DownloadQueueBase): repo_id = job.source variant = job.variant urls_to_download, metadata = self._get_repo_info(repo_id, variant) - job.destination = job.destination / Path(repo_id).name + if job.destination.stem != Path(repo_id).stem: + job.destination = job.destination / Path(repo_id).stem job.metadata = metadata bytes_downloaded = dict() + job.total_bytes = 0 for url, subdir, file, size in urls_to_download: job.total_bytes += size - subqueue.create_download_job( + job.subqueue.create_download_job( source=url, destdir=job.destination / subdir, filename=file, variant=variant, access_token=job.access_token, ) + except KeyboardInterrupt as excp: + raise excp except Exception as excp: job.error = excp self._update_job_status(job, DownloadJobStatus.ERROR) self._logger.error(job.error) finally: - subqueue.join() - if not job.status == DownloadJobStatus.ERROR: + job.subqueue.join() + if job.status == DownloadJobStatus.RUNNING: self._update_job_status(job, DownloadJobStatus.COMPLETED) - subqueue.release() # get rid of the subqueue + job.subqueue.release() # get rid of the subqueue def _get_repo_info( self, @@ -543,7 +587,22 @@ class DownloadQueue(DownloadQueueBase): result.add(v) return result + def _download_path(self, job: DownloadJobBase): + """Call when the source is a Path or pathlike object.""" + source = Path(job.source).resolve() + destination = Path(job.destination).resolve() + job.metadata = ModelSourceMetadata() + try: + if source != destination: + shutil.move(source, destination) + self._update_job_status(job, DownloadJobStatus.COMPLETED) + except OSError as excp: + job.error = excp + self._update_job_status(job, DownloadJobStatus.ERROR) + def _cleanup_cancelled_job(self, job: DownloadJobBase): + if job.preserve_partial_downloads: + return self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}") dest = Path(job.destination) if dest.is_file(): diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 4dac223fcf..740eae0da0 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -15,17 +15,18 @@ Typical usage: installer = ModelInstall(store=store, config=config, download=download) # register config, don't move path - id: str = installer.register_model('/path/to/model') + id: str = installer.register_path('/path/to/model') # register config, and install model in `models` - id: str = installer.install_model('/path/to/model') + id: str = installer.install_path('/path/to/model') # download some remote models and install them in the background - installer.download('stabilityai/stable-diffusion-2-1') - installer.download('https://civitai.com/api/download/models/154208') - installer.download('runwayml/stable-diffusion-v1-5') + installer.install('stabilityai/stable-diffusion-2-1') + installer.install('https://civitai.com/api/download/models/154208') + installer.install('runwayml/stable-diffusion-v1-5') + installer.install('/home/user/models/stable-diffusion-v1-5', inplace=True) - installed_ids = installer.wait_for_downloads() + installed_ids = installer.wait_for_installs() id1 = installed_ids['stabilityai/stable-diffusion-2-1'] id2 = installed_ids['https://civitai.com/api/download/models/154208'] @@ -94,8 +95,14 @@ class ModelInstallBase(ABC): """ pass + @property @abstractmethod - def register(self, model_path: Union[Path, str]) -> str: + def queue(self) -> DownloadQueueBase: + """Return the download queue used by the installer.""" + pass + + @abstractmethod + def register_path(self, model_path: Union[Path, str]) -> str: """ Probe and register the model at model_path. @@ -105,7 +112,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def install(self, model_path: Union[Path, str]) -> str: + def install_path(self, model_path: Union[Path, str]) -> str: """ Probe, register and install the model in the models directory. @@ -118,9 +125,11 @@ class ModelInstallBase(ABC): pass @abstractmethod - def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase: + def install( + self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None + ) -> DownloadJobBase: """ - Download and install the model located at remote site. + Download and install the indicated model. This will download the model located at `source`, probe it, and install it into the models directory. @@ -128,18 +137,25 @@ class ModelInstallBase(ABC): thread, and the returned object is a invokeai.backend.model_manager.download.DownloadJobBase object which can be interrogated to get the status of - the download and install process. Call our `wait_for_downloads()` - method to wait for all downloads to complete. + the download and install process. Call our `wait_for_installs()` + method to wait for all downloads and installations to complete. :param source: Either a URL or a HuggingFace repo_id. - :returns queue: DownloadQueueBase object. + :param inplace: If True, local paths will not be moved into + the models directory, but registered in place (the default). + :param variant: For HuggingFace models, this optional parameter + specifies which variant to download (e.g. 'fp16') + :returns DownloadQueueBase object. + + The `inplace` flag does not affect the behavior of downloaded + models, which are always moved into the `models` directory. """ pass @abstractmethod - def wait_for_downloads(self) -> Dict[str, str]: + def wait_for_installs(self) -> Dict[str, str]: """ - Wait for all pending downloads to complete. + Wait for all pending installs to complete. This will block until all pending downloads have completed, been cancelled, or errored out. It will @@ -147,7 +163,7 @@ class ModelInstallBase(ABC): paused state. It will return a dict that maps the source model - URL or repo_id to the ID of the installed model. + path, URL or repo_id to the ID of the installed model. """ pass @@ -259,7 +275,12 @@ class ModelInstall(ModelInstallBase): self._async_installs = dict() self._tmpdir = None - def register(self, model_path: Union[Path, str]) -> str: # noqa D102 + @property + def queue(self) -> DownloadQueueBase: + """Return the queue.""" + return self._download_queue + + def register_path(self, model_path: Union[Path, str]) -> str: # noqa D102 model_path = Path(model_path) info: ModelProbeInfo = ModelProbe.probe(model_path) return self._register(model_path, info) @@ -293,7 +314,7 @@ class ModelInstall(ModelInstallBase): self._store.add_model(id, registration_data) return id - def install(self, model_path: Union[Path, str]) -> str: # noqa D102 + def install_path(self, model_path: Union[Path, str]) -> str: # noqa D102 model_path = Path(model_path) info: ModelProbeInfo = ModelProbe.probe(model_path) dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name @@ -318,20 +339,21 @@ class ModelInstall(ModelInstallBase): rmtree(model.path) self.unregister(id) - def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase: # noqa D102 + def install( + self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None + ) -> DownloadJobBase: # noqa D102 # choose a temporary directory inside the models directory models_dir = self._config.models_path queue = self._download_queue - self._async_installs[source] = None def complete_installation(job: DownloadJobBase): if job.status == "completed": self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.") - model_id = self.install(job.destination) + model_id = self.install_path(job.destination) info = self._store.get_model(model_id) info.source = str(job.source) metadata: ModelSourceMetadata = job.metadata - info.description = metadata.description or f"Downloaded model {info.name}" + info.description = metadata.description or f"Imported model {info.name}" info.author = metadata.author info.tags = metadata.tags info.license = metadata.license @@ -339,22 +361,46 @@ class ModelInstall(ModelInstallBase): self._store.update_model(model_id, info) self._async_installs[job.source] = model_id elif job.status == "error": - self._logger.warning(f"{job.source}: Download finished with error: {job.error}") + self._logger.warning(f"{job.source}: Model installation error: {job.error}") elif job.status == "cancelled": - self._logger.warning(f"{job.source}: Download cancelled at caller's request.") + self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") jobs = queue.list_jobs() - if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]: + if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]: self._tmpdir.cleanup() self._tmpdir = None - # note - this is probably not going to work. The tmpdir - # will be deleted before the job actually runs. - # Better to do the cleanup in the callback - self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) - job = queue.create_download_job(source=source, destdir=self._tmpdir.name) - job.add_event_handler(complete_installation) + def complete_registration(job: DownloadJobBase): + if job.status == "completed": + self._logger.info(f"{job.source}: Installing in place.") + model_id = self.register_path(job.destination) + info = self._store.get_model(model_id) + info.source = str(job.source) + info.description = f"Imported model {info.name}" + self._store.update_model(model_id, info) + self._async_installs[job.source] = model_id + elif job.status == "error": + self._logger.warning(f"{job.source}: Model installation error: {job.error}") + elif job.status == "cancelled": + self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") - def wait_for_downloads(self) -> Dict[str, str]: # noqa D102 + # In the event that we are being asked to install a path that is already on disk, + # we simply probe and register/install it. The job does not actually do anything, but we + # create one anyway in order to have similar behavior for local files, URLs and repo_ids. + if Path(source).exists(): # a path that is already on disk + source = Path(source) + destdir = source + job = queue.create_download_job(source=source, destdir=destdir, start=False, variant=variant) + job.add_event_handler(complete_registration if inplace else complete_installation) + else: + self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) + job = queue.create_download_job(source=source, destdir=self._tmpdir.name, start=False, variant=variant) + job.add_event_handler(complete_installation) + + self._async_installs[source] = None + queue.start_job(job) + return job + + def wait_for_installs(self) -> Dict[str, str]: # noqa D102 self._download_queue.join() id_map = self._async_installs self._async_installs = dict() diff --git a/tests/test_model_download.py b/tests/test_model_download.py index 4730d815f8..400408b945 100644 --- a/tests/test_model_download.py +++ b/tests/test_model_download.py @@ -13,9 +13,17 @@ from invokeai.backend.model_manager.download import ( DownloadJobBase, UnknownJobIDException, ) +import invokeai.backend.model_manager.download.queue as download_queue +# Allow for at least one chunk to be fetched during the pause/unpause test. +# Otherwise pause test doesn't work because whole file contents are read +# before pause is received. +download_queue.DOWNLOAD_CHUNK_SIZE = 16500 + +# Prevent pytest deprecation warnings TestAdapter.__test__ = False +# Disable some tests that require the internet. INTERNET_AVAILABLE = requests.get("http://www.google.com/").status_code == 200 ######################################################################################## @@ -264,7 +272,6 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con time.sleep(0.5) # slow down the thread by blocking it just a bit at every step queue = DownloadQueue(requests_session=session, event_handlers=[event_handler]) - with tempfile.TemporaryDirectory() as tmpdir: job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False) job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)