mirror of
https://github.com/ihabunek/twitch-dl
synced 2024-08-30 18:32:25 +00:00
Use an abstract base class for TokenBucket
This commit is contained in:
parent
a9aefa871d
commit
8592c41a50
@ -4,7 +4,8 @@ import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from twitchdl.progress import Progress
|
||||
|
||||
@ -25,7 +26,13 @@ https://www.python-httpx.org/advanced/#timeout-configuration
|
||||
"""
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
class TokenBucket(ABC):
|
||||
@abstractmethod
|
||||
def advance(self, size: int):
|
||||
pass
|
||||
|
||||
|
||||
class LimitingTokenBucket(TokenBucket):
|
||||
"""Limit the download speed by strategically inserting sleeps."""
|
||||
|
||||
def __init__(self, rate: int, capacity: Optional[int] = None):
|
||||
@ -53,22 +60,19 @@ class TokenBucket:
|
||||
self.last_refilled = now
|
||||
|
||||
|
||||
class EndlessTokenBucket:
|
||||
class EndlessTokenBucket(TokenBucket):
|
||||
"""Used when download speed is not limited."""
|
||||
def advance(self, size: int):
|
||||
pass
|
||||
|
||||
|
||||
AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket]
|
||||
|
||||
|
||||
async def download(
|
||||
client: httpx.AsyncClient,
|
||||
task_id: int,
|
||||
source: str,
|
||||
target: str,
|
||||
progress: Progress,
|
||||
token_bucket: AnyTokenBucket,
|
||||
token_bucket: TokenBucket,
|
||||
):
|
||||
# Download to a temp file first, then copy to target when over to avoid
|
||||
# getting saving chunks which may persist if canceled or --keep is used
|
||||
@ -93,7 +97,7 @@ async def download_with_retries(
|
||||
source: str,
|
||||
target: str,
|
||||
progress: Progress,
|
||||
token_bucket: AnyTokenBucket,
|
||||
token_bucket: TokenBucket,
|
||||
):
|
||||
async with semaphore:
|
||||
if os.path.exists(target):
|
||||
@ -121,7 +125,7 @@ async def download_all(
|
||||
rate_limit: Optional[int] = None
|
||||
):
|
||||
progress = Progress(len(sources))
|
||||
token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
||||
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
semaphore = asyncio.Semaphore(workers)
|
||||
tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)
|
||||
|
Loading…
Reference in New Issue
Block a user