Use an abstract base class for TokenBucket

This commit is contained in:
Ivan Habunek 2024-03-23 07:56:50 +01:00
parent a9aefa871d
commit 8592c41a50
No known key found for this signature in database
GPG Key ID: F5F0623FF5EBCB3D

View File

@ -4,7 +4,8 @@ import logging
import os
import time
from typing import List, Optional, Union
from abc import ABC, abstractmethod
from typing import List, Optional
from twitchdl.progress import Progress
@ -25,7 +26,13 @@ https://www.python-httpx.org/advanced/#timeout-configuration
"""
class TokenBucket:
class TokenBucket(ABC):
@abstractmethod
def advance(self, size: int):
pass
class LimitingTokenBucket(TokenBucket):
"""Limit the download speed by strategically inserting sleeps."""
def __init__(self, rate: int, capacity: Optional[int] = None):
@ -53,22 +60,19 @@ class TokenBucket:
self.last_refilled = now
class EndlessTokenBucket:
class EndlessTokenBucket(TokenBucket):
"""Used when download speed is not limited."""
def advance(self, size: int):
pass
AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket]
async def download(
client: httpx.AsyncClient,
task_id: int,
source: str,
target: str,
progress: Progress,
token_bucket: AnyTokenBucket,
token_bucket: TokenBucket,
):
# Download to a temp file first, then copy to target when over to avoid
# getting saving chunks which may persist if canceled or --keep is used
@ -93,7 +97,7 @@ async def download_with_retries(
source: str,
target: str,
progress: Progress,
token_bucket: AnyTokenBucket,
token_bucket: TokenBucket,
):
async with semaphore:
if os.path.exists(target):
@ -121,7 +125,7 @@ async def download_all(
rate_limit: Optional[int] = None
):
progress = Progress(len(sources))
token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
semaphore = asyncio.Semaphore(workers)
tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)