# pyright: reportPrivateUsage=false
from contextlib import suppress

from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from tests.test_nodes import PromptTestInvocation


def test_invocation_cache_memory_max_cache_size():
    cache = MemoryInvocationCache()
    assert cache._max_cache_size == 0
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    cache.save(1, output_1)
    assert cache.get(1) is None
    assert cache._hits == 0
    assert cache._misses == 0  # TODO: when cache size is zero, should we consider it a miss?
    assert len(cache._cache) == 0


def test_invocation_cache_memory_creates_deterministic_keys():
    hash1 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo"))
    hash2 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo"))
    hash3 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="bar"))

    assert hash1 == hash2
    assert hash1 != hash3


def test_invocation_cache_memory_adds_invocation():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=5)
    cache.save(1, output_1)
    cache.save(2, output_2)
    assert cache.get(1) == output_1
    assert cache.get(2) == output_2


def test_invocation_cache_memory_tracks_hits():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=5)
    cache.save(1, output_1)
    cache.get(1)  # hit
    cache.get(1)  # hit
    cache.get(1)  # hit
    cache.get(2)  # miss
    cache.get(2)  # miss
    assert cache._hits == 3
    assert cache._misses == 2


def test_invocation_cache_memory_is_lru():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=2)
    cache.save(1, output_1)
    cache.save(2, output_2)
    cache.save(3, output_3)
    assert cache.get(1) is None
    assert cache.get(2) == output_2
    assert cache.get(3) == output_3
    assert len(cache._cache) == 2
    assert list(cache._cache.keys()) == [2, 3]
    cache.get(2)
    assert list(cache._cache.keys()) == [3, 2]


def test_invocation_cache_memory_disables_and_enables():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=2)
    cache.save(1, output_1)
    cache.disable()
    assert cache.get(1) is None
    cache.save(2, output_2)
    assert cache.get(2) is None
    assert len(cache._cache) == 1
    assert cache._hits == 0
    assert cache._misses == 0
    cache.enable()
    cache.save(2, output_2)
    assert cache.get(2) is output_2
    assert len(cache._cache) == 2
    assert cache._hits == 1
    assert cache._misses == 0


def test_invocation_cache_memory_deletes_by_match():
    # The _delete_by_match method attempts to log but the logger is not set up in the test environment
    with suppress(AttributeError):
        cache = MemoryInvocationCache(max_cache_size=5)
        output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
        output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
        output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
        cache.save(1, output_1)
        cache.save(2, output_2)
        cache.save(3, output_3)
        cache._delete_by_match("bar")
        assert cache.get(1) == output_1
        assert cache.get(2) is None
        assert cache.get(3) == output_3
        assert len(cache._cache) == 2
        assert list(cache._cache.keys()) == [1, 3]
        cache._delete_by_match("foo")
        assert cache.get(1) is None
        assert cache.get(2) is None
        assert cache.get(3) == output_3
        assert len(cache._cache) == 1
        assert list(cache._cache.keys()) == [3]
        cache._delete_by_match("baz")
        assert cache.get(1) is None
        assert cache.get(2) is None
        assert cache.get(3) is None
        assert len(cache._cache) == 0
        assert list(cache._cache.keys()) == []
        # shouldn't raise on empty cache
        cache._delete_by_match("foo")


def test_invocation_cache_memory_clears():
    cache = MemoryInvocationCache(max_cache_size=5)
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
    cache.save(1, output_1)
    cache.save(2, output_2)
    cache.save(3, output_3)
    cache.get(1)
    cache.get(2)
    cache.get(3)
    cache.get("foo")  # miss
    cache.get("bar")  # miss
    cache.clear()
    assert len(cache._cache) == 0
    assert cache._hits == 0
    assert cache._misses == 0
    assert cache._misses == 0
    assert cache.get(1) is None
    assert cache.get(2) is None
    assert cache.get(3) is None


def test_invocation_cache_memory_status():
    cache = MemoryInvocationCache(max_cache_size=5)
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
    cache.save(1, output_1)
    cache.save(2, output_2)
    cache.save(3, output_3)
    cache.get(1)
    cache.get(2)
    cache.get(3)
    cache.get("foo")  # miss
    cache.get("bar")  # miss
    status = cache.get_status()
    assert status.hits == 3
    assert status.misses == 2
    assert status.enabled
    assert status.size == 3
    assert status.max_size == 5
    cache.disable()
    status = cache.get_status()
    assert not status.enabled
    cache.enable()
    status = cache.get_status()
    assert status.enabled
    cache.clear()
    status = cache.get_status()
    assert status.size == 0
    assert status.hits == 0
    assert status.misses == 0
    assert status.enabled
    assert status.max_size == 5
    cache._max_cache_size = 0  # cache should be disabled when max_cache_size is zero
    status = cache.get_status()
    assert not status.enabled
    assert status.size == 0
    assert status.hits == 0
    assert status.misses == 0
    assert status.max_size == 0