diff --git a/twitchdl/http.py b/twitchdl/http.py new file mode 100644 index 0000000..2452b6a --- /dev/null +++ b/twitchdl/http.py @@ -0,0 +1,119 @@ +import asyncio +import httpx +import logging +import time + +from typing import List, Optional, Union + +from twitchdl.progress import Progress + +logger = logging.getLogger(__name__) + +KB = 1024 + +CHUNK_SIZE = 256 * KB +"""How much of a VOD to download in each iteration""" + +RETRY_COUNT = 5 +"""Number of times to retry failed downloads before aborting.""" + +TIMEOUT = 30 +""" +Number of seconds to wait before aborting when there is no network activity. +https://www.python-httpx.org/advanced/#timeout-configuration +""" + + +class TokenBucket: + """Limit the download speed by strategically inserting sleeps.""" + + def __init__(self, rate: int, capacity: Optional[int] = None): + self.rate: int = rate + self.capacity: int = capacity or rate * 2 + self.available: int = 0 + self.last_refilled: float = time.time() + + def advance(self, size: int): + """Called every time a chunk of data is downloaded.""" + self._refill() + + if self.available < size: + deficit = size - self.available + time.sleep(deficit / self.rate) + + self.available -= size + + def _refill(self): + """Increase available capacity according to elapsed time since last refill.""" + now = time.time() + elapsed = now - self.last_refilled + refill_amount = int(elapsed * self.rate) + self.available = min(self.available + refill_amount, self.capacity) + self.last_refilled = now + + +class EndlessTokenBucket: + """Used when download speed is not limited.""" + def advance(self, size): + pass + + +AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket] + + +async def download( + client: httpx.AsyncClient, + task_id: int, + source: str, + target: str, + progress: Progress, + token_bucket: AnyTokenBucket, +): + with open(target, "wb") as f: + async with client.stream("GET", source) as response: + size = int(response.headers.get("content-length")) + progress.start(task_id, size) + async for chunk in response.aiter_bytes(chunk_size=CHUNK_SIZE): + f.write(chunk) + size = len(chunk) + token_bucket.advance(size) + progress.advance(task_id, size) + progress.end(task_id) + + +async def download_with_retries( + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + task_id: int, + source: str, + target: str, + progress: Progress, + token_bucket: AnyTokenBucket, +): + async with semaphore: + for n in range(RETRY_COUNT): + try: + return await download(client, task_id, source, target, progress, token_bucket) + except httpx.RequestError: + logger.exception("Task {task_id} failed. Retrying. Maybe.") + progress.abort(task_id) + if n + 1 >= RETRY_COUNT: + raise + + raise Exception("Should not happen") + + +async def download_all( + sources: List[str], + targets: List[str], + workers: int, + /, *, + rate_limit: Optional[int] = None +): + progress = Progress(len(sources)) + token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket() + async with httpx.AsyncClient(timeout=TIMEOUT) as client: + semaphore = asyncio.Semaphore(workers) + tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket) + for task_id, (source, target) in enumerate(zip(sources, targets))] + await asyncio.gather(*tasks)