InvokeAI/tests/test_invocation_cache_memory.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

184 lines
6.7 KiB
Python
Raw Normal View History

2024-03-06 12:18:48 +00:00
# pyright: reportPrivateUsage=false
from contextlib import suppress
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from tests.test_nodes import PromptTestInvocation
def test_invocation_cache_memory_max_cache_size():
cache = MemoryInvocationCache()
assert cache._max_cache_size == 0
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
cache.save(1, output_1)
assert cache.get(1) is None
assert cache._hits == 0
assert cache._misses == 0 # TODO: when cache size is zero, should we consider it a miss?
assert len(cache._cache) == 0
def test_invocation_cache_memory_creates_deterministic_keys():
hash1 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo"))
hash2 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo"))
hash3 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="bar"))
assert hash1 == hash2
assert hash1 != hash3
def test_invocation_cache_memory_adds_invocation():
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
cache = MemoryInvocationCache(max_cache_size=5)
cache.save(1, output_1)
cache.save(2, output_2)
assert cache.get(1) == output_1
assert cache.get(2) == output_2
def test_invocation_cache_memory_tracks_hits():
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
cache = MemoryInvocationCache(max_cache_size=5)
cache.save(1, output_1)
cache.get(1) # hit
cache.get(1) # hit
cache.get(1) # hit
cache.get(2) # miss
cache.get(2) # miss
assert cache._hits == 3
assert cache._misses == 2
def test_invocation_cache_memory_is_lru():
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
cache = MemoryInvocationCache(max_cache_size=2)
cache.save(1, output_1)
cache.save(2, output_2)
cache.save(3, output_3)
assert cache.get(1) is None
assert cache.get(2) == output_2
assert cache.get(3) == output_3
assert len(cache._cache) == 2
assert list(cache._cache.keys()) == [2, 3]
cache.get(2)
assert list(cache._cache.keys()) == [3, 2]
def test_invocation_cache_memory_disables_and_enables():
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
cache = MemoryInvocationCache(max_cache_size=2)
cache.save(1, output_1)
cache.disable()
assert cache.get(1) is None
cache.save(2, output_2)
assert cache.get(2) is None
assert len(cache._cache) == 1
assert cache._hits == 0
assert cache._misses == 0
cache.enable()
cache.save(2, output_2)
assert cache.get(2) is output_2
assert len(cache._cache) == 2
assert cache._hits == 1
assert cache._misses == 0
def test_invocation_cache_memory_deletes_by_match():
# The _delete_by_match method attempts to log but the logger is not set up in the test environment
with suppress(AttributeError):
cache = MemoryInvocationCache(max_cache_size=5)
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
cache.save(1, output_1)
cache.save(2, output_2)
cache.save(3, output_3)
cache._delete_by_match("bar")
assert cache.get(1) == output_1
assert cache.get(2) is None
assert cache.get(3) == output_3
assert len(cache._cache) == 2
assert list(cache._cache.keys()) == [1, 3]
cache._delete_by_match("foo")
assert cache.get(1) is None
assert cache.get(2) is None
assert cache.get(3) == output_3
assert len(cache._cache) == 1
assert list(cache._cache.keys()) == [3]
cache._delete_by_match("baz")
assert cache.get(1) is None
assert cache.get(2) is None
assert cache.get(3) is None
assert len(cache._cache) == 0
assert list(cache._cache.keys()) == []
# shouldn't raise on empty cache
cache._delete_by_match("foo")
def test_invocation_cache_memory_clears():
cache = MemoryInvocationCache(max_cache_size=5)
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
cache.save(1, output_1)
cache.save(2, output_2)
cache.save(3, output_3)
cache.get(1)
cache.get(2)
cache.get(3)
cache.get("foo") # miss
cache.get("bar") # miss
cache.clear()
assert len(cache._cache) == 0
assert cache._hits == 0
assert cache._misses == 0
assert cache._misses == 0
assert cache.get(1) is None
assert cache.get(2) is None
assert cache.get(3) is None
def test_invocation_cache_memory_status():
cache = MemoryInvocationCache(max_cache_size=5)
output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
cache.save(1, output_1)
cache.save(2, output_2)
cache.save(3, output_3)
cache.get(1)
cache.get(2)
cache.get(3)
cache.get("foo") # miss
cache.get("bar") # miss
status = cache.get_status()
assert status.hits == 3
assert status.misses == 2
assert status.enabled
assert status.size == 3
assert status.max_size == 5
cache.disable()
status = cache.get_status()
assert not status.enabled
cache.enable()
status = cache.get_status()
assert status.enabled
cache.clear()
status = cache.get_status()
assert status.size == 0
assert status.hits == 0
assert status.misses == 0
assert status.enabled
assert status.max_size == 5
cache._max_cache_size = 0 # cache should be disabled when max_cache_size is zero
status = cache.get_status()
assert not status.enabled
assert status.size == 0
assert status.hits == 0
assert status.misses == 0
assert status.max_size == 0