import gc
from typing import Optional

import psutil
import torch

from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2

GB = 2**30  # 1 GB


class MemorySnapshot:
    """A snapshot of RAM and VRAM usage. All values are in bytes."""

    def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]):
        """Initialize a MemorySnapshot.

        Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.

        Args:
            process_ram (int): CPU RAM used by the current process.
            vram (Optional[int]): VRAM used by torch.
            malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil.
        """
        self.process_ram = process_ram
        self.vram = vram
        self.malloc_info = malloc_info

    @classmethod
    def capture(cls, run_garbage_collector: bool = True):
        """Capture and return a MemorySnapshot.

        Note: This function has significant overhead, particularly if `run_garbage_collector == True`.

        Args:
            run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
                usage. Defaults to True.

        Returns:
            MemorySnapshot
        """
        if run_garbage_collector:
            gc.collect()

        # According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
        # supported on all platforms.
        process_ram = psutil.Process().memory_info().rss

        if torch.cuda.is_available():
            vram = torch.cuda.memory_allocated()
        else:
            # TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
            # time to test it properly.
            vram = None

        try:
            malloc_info = LibcUtil().mallinfo2()
        except (OSError, AttributeError):
            # OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
            # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)
            # TODO: Does `mallinfo` work?
            malloc_info = None

        return cls(process_ram, vram, malloc_info)


def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
    """Get a pretty string describing the difference between two `MemorySnapshot`s."""

    def get_msg_line(prefix: str, val1: int, val2: int):
        diff = val2 - val1
        return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"

    msg = ""

    if snapshot_1 is None or snapshot_2 is None:
        return msg

    msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)

    if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
        msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd)

        msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks)

        msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks)

        libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd
        libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd
        msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2)

        libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd
        libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
        msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)

    if snapshot_1.vram is not None and snapshot_2.vram is not None:
        msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)

    return msg