Add new asyncio downloader code with rate limiting

This commit is contained in:
Ivan Habunek 2022-08-13 11:41:13 +02:00
parent 85631c8ce5
commit 23f1a74aa6
No known key found for this signature in database
GPG Key ID: CDBD63C43A30BB95

119
twitchdl/http.py Normal file
View File

@ -0,0 +1,119 @@
import asyncio
import httpx
import logging
import time
from typing import List, Optional, Union
from twitchdl.progress import Progress
logger = logging.getLogger(__name__)
KB = 1024
CHUNK_SIZE = 256 * KB
"""How much of a VOD to download in each iteration"""
RETRY_COUNT = 5
"""Number of times to retry failed downloads before aborting."""
TIMEOUT = 30
"""
Number of seconds to wait before aborting when there is no network activity.
https://www.python-httpx.org/advanced/#timeout-configuration
"""
class TokenBucket:
"""Limit the download speed by strategically inserting sleeps."""
def __init__(self, rate: int, capacity: Optional[int] = None):
self.rate: int = rate
self.capacity: int = capacity or rate * 2
self.available: int = 0
self.last_refilled: float = time.time()
def advance(self, size: int):
"""Called every time a chunk of data is downloaded."""
self._refill()
if self.available < size:
deficit = size - self.available
time.sleep(deficit / self.rate)
self.available -= size
def _refill(self):
"""Increase available capacity according to elapsed time since last refill."""
now = time.time()
elapsed = now - self.last_refilled
refill_amount = int(elapsed * self.rate)
self.available = min(self.available + refill_amount, self.capacity)
self.last_refilled = now
class EndlessTokenBucket:
"""Used when download speed is not limited."""
def advance(self, size):
pass
AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket]
async def download(
client: httpx.AsyncClient,
task_id: int,
source: str,
target: str,
progress: Progress,
token_bucket: AnyTokenBucket,
):
with open(target, "wb") as f:
async with client.stream("GET", source) as response:
size = int(response.headers.get("content-length"))
progress.start(task_id, size)
async for chunk in response.aiter_bytes(chunk_size=CHUNK_SIZE):
f.write(chunk)
size = len(chunk)
token_bucket.advance(size)
progress.advance(task_id, size)
progress.end(task_id)
async def download_with_retries(
client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
task_id: int,
source: str,
target: str,
progress: Progress,
token_bucket: AnyTokenBucket,
):
async with semaphore:
for n in range(RETRY_COUNT):
try:
return await download(client, task_id, source, target, progress, token_bucket)
except httpx.RequestError:
logger.exception("Task {task_id} failed. Retrying. Maybe.")
progress.abort(task_id)
if n + 1 >= RETRY_COUNT:
raise
raise Exception("Should not happen")
async def download_all(
sources: List[str],
targets: List[str],
workers: int,
/, *,
rate_limit: Optional[int] = None
):
progress = Progress(len(sources))
token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
semaphore = asyncio.Semaphore(workers)
tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)
for task_id, (source, target) in enumerate(zip(sources, targets))]
await asyncio.gather(*tasks)