diff --git a/tests/test_invocation_cache_memory.py b/tests/test_invocation_cache_memory.py new file mode 100644 index 0000000000..f776b86ee2 --- /dev/null +++ b/tests/test_invocation_cache_memory.py @@ -0,0 +1,183 @@ +# pyright: reportPrivateUsage=false +from contextlib import suppress + +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache +from tests.test_nodes import PromptTestInvocation + + +def test_invocation_cache_memory_max_cache_size(): + cache = MemoryInvocationCache() + assert cache._max_cache_size == 0 + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + cache.save(1, output_1) + assert cache.get(1) is None + assert cache._hits == 0 + assert cache._misses == 0 # TODO: when cache size is zero, should we consider it a miss? + assert len(cache._cache) == 0 + + +def test_invocation_cache_memory_creates_deterministic_keys(): + hash1 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo")) + hash2 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo")) + hash3 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="bar")) + + assert hash1 == hash2 + assert hash1 != hash3 + + +def test_invocation_cache_memory_adds_invocation(): + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512) + cache = MemoryInvocationCache(max_cache_size=5) + cache.save(1, output_1) + cache.save(2, output_2) + assert cache.get(1) == output_1 + assert cache.get(2) == output_2 + + +def test_invocation_cache_memory_tracks_hits(): + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + cache = MemoryInvocationCache(max_cache_size=5) + cache.save(1, output_1) + cache.get(1) # hit + cache.get(1) # hit + cache.get(1) # hit + cache.get(2) # miss + cache.get(2) # miss + assert cache._hits == 3 + assert cache._misses == 2 + + +def test_invocation_cache_memory_is_lru(): + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512) + output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512) + cache = MemoryInvocationCache(max_cache_size=2) + cache.save(1, output_1) + cache.save(2, output_2) + cache.save(3, output_3) + assert cache.get(1) is None + assert cache.get(2) == output_2 + assert cache.get(3) == output_3 + assert len(cache._cache) == 2 + assert list(cache._cache.keys()) == [2, 3] + cache.get(2) + assert list(cache._cache.keys()) == [3, 2] + + +def test_invocation_cache_memory_disables_and_enables(): + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512) + cache = MemoryInvocationCache(max_cache_size=2) + cache.save(1, output_1) + cache.disable() + assert cache.get(1) is None + cache.save(2, output_2) + assert cache.get(2) is None + assert len(cache._cache) == 1 + assert cache._hits == 0 + assert cache._misses == 0 + cache.enable() + cache.save(2, output_2) + assert cache.get(2) is output_2 + assert len(cache._cache) == 2 + assert cache._hits == 1 + assert cache._misses == 0 + + +def test_invocation_cache_memory_deletes_by_match(): + # The _delete_by_match method attempts to log but the logger is not set up in the test environment + with suppress(AttributeError): + cache = MemoryInvocationCache(max_cache_size=5) + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512) + output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512) + cache.save(1, output_1) + cache.save(2, output_2) + cache.save(3, output_3) + cache._delete_by_match("bar") + assert cache.get(1) == output_1 + assert cache.get(2) is None + assert cache.get(3) == output_3 + assert len(cache._cache) == 2 + assert list(cache._cache.keys()) == [1, 3] + cache._delete_by_match("foo") + assert cache.get(1) is None + assert cache.get(2) is None + assert cache.get(3) == output_3 + assert len(cache._cache) == 1 + assert list(cache._cache.keys()) == [3] + cache._delete_by_match("baz") + assert cache.get(1) is None + assert cache.get(2) is None + assert cache.get(3) is None + assert len(cache._cache) == 0 + assert list(cache._cache.keys()) == [] + # shouldn't raise on empty cache + cache._delete_by_match("foo") + + +def test_invocation_cache_memory_clears(): + cache = MemoryInvocationCache(max_cache_size=5) + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512) + output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512) + cache.save(1, output_1) + cache.save(2, output_2) + cache.save(3, output_3) + cache.get(1) + cache.get(2) + cache.get(3) + cache.get("foo") # miss + cache.get("bar") # miss + cache.clear() + assert len(cache._cache) == 0 + assert cache._hits == 0 + assert cache._misses == 0 + assert cache._misses == 0 + assert cache.get(1) is None + assert cache.get(2) is None + assert cache.get(3) is None + + +def test_invocation_cache_memory_status(): + cache = MemoryInvocationCache(max_cache_size=5) + output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512) + output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512) + output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512) + cache.save(1, output_1) + cache.save(2, output_2) + cache.save(3, output_3) + cache.get(1) + cache.get(2) + cache.get(3) + cache.get("foo") # miss + cache.get("bar") # miss + status = cache.get_status() + assert status.hits == 3 + assert status.misses == 2 + assert status.enabled + assert status.size == 3 + assert status.max_size == 5 + cache.disable() + status = cache.get_status() + assert not status.enabled + cache.enable() + status = cache.get_status() + assert status.enabled + cache.clear() + status = cache.get_status() + assert status.size == 0 + assert status.hits == 0 + assert status.misses == 0 + assert status.enabled + assert status.max_size == 5 + cache._max_cache_size = 0 # cache should be disabled when max_cache_size is zero + status = cache.get_status() + assert not status.enabled + assert status.size == 0 + assert status.hits == 0 + assert status.misses == 0 + assert status.max_size == 0