mirror of
https://github.com/ihabunek/twitch-dl
synced 2024-08-30 18:32:25 +00:00
Use a queue+workers instead of semaphore
This commit is contained in:
parent
8c68132ddb
commit
da51ffc31f
@ -4,7 +4,7 @@ import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, NamedTuple, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
@ -95,55 +95,81 @@ async def download(
|
||||
|
||||
async def download_with_retries(
|
||||
client: httpx.AsyncClient,
|
||||
semaphore: asyncio.Semaphore,
|
||||
task_id: int,
|
||||
source: str,
|
||||
target: Path,
|
||||
progress: Progress,
|
||||
token_bucket: TokenBucket,
|
||||
):
|
||||
async with semaphore:
|
||||
if target.exists():
|
||||
size = os.path.getsize(target)
|
||||
progress.already_downloaded(task_id, size)
|
||||
return
|
||||
if target.exists():
|
||||
size = os.path.getsize(target)
|
||||
progress.already_downloaded(task_id, size)
|
||||
return
|
||||
|
||||
for n in range(RETRY_COUNT):
|
||||
try:
|
||||
return await download(client, task_id, source, target, progress, token_bucket)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Task {task_id} failed. Retrying. Maybe.")
|
||||
progress.abort(task_id)
|
||||
if n + 1 >= RETRY_COUNT:
|
||||
raise
|
||||
for n in range(RETRY_COUNT):
|
||||
try:
|
||||
return await download(client, task_id, source, target, progress, token_bucket)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Task {task_id} failed. Retrying. Maybe.")
|
||||
progress.abort(task_id)
|
||||
if n + 1 >= RETRY_COUNT:
|
||||
raise
|
||||
|
||||
raise Exception("Should not happen")
|
||||
raise Exception("Should not happen")
|
||||
|
||||
|
||||
class QueueItem(NamedTuple):
|
||||
task_id: int
|
||||
url: str
|
||||
target: Path
|
||||
|
||||
|
||||
async def download_worker(
|
||||
queue: asyncio.Queue[QueueItem],
|
||||
client: httpx.AsyncClient,
|
||||
progress: Progress,
|
||||
token_bucket: TokenBucket,
|
||||
):
|
||||
while True:
|
||||
item = await queue.get()
|
||||
await download_with_retries(
|
||||
client,
|
||||
item.task_id,
|
||||
item.url,
|
||||
item.target,
|
||||
progress,
|
||||
token_bucket,
|
||||
)
|
||||
queue.task_done()
|
||||
|
||||
|
||||
async def download_all(
|
||||
source_targets: Iterable[Tuple[str, Path]],
|
||||
workers: int,
|
||||
worker_count: int,
|
||||
*,
|
||||
count: Optional[int] = None,
|
||||
rate_limit: Optional[int] = None,
|
||||
):
|
||||
progress = Progress(count)
|
||||
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
||||
queue: asyncio.Queue[QueueItem] = asyncio.Queue()
|
||||
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
semaphore = asyncio.Semaphore(workers)
|
||||
tasks = [
|
||||
download_with_retries(
|
||||
client,
|
||||
semaphore,
|
||||
task_id,
|
||||
source,
|
||||
target,
|
||||
progress,
|
||||
token_bucket,
|
||||
)
|
||||
for task_id, (source, target) in enumerate(source_targets)
|
||||
asyncio.create_task(download_worker(queue, client, progress, token_bucket))
|
||||
for _ in range(worker_count)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
for index, (source, target) in enumerate(source_targets):
|
||||
await queue.put(QueueItem(index, source, target))
|
||||
|
||||
# Wait for queue to deplete
|
||||
await queue.join()
|
||||
|
||||
# Cancel tasks and wait until they are cancelled
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user