2022-08-13 09:41:13 +00:00
|
|
|
import asyncio
|
|
|
|
import httpx
|
|
|
|
import logging
|
2022-08-14 08:02:41 +00:00
|
|
|
import os
|
2022-08-13 09:41:13 +00:00
|
|
|
import time
|
|
|
|
|
|
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
|
|
from twitchdl.progress import Progress
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
KB = 1024
|
|
|
|
|
|
|
|
CHUNK_SIZE = 256 * KB
|
|
|
|
"""How much of a VOD to download in each iteration"""
|
|
|
|
|
|
|
|
RETRY_COUNT = 5
|
|
|
|
"""Number of times to retry failed downloads before aborting."""
|
|
|
|
|
|
|
|
TIMEOUT = 30
|
|
|
|
"""
|
|
|
|
Number of seconds to wait before aborting when there is no network activity.
|
|
|
|
https://www.python-httpx.org/advanced/#timeout-configuration
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
class TokenBucket:
|
|
|
|
"""Limit the download speed by strategically inserting sleeps."""
|
|
|
|
|
|
|
|
def __init__(self, rate: int, capacity: Optional[int] = None):
|
|
|
|
self.rate: int = rate
|
|
|
|
self.capacity: int = capacity or rate * 2
|
|
|
|
self.available: int = 0
|
|
|
|
self.last_refilled: float = time.time()
|
|
|
|
|
|
|
|
def advance(self, size: int):
|
|
|
|
"""Called every time a chunk of data is downloaded."""
|
|
|
|
self._refill()
|
|
|
|
|
|
|
|
if self.available < size:
|
|
|
|
deficit = size - self.available
|
|
|
|
time.sleep(deficit / self.rate)
|
|
|
|
|
|
|
|
self.available -= size
|
|
|
|
|
|
|
|
def _refill(self):
|
|
|
|
"""Increase available capacity according to elapsed time since last refill."""
|
|
|
|
now = time.time()
|
|
|
|
elapsed = now - self.last_refilled
|
|
|
|
refill_amount = int(elapsed * self.rate)
|
|
|
|
self.available = min(self.available + refill_amount, self.capacity)
|
|
|
|
self.last_refilled = now
|
|
|
|
|
|
|
|
|
|
|
|
class EndlessTokenBucket:
|
|
|
|
"""Used when download speed is not limited."""
|
|
|
|
def advance(self, size):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket]
|
|
|
|
|
|
|
|
|
|
|
|
async def download(
|
|
|
|
client: httpx.AsyncClient,
|
|
|
|
task_id: int,
|
|
|
|
source: str,
|
|
|
|
target: str,
|
|
|
|
progress: Progress,
|
|
|
|
token_bucket: AnyTokenBucket,
|
|
|
|
):
|
2022-08-14 09:33:23 +00:00
|
|
|
# Download to a temp file first, then copy to target when over to avoid
|
|
|
|
# getting saving chunks which may persist if canceled or --keep is used
|
|
|
|
tmp_target = f"{target}.tmp"
|
|
|
|
with open(tmp_target, "wb") as f:
|
2022-08-13 09:41:13 +00:00
|
|
|
async with client.stream("GET", source) as response:
|
|
|
|
size = int(response.headers.get("content-length"))
|
|
|
|
progress.start(task_id, size)
|
|
|
|
async for chunk in response.aiter_bytes(chunk_size=CHUNK_SIZE):
|
|
|
|
f.write(chunk)
|
|
|
|
size = len(chunk)
|
|
|
|
token_bucket.advance(size)
|
|
|
|
progress.advance(task_id, size)
|
|
|
|
progress.end(task_id)
|
2022-08-14 09:33:23 +00:00
|
|
|
os.rename(tmp_target, target)
|
2022-08-13 09:41:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def download_with_retries(
|
|
|
|
client: httpx.AsyncClient,
|
|
|
|
semaphore: asyncio.Semaphore,
|
|
|
|
task_id: int,
|
|
|
|
source: str,
|
|
|
|
target: str,
|
|
|
|
progress: Progress,
|
|
|
|
token_bucket: AnyTokenBucket,
|
|
|
|
):
|
|
|
|
async with semaphore:
|
2022-08-14 08:02:41 +00:00
|
|
|
if os.path.exists(target):
|
|
|
|
size = os.path.getsize(target)
|
|
|
|
progress.already_downloaded(task_id, size)
|
|
|
|
return
|
|
|
|
|
2022-08-13 09:41:13 +00:00
|
|
|
for n in range(RETRY_COUNT):
|
|
|
|
try:
|
|
|
|
return await download(client, task_id, source, target, progress, token_bucket)
|
|
|
|
except httpx.RequestError:
|
|
|
|
logger.exception("Task {task_id} failed. Retrying. Maybe.")
|
|
|
|
progress.abort(task_id)
|
|
|
|
if n + 1 >= RETRY_COUNT:
|
|
|
|
raise
|
|
|
|
|
|
|
|
raise Exception("Should not happen")
|
|
|
|
|
|
|
|
|
|
|
|
async def download_all(
|
|
|
|
sources: List[str],
|
|
|
|
targets: List[str],
|
|
|
|
workers: int,
|
|
|
|
/, *,
|
|
|
|
rate_limit: Optional[int] = None
|
|
|
|
):
|
|
|
|
progress = Progress(len(sources))
|
|
|
|
token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
|
|
|
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
|
|
|
semaphore = asyncio.Semaphore(workers)
|
|
|
|
tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)
|
|
|
|
for task_id, (source, target) in enumerate(zip(sources, targets))]
|
|
|
|
await asyncio.gather(*tasks)
|