diff --git a/twitchdl/http.py b/twitchdl/http.py index 26e571e..515c837 100644 --- a/twitchdl/http.py +++ b/twitchdl/http.py @@ -4,7 +4,8 @@ import logging import os import time -from typing import List, Optional, Union +from abc import ABC, abstractmethod +from typing import List, Optional from twitchdl.progress import Progress @@ -25,7 +26,13 @@ https://www.python-httpx.org/advanced/#timeout-configuration """ -class TokenBucket: +class TokenBucket(ABC): + @abstractmethod + def advance(self, size: int): + pass + + +class LimitingTokenBucket(TokenBucket): """Limit the download speed by strategically inserting sleeps.""" def __init__(self, rate: int, capacity: Optional[int] = None): @@ -53,22 +60,19 @@ class TokenBucket: self.last_refilled = now -class EndlessTokenBucket: +class EndlessTokenBucket(TokenBucket): """Used when download speed is not limited.""" def advance(self, size: int): pass -AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket] - - async def download( client: httpx.AsyncClient, task_id: int, source: str, target: str, progress: Progress, - token_bucket: AnyTokenBucket, + token_bucket: TokenBucket, ): # Download to a temp file first, then copy to target when over to avoid # getting saving chunks which may persist if canceled or --keep is used @@ -93,7 +97,7 @@ async def download_with_retries( source: str, target: str, progress: Progress, - token_bucket: AnyTokenBucket, + token_bucket: TokenBucket, ): async with semaphore: if os.path.exists(target): @@ -121,7 +125,7 @@ async def download_all( rate_limit: Optional[int] = None ): progress = Progress(len(sources)) - token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket() + token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket() async with httpx.AsyncClient(timeout=TIMEOUT) as client: semaphore = asyncio.Semaphore(workers) tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)