Refactor download_all to work when count is unknown

This commit is contained in:
Ivan Habunek 2024-08-30 07:37:50 +02:00
parent 5be97b9e13
commit f57612ffcc
No known key found for this signature in database
GPG Key ID: 01DB3DD0D824504C
2 changed files with 14 additions and 6 deletions

View File

@ -233,7 +233,15 @@ def _download_video(video_id: str, args: DownloadOptions) -> None:
sources = [base_uri + vod.path for vod in vods]
targets = [target_dir / f"{vod.index:05d}.ts" for vod in vods]
asyncio.run(download_all(sources, targets, args.max_workers, rate_limit=args.rate_limit))
asyncio.run(
download_all(
zip(sources, targets),
args.max_workers,
rate_limit=args.rate_limit,
count=len(vods),
)
)
join_playlist = make_join_playlist(vods_m3u8, vods, targets)
join_playlist_path = target_dir / "playlist_downloaded.m3u8"

View File

@ -4,7 +4,7 @@ import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional
from typing import Iterable, Optional, Tuple
import httpx
@ -121,13 +121,13 @@ async def download_with_retries(
async def download_all(
sources: List[str],
targets: List[Path],
source_targets: Iterable[Tuple[str, Path]],
workers: int,
*,
count: Optional[int] = None,
rate_limit: Optional[int] = None,
):
progress = Progress(len(sources))
progress = Progress(count)
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
semaphore = asyncio.Semaphore(workers)
@ -141,7 +141,7 @@ async def download_all(
progress,
token_bucket,
)
for task_id, (source, target) in enumerate(zip(sources, targets))
for task_id, (source, target) in enumerate(source_targets)
]
await asyncio.gather(*tasks)