import asyncio
import httpx
import logging
import os
import time

from abc import ABC, abstractmethod
from typing import List, Optional

from twitchdl.progress import Progress

logger = logging.getLogger(__name__)

KB = 1024

CHUNK_SIZE = 256 * KB
"""How much of a VOD to download in each iteration"""

RETRY_COUNT = 5
"""Number of times to retry failed downloads before aborting."""

TIMEOUT = 30
"""
Number of seconds to wait before aborting when there is no network activity.
https://www.python-httpx.org/advanced/#timeout-configuration
"""


class TokenBucket(ABC):
    @abstractmethod
    def advance(self, size: int):
        pass


class LimitingTokenBucket(TokenBucket):
    """Limit the download speed by strategically inserting sleeps."""

    def __init__(self, rate: int, capacity: Optional[int] = None):
        self.rate: int = rate
        self.capacity: int = capacity or rate * 2
        self.available: int = 0
        self.last_refilled: float = time.time()

    def advance(self, size: int):
        """Called every time a chunk of data is downloaded."""
        self._refill()

        if self.available < size:
            deficit = size - self.available
            time.sleep(deficit / self.rate)

        self.available -= size

    def _refill(self):
        """Increase available capacity according to elapsed time since last refill."""
        now = time.time()
        elapsed = now - self.last_refilled
        refill_amount = int(elapsed * self.rate)
        self.available = min(self.available + refill_amount, self.capacity)
        self.last_refilled = now


class EndlessTokenBucket(TokenBucket):
    """Used when download speed is not limited."""
    def advance(self, size: int):
        pass


async def download(
    client: httpx.AsyncClient,
    task_id: int,
    source: str,
    target: str,
    progress: Progress,
    token_bucket: TokenBucket,
):
    # Download to a temp file first, then copy to target when over to avoid
    # getting saving chunks which may persist if canceled or --keep is used
    tmp_target = f"{target}.tmp"
    with open(tmp_target, "wb") as f:
        async with client.stream("GET", source) as response:
            size = int(response.headers.get("content-length"))
            progress.start(task_id, size)
            async for chunk in response.aiter_bytes(chunk_size=CHUNK_SIZE):
                f.write(chunk)
                size = len(chunk)
                token_bucket.advance(size)
                progress.advance(task_id, size)
            progress.end(task_id)
    os.rename(tmp_target, target)


async def download_with_retries(
    client: httpx.AsyncClient,
    semaphore: asyncio.Semaphore,
    task_id: int,
    source: str,
    target: str,
    progress: Progress,
    token_bucket: TokenBucket,
):
    async with semaphore:
        if os.path.exists(target):
            size = os.path.getsize(target)
            progress.already_downloaded(task_id, size)
            return

        for n in range(RETRY_COUNT):
            try:
                return await download(client, task_id, source, target, progress, token_bucket)
            except httpx.RequestError:
                logger.exception("Task {task_id} failed. Retrying. Maybe.")
                progress.abort(task_id)
                if n + 1 >= RETRY_COUNT:
                    raise

        raise Exception("Should not happen")


async def download_all(
    sources: List[str],
    targets: List[str],
    workers: int,
    *,
    rate_limit: Optional[int] = None
):
    progress = Progress(len(sources))
    token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
    async with httpx.AsyncClient(timeout=TIMEOUT) as client:
        semaphore = asyncio.Semaphore(workers)
        tasks = [download_with_retries(client, semaphore, task_id, source, target, progress, token_bucket)
                 for task_id, (source, target) in enumerate(zip(sources, targets))]
        await asyncio.gather(*tasks)