Use an abstract base class for TokenBucket

This commit is contained in:
Ivan Habunek 2024-03-23 07:56:50 +01:00
parent a9aefa871d
commit 8592c41a50
No known key found for this signature in database
GPG Key ID: F5F0623FF5EBCB3D

View File

@ -4,7 +4,8 @@ import logging
import os import os
import time import time
from typing import List, Optional, Union from abc import ABC, abstractmethod
from typing import List, Optional
from twitchdl.progress import Progress from twitchdl.progress import Progress
@ -25,7 +26,13 @@ https://www.python-httpx.org/advanced/#timeout-configuration
""" """
class TokenBucket: class TokenBucket(ABC):
@abstractmethod
def advance(self, size: int):
pass
class LimitingTokenBucket(TokenBucket):
"""Limit the download speed by strategically inserting sleeps.""" """Limit the download speed by strategically inserting sleeps."""
def __init__(self, rate: int, capacity: Optional[int] = None): def __init__(self, rate: int, capacity: Optional[int] = None):
@ -53,22 +60,19 @@ class TokenBucket:
self.last_refilled = now self.last_refilled = now
class EndlessTokenBucket: class EndlessTokenBucket(TokenBucket):
"""Used when download speed is not limited.""" """Used when download speed is not limited."""
def advance(self, size: int): def advance(self, size: int):
pass pass
AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket]
async def download( async def download(
client: httpx.AsyncClient, client: httpx.AsyncClient,
task_id: int, task_id: int,
source: str, source: str,
target: str, target: str,
progress: Progress, progress: Progress,
token_bucket: AnyTokenBucket, token_bucket: TokenBucket,
): ):
# Download to a temp file first, then copy to target when over to avoid # Download to a temp file first, then copy to target when over to avoid
# getting saving chunks which may persist if canceled or --keep is used # getting saving chunks which may persist if canceled or --keep is used
@ -93,7 +97,7 @@ async def download_with_retries(
source: str, source: str,
target: str, target: str,
progress: Progress, progress: Progress,
token_bucket: AnyTokenBucket, token_bucket: TokenBucket,
): ):
async with semaphore: async with semaphore:
if os.path.exists(target): if os.path.exists(target):
@ -121,7 +125,7 @@ async def download_all(
rate_limit: Optional[int] = None rate_limit: Optional[int] = None
): ):
progress = Progress(len(sources)) progress = Progress(len(sources))
token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket() token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
async with httpx.AsyncClient(timeout=TIMEOUT) as client: async with httpx.AsyncClient(timeout=TIMEOUT) as client:
semaphore = asyncio.Semaphore(workers) semaphore = asyncio.Semaphore(workers)
tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket) tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)