mirror of
https://github.com/ihabunek/twitch-dl
synced 2024-08-30 18:32:25 +00:00
Use an abstract base class for TokenBucket
This commit is contained in:
parent
a9aefa871d
commit
8592c41a50
@ -4,7 +4,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from twitchdl.progress import Progress
|
from twitchdl.progress import Progress
|
||||||
|
|
||||||
@ -25,7 +26,13 @@ https://www.python-httpx.org/advanced/#timeout-configuration
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TokenBucket:
|
class TokenBucket(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def advance(self, size: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LimitingTokenBucket(TokenBucket):
|
||||||
"""Limit the download speed by strategically inserting sleeps."""
|
"""Limit the download speed by strategically inserting sleeps."""
|
||||||
|
|
||||||
def __init__(self, rate: int, capacity: Optional[int] = None):
|
def __init__(self, rate: int, capacity: Optional[int] = None):
|
||||||
@ -53,22 +60,19 @@ class TokenBucket:
|
|||||||
self.last_refilled = now
|
self.last_refilled = now
|
||||||
|
|
||||||
|
|
||||||
class EndlessTokenBucket:
|
class EndlessTokenBucket(TokenBucket):
|
||||||
"""Used when download speed is not limited."""
|
"""Used when download speed is not limited."""
|
||||||
def advance(self, size: int):
|
def advance(self, size: int):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
AnyTokenBucket = Union[TokenBucket, EndlessTokenBucket]
|
|
||||||
|
|
||||||
|
|
||||||
async def download(
|
async def download(
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
source: str,
|
source: str,
|
||||||
target: str,
|
target: str,
|
||||||
progress: Progress,
|
progress: Progress,
|
||||||
token_bucket: AnyTokenBucket,
|
token_bucket: TokenBucket,
|
||||||
):
|
):
|
||||||
# Download to a temp file first, then copy to target when over to avoid
|
# Download to a temp file first, then copy to target when over to avoid
|
||||||
# getting saving chunks which may persist if canceled or --keep is used
|
# getting saving chunks which may persist if canceled or --keep is used
|
||||||
@ -93,7 +97,7 @@ async def download_with_retries(
|
|||||||
source: str,
|
source: str,
|
||||||
target: str,
|
target: str,
|
||||||
progress: Progress,
|
progress: Progress,
|
||||||
token_bucket: AnyTokenBucket,
|
token_bucket: TokenBucket,
|
||||||
):
|
):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
if os.path.exists(target):
|
if os.path.exists(target):
|
||||||
@ -121,7 +125,7 @@ async def download_all(
|
|||||||
rate_limit: Optional[int] = None
|
rate_limit: Optional[int] = None
|
||||||
):
|
):
|
||||||
progress = Progress(len(sources))
|
progress = Progress(len(sources))
|
||||||
token_bucket = TokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
||||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||||
semaphore = asyncio.Semaphore(workers)
|
semaphore = asyncio.Semaphore(workers)
|
||||||
tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)
|
tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)
|
||||||
|
Loading…
Reference in New Issue
Block a user