Use a queue+workers instead of semaphore

This commit is contained in:
Ivan Habunek 2024-08-30 15:35:12 +02:00
parent 8c68132ddb
commit da51ffc31f
No known key found for this signature in database
GPG Key ID: 01DB3DD0D824504C

View File

@ -4,7 +4,7 @@ import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Iterable, Optional, Tuple from typing import Iterable, NamedTuple, Optional, Tuple
import httpx import httpx
@ -95,14 +95,12 @@ async def download(
async def download_with_retries( async def download_with_retries(
client: httpx.AsyncClient, client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
task_id: int, task_id: int,
source: str, source: str,
target: Path, target: Path,
progress: Progress, progress: Progress,
token_bucket: TokenBucket, token_bucket: TokenBucket,
): ):
async with semaphore:
if target.exists(): if target.exists():
size = os.path.getsize(target) size = os.path.getsize(target)
progress.already_downloaded(task_id, size) progress.already_downloaded(task_id, size)
@ -120,30 +118,58 @@ async def download_with_retries(
raise Exception("Should not happen") raise Exception("Should not happen")
class QueueItem(NamedTuple):
task_id: int
url: str
target: Path
async def download_worker(
queue: asyncio.Queue[QueueItem],
client: httpx.AsyncClient,
progress: Progress,
token_bucket: TokenBucket,
):
while True:
item = await queue.get()
await download_with_retries(
client,
item.task_id,
item.url,
item.target,
progress,
token_bucket,
)
queue.task_done()
async def download_all( async def download_all(
source_targets: Iterable[Tuple[str, Path]], source_targets: Iterable[Tuple[str, Path]],
workers: int, worker_count: int,
*, *,
count: Optional[int] = None, count: Optional[int] = None,
rate_limit: Optional[int] = None, rate_limit: Optional[int] = None,
): ):
progress = Progress(count) progress = Progress(count)
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket() token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
queue: asyncio.Queue[QueueItem] = asyncio.Queue()
async with httpx.AsyncClient(timeout=TIMEOUT) as client: async with httpx.AsyncClient(timeout=TIMEOUT) as client:
semaphore = asyncio.Semaphore(workers)
tasks = [ tasks = [
download_with_retries( asyncio.create_task(download_worker(queue, client, progress, token_bucket))
client, for _ in range(worker_count)
semaphore,
task_id,
source,
target,
progress,
token_bucket,
)
for task_id, (source, target) in enumerate(source_targets)
] ]
await asyncio.gather(*tasks)
for index, (source, target) in enumerate(source_targets):
await queue.put(QueueItem(index, source, target))
# Wait for queue to deplete
await queue.join()
# Cancel tasks and wait until they are cancelled
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None: def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None: