mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added download manager service and began repo_id download
This commit is contained in:
178
invokeai/app/services/download_manager.py
Normal file
178
invokeai/app/services/download_manager.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Model download service.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from .events import EventServicesBase
|
||||
from invokeai.backend.model_manager.download import DownloadQueue, DownloadJobBase, DownloadEventHandler
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param source: Source of the download - URL or repo_id
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:param start: Immediately start job [True]
|
||||
:returns job id: The numeric ID of the DownloadJobBase object for this task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""
|
||||
List active DownloadJobBases.
|
||||
|
||||
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""
|
||||
Return the DownloadJobBase corresponding to the string ID.
|
||||
|
||||
:param id: ID of the DownloadJobBase.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_all_jobs(self):
|
||||
"""Enqueue all stopped jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_all_jobs(self):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, id: int):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, id: int):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, id: int):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, id: int, delta: int):
|
||||
"""
|
||||
Change the job's priority.
|
||||
|
||||
:param id: ID of the job
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
job.change_priority(-10).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
|
||||
class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
_event_bus: EventServicesBase
|
||||
_queue: DownloadQueue
|
||||
|
||||
def __init__(self, event_bus: EventServicesBase, **kwargs):
|
||||
"""
|
||||
Initialize new DownloadQueueService object.
|
||||
|
||||
:param event_bus: EventServicesBase object for reporting progress.
|
||||
:param **kwargs: Any of the arguments taken by invokeai.backend.model_manager.download.DownloadQueue.
|
||||
e.g. `max_parallel_dl`.
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._queue = DownloadQueue(event_handler=self._forward_event)
|
||||
|
||||
def _forward_event(self, job: DownloadJobBase):
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_download_event(job)
|
||||
|
||||
def _wrap_handler(self, event_handler: DownloadEventHandler) -> DownloadEventHandler:
|
||||
def __wrapper(job: DownloadJobBase):
|
||||
self._forward_events(job)
|
||||
event_handler(job)
|
||||
return __wrapper
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
) -> int:
|
||||
if event_handler:
|
||||
event_handler = self._wrap_handler(event_handler)
|
||||
return self._queue.create_download_job(
|
||||
source, destdir, filename, start, access_token,
|
||||
event_handler=event_handler
|
||||
)
|
||||
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
return self._queue.list_jobs()
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
return self._queue.id_to_job(id)
|
||||
|
||||
def start_all_jobs(self):
|
||||
return self._queue.start_all_jobs()
|
||||
|
||||
def pause_all_jobs(self):
|
||||
return self._queue.pause_all_jobs()
|
||||
|
||||
def cancel_all_jobs(self):
|
||||
return self._queue.cancel_all_jobs()
|
||||
|
||||
def start_job(self, id: int):
|
||||
return self._queue.start_job(id)
|
||||
|
||||
def pause_job(self, id: int):
|
||||
return self._queue.pause_job(id)
|
||||
|
||||
def cancel_job(self, id: int):
|
||||
return self._queue.cancel_job(id)
|
||||
|
||||
def change_priority(self, id: int, delta: int):
|
||||
return self._queue.change_priority(id, delta)
|
||||
|
||||
def join(self):
|
||||
return self._queue.join()
|
12
invokeai/backend/model_manager/download/__init__.py
Normal file
12
invokeai/backend/model_manager/download/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Initialization file for threaded download manager."""
|
||||
|
||||
from .base import ( # noqa F401
|
||||
DownloadQueueBase,
|
||||
DownloadJobStatus,
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
CancelledJobException,
|
||||
DownloadJobBase
|
||||
)
|
||||
|
||||
from .queue import DownloadQueue # noqa F401
|
177
invokeai/backend/model_manager/download/base.py
Normal file
177
invokeai/backend/model_manager/download/base.py
Normal file
@ -0,0 +1,177 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Abstract base class for a multithreaded model download queue.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Set, List, Optional, Dict, Callable
|
||||
from pydantic import BaseModel, Field, validator, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
|
||||
IDLE = "idle" # not enqueued, will not run
|
||||
ENQUEUED = "enqueued" # enqueued but not yet active
|
||||
RUNNING = "running" # actively downloading
|
||||
PAUSED = "paused" # previously started, now paused
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
|
||||
|
||||
class UnknownJobIDException(Exception):
|
||||
"""Raised when an invalid Job ID is requested."""
|
||||
|
||||
|
||||
class CancelledJobException(Exception):
|
||||
"""Raised when a job is cancelled."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJobBase(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
id: int = Field(description="Numeric ID of this job")
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
event_handler: Optional[DownloadEventHandler] = Field(description="Callable will be called whenever job status changes")
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
class Config():
|
||||
"""Config object for this pydantic class."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
validate_assignment = True
|
||||
|
||||
def __lt__(self, other: "DownloadJobBase") -> bool:
|
||||
"""
|
||||
Return True if self.priority < other.priority.
|
||||
|
||||
:param other: The DownloadJobBase that this will be compared against.
|
||||
"""
|
||||
if not hasattr(other, "id"):
|
||||
return NotImplemented
|
||||
return self.id < other.id
|
||||
|
||||
|
||||
class DownloadQueueBase(ABC):
|
||||
"""Abstract base class for managing model downloads."""
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param source: Source of the download - URL or repo_id
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:param start: Immediately start job [True]
|
||||
:param variant: Variant to download, such as "fp16" (repo_ids only).
|
||||
:param event_handler: Optional callable that will be called whenever job status changes.
|
||||
:returns job id: The numeric ID of the DownloadJobBase object for this task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release(self) -> int:
|
||||
"""
|
||||
Release resources used by queue.
|
||||
|
||||
If threaded downloads are
|
||||
used, then this will stop the threads.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""
|
||||
List active DownloadJobBases.
|
||||
|
||||
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""
|
||||
Return the DownloadJobBase corresponding to the string ID.
|
||||
|
||||
:param id: ID of the DownloadJobBase.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_all_jobs(self):
|
||||
"""Enqueue all stopped jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_all_jobs(self):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, id: int):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, id: int):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, id: int):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, id: int, delta: int):
|
||||
"""
|
||||
Change the job's priority.
|
||||
|
||||
:param id: ID of the job
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
job.change_priority(-10).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
|
@ -1,241 +1,113 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Abstract base class for a multithreaded model download queue.
|
||||
"""
|
||||
# Copyright (c) 2023, Lincoln D. Stein
|
||||
"""Implementation of multithreaded download queue for invokeai."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import os
|
||||
import requests
|
||||
import threading
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from queue import PriorityQueue
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Set, List, Optional, Dict, Callable
|
||||
from typing import Dict, Optional, Set, List
|
||||
|
||||
from pydantic import BaseModel, Field, validator, PrivateAttr
|
||||
from pydantic import Field, validator, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from queue import PriorityQueue
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .base import (
|
||||
DownloadQueueBase,
|
||||
DownloadJobStatus,
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
CancelledJobException,
|
||||
DownloadJobBase
|
||||
)
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = DownloadJobBase(
|
||||
id=-99,
|
||||
priority=-99,
|
||||
source='dummy',
|
||||
destination='/')
|
||||
|
||||
|
||||
class EventServicesBase: # forward declaration
|
||||
pass
|
||||
class DownloadJobURL(DownloadJobBase):
|
||||
"""Job declaration for downloading individual URLs."""
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
|
||||
IDLE = "idle" # not enqueued, will not run
|
||||
ENQUEUED = "enqueued" # enqueued but not yet active
|
||||
RUNNING = "running" # actively downloading
|
||||
PAUSED = "paused" # previously started, now paused
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
source: AnyHttpUrl = Field(description="URL to download")
|
||||
|
||||
|
||||
class UnknownJobIDException(Exception):
|
||||
"""Raised when an invalid Job ID is requested."""
|
||||
class DownloadJobRepoID(DownloadJobBase):
|
||||
"""Download repo ids."""
|
||||
|
||||
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
||||
|
||||
class CancelledJobException(Exception):
|
||||
"""Raised when a job is cancelled."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJob(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
id: int = Field(description="Numeric ID of this job")
|
||||
url: AnyHttpUrl = Field(description="URL to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
event_handler: Optional[DownloadEventHandler] = Field(description="Callable will be called whenever job status changes")
|
||||
error: Exception = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
class Config():
|
||||
"""Config object for this pydantic class."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
validate_assignment = True
|
||||
|
||||
# @validator('destination')
|
||||
# def path_doesnt_exist(cls, v):
|
||||
# """Don't allow a destination to clobber an existing file."""
|
||||
# if v.exists():
|
||||
# raise ValueError(f"{v} already exists")
|
||||
# return v
|
||||
|
||||
def __lt__(self, other: "DownloadJob") -> bool:
|
||||
"""
|
||||
Return True if self.priority < other.priority.
|
||||
|
||||
:param other: The DownloadJob that this will be compared against.
|
||||
"""
|
||||
if not hasattr(other, "id"):
|
||||
return NotImplemented
|
||||
return self.id < other.id
|
||||
|
||||
|
||||
class DownloadQueueBase(ABC):
|
||||
"""Abstract base class for managing model downloads."""
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
url: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param url: URL to download.
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:param start: Immediately start job [True]
|
||||
:param event_handler: Optional callable that will be called whenever job status changes.
|
||||
:returns job id: The numeric ID of the DownloadJob object for this task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""
|
||||
List active DownloadJobs.
|
||||
|
||||
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJob:
|
||||
"""
|
||||
Return the DownloadJob corresponding to the string ID.
|
||||
|
||||
:param id: ID of the DownloadJob.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_all_jobs(self):
|
||||
"""Enqueue all stopped jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_all_jobs(self):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, id: int):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, id: int):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, id: int):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, id: int, delta: int):
|
||||
"""
|
||||
Change the job's priority.
|
||||
|
||||
:param id: ID of the job
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
job.change_priority(-10).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
@validator('source')
|
||||
@classmethod
|
||||
def _validate_source(cls, v: str) -> str:
|
||||
if not re.match(r'^\w+/\w+$', v):
|
||||
raise ValidationError(f'{v} invalid repo_id')
|
||||
return v
|
||||
|
||||
|
||||
class DownloadQueue(DownloadQueueBase):
|
||||
"""Class for queued download of models."""
|
||||
|
||||
_jobs: Dict[int, DownloadJob]
|
||||
_worker_pool: Set[Thread]
|
||||
_jobs: Dict[int, DownloadJobBase]
|
||||
_worker_pool: Set[threading.Thread]
|
||||
_queue: PriorityQueue
|
||||
_lock: threading.Lock
|
||||
_logger: InvokeAILogger
|
||||
_event_bus: Optional[EventServicesBase]
|
||||
_event_handler: Optional[DownloadEventHandler]
|
||||
_next_job_id: int = 0
|
||||
|
||||
def __init__(self,
|
||||
max_parallel_dl: int = 5,
|
||||
events: Optional["EventServicesBase"] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param events: Optional EventServices bus for reporting events.
|
||||
:param event_handler: Optional callable that will be called each time a job status changes.
|
||||
"""
|
||||
print('IN __INIT__')
|
||||
self._jobs = dict()
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._worker_pool = set()
|
||||
self._lock = threading.RLock()
|
||||
self._logger = InvokeAILogger.getLogger()
|
||||
self._event_bus = events
|
||||
self._event_handler = event_handler
|
||||
|
||||
self._start_workers(max_parallel_dl)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
url: str,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
) -> int:
|
||||
if re.match(r'^\w+/\w+$', source):
|
||||
cls = DownloadJobRepoID
|
||||
kwargs = dict(variant=variant)
|
||||
else:
|
||||
cls = DownloadJobURL
|
||||
kwargs = dict()
|
||||
try:
|
||||
self._lock.acquire()
|
||||
id = self._next_job_id
|
||||
self._jobs[id] = DownloadJob(
|
||||
self._jobs[id] = cls(
|
||||
id=id,
|
||||
url=url,
|
||||
source=source,
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
access_token=access_token,
|
||||
event_handler=(event_handler or self._event_handler),
|
||||
**kwargs,
|
||||
)
|
||||
self._next_job_id += 1
|
||||
job = self._jobs[id]
|
||||
@ -245,13 +117,18 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self.start_job(id)
|
||||
return job.id
|
||||
|
||||
def release(self):
|
||||
"""Signal our threads to exit when queue done."""
|
||||
for thread in self._worker_pool:
|
||||
if thread.is_alive():
|
||||
self._queue.put(STOP_JOB)
|
||||
|
||||
def join(self):
|
||||
self._queue.join()
|
||||
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
return self._jobs.values()
|
||||
|
||||
|
||||
def change_priority(self, id: int, delta: int):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -262,7 +139,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_job(self, job: DownloadJob):
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
@ -272,7 +149,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJob:
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
try:
|
||||
return self._jobs[id]
|
||||
except KeyError as excp:
|
||||
@ -322,7 +199,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
def _start_workers(self, max_workers: int):
|
||||
for i in range(0, max_workers):
|
||||
worker = Thread(target=self._download_next_item, daemon=True)
|
||||
worker = threading.Thread(target=self._download_next_item, daemon=True)
|
||||
worker.start()
|
||||
self._worker_pool.add(worker)
|
||||
|
||||
@ -330,17 +207,25 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""Worker thread gets next job on priority queue."""
|
||||
while True:
|
||||
job = self._queue.get()
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
print(f'DEBUG: thread {threading.current_thread().native_id} exiting')
|
||||
break
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
|
||||
self._download_with_resume(job)
|
||||
if isinstance(job, DownloadJobURL):
|
||||
self._download_with_resume(job)
|
||||
elif isinstance(job, DownloadJobRepoID):
|
||||
raise self._download_repoid(job)
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this job: {job}")
|
||||
self._queue.task_done()
|
||||
|
||||
def _download_with_resume(self, job: DownloadJob):
|
||||
def _download_with_resume(self, job: DownloadJobBase):
|
||||
"""Do the actual download."""
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = requests.get(job.url, header, stream=True)
|
||||
resp = requests.get(job.source, header, stream=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
@ -348,8 +233,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except AttributeError:
|
||||
file_name = os.path.basename(job.url)
|
||||
dest = job.destination / file_name
|
||||
file_name = os.path.basename(job.source)
|
||||
job.destination = job.destination / file_name
|
||||
dest = job.destination
|
||||
else:
|
||||
dest = job.destination
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
@ -358,7 +244,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
job.bytes = dest.stat().st_size
|
||||
header["Range"] = f"bytes={job.bytes}-"
|
||||
open_mode = "ab"
|
||||
resp = requests.get(job.url, headers=header, stream=True) # new request with range
|
||||
resp = requests.get(job.source, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
self._logger.warning("corrupt existing file found. re-downloading")
|
||||
@ -398,14 +284,18 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _update_job_status(self,
|
||||
job: DownloadJob,
|
||||
job: DownloadJobBase,
|
||||
new_status: Optional[DownloadJobStatus] = None
|
||||
):
|
||||
"""Optionally change the job status and send an event indicating a change of state."""
|
||||
if new_status:
|
||||
job.status = new_status
|
||||
if bus := self._event_bus:
|
||||
bus.emit_model_download_event(job)
|
||||
self._logger.debug(f"Status update for download job {job.id}: {job}")
|
||||
if job.event_handler:
|
||||
job.event_handler(job)
|
||||
|
||||
def _download_repoid(self, job: DownloadJobBase):
|
||||
"""Download a job that holds a huggingface repoid."""
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
|
Reference in New Issue
Block a user