diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index be07029f4d..b40243f285 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -1,4 +1,7 @@ -from queue import Queue +from collections import OrderedDict +from dataclasses import dataclass, field +from threading import Lock +from time import time from typing import Optional, Union from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput @@ -7,22 +10,28 @@ from invokeai.app.services.invocation_cache.invocation_cache_common import Invoc from invokeai.app.services.invoker import Invoker +@dataclass(order=True) +class CachedItem: + invocation_output: BaseInvocationOutput = field(compare=False) + invocation_output_json: str = field(compare=False) + + class MemoryInvocationCache(InvocationCacheBase): - _cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]] + _cache: OrderedDict[Union[int, str], CachedItem] _max_cache_size: int _disabled: bool _hits: int _misses: int - _cache_ids: Queue _invoker: Invoker + _lock: Lock def __init__(self, max_cache_size: int = 0) -> None: - self._cache = dict() + self._cache = OrderedDict() self._max_cache_size = max_cache_size self._disabled = False self._hits = 0 self._misses = 0 - self._cache_ids = Queue() + self._lock = Lock() def start(self, invoker: Invoker) -> None: self._invoker = invoker @@ -32,80 +41,87 @@ class MemoryInvocationCache(InvocationCacheBase): self._invoker.services.latents.on_deleted(self._delete_by_match) def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: - if self._max_cache_size == 0 or self._disabled: - return - - item = self._cache.get(key, None) - if item is not None: - self._hits += 1 - return item[0] - self._misses += 1 + with self._lock: + if self._max_cache_size == 0 or self._disabled: + return None + item = self._cache.get(key, None) + if item is not None: + self._hits += 1 + self._cache.move_to_end(key) + return item.invocation_output + self._misses += 1 + return None def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: - if self._max_cache_size == 0 or self._disabled: - return + with self._lock: + if self._max_cache_size == 0 or self._disabled or key in self._cache: + return + # If the cache is full, we need to remove the least used + number_to_delete = len(self._cache) + 1 - self._max_cache_size + self._delete_oldest_access(number_to_delete) + self._cache[key] = CachedItem(time(), invocation_output, invocation_output.json()) - if key not in self._cache: - self._cache[key] = (invocation_output, invocation_output.json()) - self._cache_ids.put(key) - if self._cache_ids.qsize() > self._max_cache_size: - try: - self._cache.pop(self._cache_ids.get()) - except KeyError: - # this means the cache_ids are somehow out of sync w/ the cache - pass + def _delete_oldest_access(self, number_to_delete: int) -> None: + number_to_delete = min(number_to_delete, len(self._cache)) + for _ in range(number_to_delete): + self._cache.popitem(last=False) - def delete(self, key: Union[int, str]) -> None: + def _delete(self, key: Union[int, str]) -> None: if self._max_cache_size == 0: return - if key in self._cache: del self._cache[key] + def delete(self, key: Union[int, str]) -> None: + with self._lock: + return self._delete(key) + def clear(self, *args, **kwargs) -> None: - if self._max_cache_size == 0: - return + with self._lock: + if self._max_cache_size == 0: + return + self._cache.clear() + self._misses = 0 + self._hits = 0 - self._cache.clear() - self._cache_ids = Queue() - self._misses = 0 - self._hits = 0 - - def create_key(self, invocation: BaseInvocation) -> int: + @staticmethod + def create_key(invocation: BaseInvocation) -> int: return hash(invocation.json(exclude={"id"})) def disable(self) -> None: - if self._max_cache_size == 0: - return - self._disabled = True + with self._lock: + if self._max_cache_size == 0: + return + self._disabled = True def enable(self) -> None: - if self._max_cache_size == 0: - return - self._disabled = False + with self._lock: + if self._max_cache_size == 0: + return + self._disabled = False def get_status(self) -> InvocationCacheStatus: - return InvocationCacheStatus( - hits=self._hits, - misses=self._misses, - enabled=not self._disabled and self._max_cache_size > 0, - size=len(self._cache), - max_size=self._max_cache_size, - ) + with self._lock: + return InvocationCacheStatus( + hits=self._hits, + misses=self._misses, + enabled=not self._disabled and self._max_cache_size > 0, + size=len(self._cache), + max_size=self._max_cache_size, + ) def _delete_by_match(self, to_match: str) -> None: - if self._max_cache_size == 0: - return - - keys_to_delete = set() - for key, value_tuple in self._cache.items(): - if to_match in value_tuple[1]: - keys_to_delete.add(key) - - if not keys_to_delete: - return - - for key in keys_to_delete: - self.delete(key) - - self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}") + with self._lock: + if self._max_cache_size == 0: + return + keys_to_delete = set() + for key, cached_item in self._cache.items(): + if to_match in cached_item.invocation_output_json: + keys_to_delete.add(key) + if not keys_to_delete: + return + for key in keys_to_delete: + self._delete(key) + self._invoker.services.logger.debug( + f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}" + )