mirror of
https://github.com/ihabunek/twitch-dl
synced 2024-08-30 18:32:25 +00:00
Use a queue+workers instead of semaphore
This commit is contained in:
parent
8c68132ddb
commit
da51ffc31f
@ -4,7 +4,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@ -95,14 +95,12 @@ async def download(
|
|||||||
|
|
||||||
async def download_with_retries(
|
async def download_with_retries(
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
semaphore: asyncio.Semaphore,
|
|
||||||
task_id: int,
|
task_id: int,
|
||||||
source: str,
|
source: str,
|
||||||
target: Path,
|
target: Path,
|
||||||
progress: Progress,
|
progress: Progress,
|
||||||
token_bucket: TokenBucket,
|
token_bucket: TokenBucket,
|
||||||
):
|
):
|
||||||
async with semaphore:
|
|
||||||
if target.exists():
|
if target.exists():
|
||||||
size = os.path.getsize(target)
|
size = os.path.getsize(target)
|
||||||
progress.already_downloaded(task_id, size)
|
progress.already_downloaded(task_id, size)
|
||||||
@ -120,30 +118,58 @@ async def download_with_retries(
|
|||||||
raise Exception("Should not happen")
|
raise Exception("Should not happen")
|
||||||
|
|
||||||
|
|
||||||
|
class QueueItem(NamedTuple):
|
||||||
|
task_id: int
|
||||||
|
url: str
|
||||||
|
target: Path
|
||||||
|
|
||||||
|
|
||||||
|
async def download_worker(
|
||||||
|
queue: asyncio.Queue[QueueItem],
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
progress: Progress,
|
||||||
|
token_bucket: TokenBucket,
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
await download_with_retries(
|
||||||
|
client,
|
||||||
|
item.task_id,
|
||||||
|
item.url,
|
||||||
|
item.target,
|
||||||
|
progress,
|
||||||
|
token_bucket,
|
||||||
|
)
|
||||||
|
queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
async def download_all(
|
async def download_all(
|
||||||
source_targets: Iterable[Tuple[str, Path]],
|
source_targets: Iterable[Tuple[str, Path]],
|
||||||
workers: int,
|
worker_count: int,
|
||||||
*,
|
*,
|
||||||
count: Optional[int] = None,
|
count: Optional[int] = None,
|
||||||
rate_limit: Optional[int] = None,
|
rate_limit: Optional[int] = None,
|
||||||
):
|
):
|
||||||
progress = Progress(count)
|
progress = Progress(count)
|
||||||
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
|
||||||
|
queue: asyncio.Queue[QueueItem] = asyncio.Queue()
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||||
semaphore = asyncio.Semaphore(workers)
|
|
||||||
tasks = [
|
tasks = [
|
||||||
download_with_retries(
|
asyncio.create_task(download_worker(queue, client, progress, token_bucket))
|
||||||
client,
|
for _ in range(worker_count)
|
||||||
semaphore,
|
|
||||||
task_id,
|
|
||||||
source,
|
|
||||||
target,
|
|
||||||
progress,
|
|
||||||
token_bucket,
|
|
||||||
)
|
|
||||||
for task_id, (source, target) in enumerate(source_targets)
|
|
||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
for index, (source, target) in enumerate(source_targets):
|
||||||
|
await queue.put(QueueItem(index, source, target))
|
||||||
|
|
||||||
|
# Wait for queue to deplete
|
||||||
|
await queue.join()
|
||||||
|
|
||||||
|
# Cancel tasks and wait until they are cancelled
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None:
|
def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user