Use a queue+workers instead of semaphore

This commit is contained in:
Ivan Habunek 2024-08-30 15:35:12 +02:00
parent 8c68132ddb
commit da51ffc31f
No known key found for this signature in database
GPG Key ID: 01DB3DD0D824504C

View File

@ -4,7 +4,7 @@ import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Iterable, Optional, Tuple
from typing import Iterable, NamedTuple, Optional, Tuple
import httpx
@ -95,55 +95,81 @@ async def download(
async def download_with_retries(
client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
task_id: int,
source: str,
target: Path,
progress: Progress,
token_bucket: TokenBucket,
):
async with semaphore:
if target.exists():
size = os.path.getsize(target)
progress.already_downloaded(task_id, size)
return
if target.exists():
size = os.path.getsize(target)
progress.already_downloaded(task_id, size)
return
for n in range(RETRY_COUNT):
try:
return await download(client, task_id, source, target, progress, token_bucket)
except httpx.RequestError:
logger.exception("Task {task_id} failed. Retrying. Maybe.")
progress.abort(task_id)
if n + 1 >= RETRY_COUNT:
raise
for n in range(RETRY_COUNT):
try:
return await download(client, task_id, source, target, progress, token_bucket)
except httpx.RequestError:
logger.exception("Task {task_id} failed. Retrying. Maybe.")
progress.abort(task_id)
if n + 1 >= RETRY_COUNT:
raise
raise Exception("Should not happen")
raise Exception("Should not happen")
class QueueItem(NamedTuple):
task_id: int
url: str
target: Path
async def download_worker(
queue: asyncio.Queue[QueueItem],
client: httpx.AsyncClient,
progress: Progress,
token_bucket: TokenBucket,
):
while True:
item = await queue.get()
await download_with_retries(
client,
item.task_id,
item.url,
item.target,
progress,
token_bucket,
)
queue.task_done()
async def download_all(
source_targets: Iterable[Tuple[str, Path]],
workers: int,
worker_count: int,
*,
count: Optional[int] = None,
rate_limit: Optional[int] = None,
):
progress = Progress(count)
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
queue: asyncio.Queue[QueueItem] = asyncio.Queue()
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
semaphore = asyncio.Semaphore(workers)
tasks = [
download_with_retries(
client,
semaphore,
task_id,
source,
target,
progress,
token_bucket,
)
for task_id, (source, target) in enumerate(source_targets)
asyncio.create_task(download_worker(queue, client, progress, token_bucket))
for _ in range(worker_count)
]
await asyncio.gather(*tasks)
for index, (source, target) in enumerate(source_targets):
await queue.put(QueueItem(index, source, target))
# Wait for queue to deplete
await queue.join()
# Cancel tasks and wait until they are cancelled
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None: