Feature/lru caching 2 (#4657)

* fix(nodes): do not disable invocation cache delete methods

When the runtime disabled flag is on, do not skip the delete methods. This could lead to a hit on a missing resource.

Do skip them when the cache size is 0, because the user cannot change this (must restart app to change it).

* fix(nodes): do not use double-underscores in cache service

* Thread lock for cache

* Making cache LRU

* Bug fixes

* bugfix

* Switching to one Lock and OrderedDict cache

* Removing unused imports

* Move lock cache instance

* Addressing PR comments

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Co-authored-by: Martin Kristiansen <martin@modyfi.io>
This commit is contained in:
Martin Kristiansen 2023-09-25 23:42:09 -04:00 committed by GitHub
parent f8392b2f78
commit a2613948d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,7 @@
from queue import Queue from collections import OrderedDict
from dataclasses import dataclass, field
from threading import Lock
from time import time
from typing import Optional, Union from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@ -7,22 +10,28 @@ from invokeai.app.services.invocation_cache.invocation_cache_common import Invoc
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
@dataclass(order=True)
class CachedItem:
invocation_output: BaseInvocationOutput = field(compare=False)
invocation_output_json: str = field(compare=False)
class MemoryInvocationCache(InvocationCacheBase): class MemoryInvocationCache(InvocationCacheBase):
_cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]] _cache: OrderedDict[Union[int, str], CachedItem]
_max_cache_size: int _max_cache_size: int
_disabled: bool _disabled: bool
_hits: int _hits: int
_misses: int _misses: int
_cache_ids: Queue
_invoker: Invoker _invoker: Invoker
_lock: Lock
def __init__(self, max_cache_size: int = 0) -> None: def __init__(self, max_cache_size: int = 0) -> None:
self._cache = dict() self._cache = OrderedDict()
self._max_cache_size = max_cache_size self._max_cache_size = max_cache_size
self._disabled = False self._disabled = False
self._hits = 0 self._hits = 0
self._misses = 0 self._misses = 0
self._cache_ids = Queue() self._lock = Lock()
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self._invoker = invoker self._invoker = invoker
@ -32,59 +41,67 @@ class MemoryInvocationCache(InvocationCacheBase):
self._invoker.services.latents.on_deleted(self._delete_by_match) self._invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
with self._lock:
if self._max_cache_size == 0 or self._disabled: if self._max_cache_size == 0 or self._disabled:
return return None
item = self._cache.get(key, None) item = self._cache.get(key, None)
if item is not None: if item is not None:
self._hits += 1 self._hits += 1
return item[0] self._cache.move_to_end(key)
return item.invocation_output
self._misses += 1 self._misses += 1
return None
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self._max_cache_size == 0 or self._disabled: with self._lock:
if self._max_cache_size == 0 or self._disabled or key in self._cache:
return return
# If the cache is full, we need to remove the least used
number_to_delete = len(self._cache) + 1 - self._max_cache_size
self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(time(), invocation_output, invocation_output.json())
if key not in self._cache: def _delete_oldest_access(self, number_to_delete: int) -> None:
self._cache[key] = (invocation_output, invocation_output.json()) number_to_delete = min(number_to_delete, len(self._cache))
self._cache_ids.put(key) for _ in range(number_to_delete):
if self._cache_ids.qsize() > self._max_cache_size: self._cache.popitem(last=False)
try:
self._cache.pop(self._cache_ids.get())
except KeyError:
# this means the cache_ids are somehow out of sync w/ the cache
pass
def delete(self, key: Union[int, str]) -> None: def _delete(self, key: Union[int, str]) -> None:
if self._max_cache_size == 0: if self._max_cache_size == 0:
return return
if key in self._cache: if key in self._cache:
del self._cache[key] del self._cache[key]
def delete(self, key: Union[int, str]) -> None:
with self._lock:
return self._delete(key)
def clear(self, *args, **kwargs) -> None: def clear(self, *args, **kwargs) -> None:
with self._lock:
if self._max_cache_size == 0: if self._max_cache_size == 0:
return return
self._cache.clear() self._cache.clear()
self._cache_ids = Queue()
self._misses = 0 self._misses = 0
self._hits = 0 self._hits = 0
def create_key(self, invocation: BaseInvocation) -> int: @staticmethod
def create_key(invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"})) return hash(invocation.json(exclude={"id"}))
def disable(self) -> None: def disable(self) -> None:
with self._lock:
if self._max_cache_size == 0: if self._max_cache_size == 0:
return return
self._disabled = True self._disabled = True
def enable(self) -> None: def enable(self) -> None:
with self._lock:
if self._max_cache_size == 0: if self._max_cache_size == 0:
return return
self._disabled = False self._disabled = False
def get_status(self) -> InvocationCacheStatus: def get_status(self) -> InvocationCacheStatus:
with self._lock:
return InvocationCacheStatus( return InvocationCacheStatus(
hits=self._hits, hits=self._hits,
misses=self._misses, misses=self._misses,
@ -94,18 +111,17 @@ class MemoryInvocationCache(InvocationCacheBase):
) )
def _delete_by_match(self, to_match: str) -> None: def _delete_by_match(self, to_match: str) -> None:
with self._lock:
if self._max_cache_size == 0: if self._max_cache_size == 0:
return return
keys_to_delete = set() keys_to_delete = set()
for key, value_tuple in self._cache.items(): for key, cached_item in self._cache.items():
if to_match in value_tuple[1]: if to_match in cached_item.invocation_output_json:
keys_to_delete.add(key) keys_to_delete.add(key)
if not keys_to_delete: if not keys_to_delete:
return return
for key in keys_to_delete: for key in keys_to_delete:
self.delete(key) self._delete(key)
self._invoker.services.logger.debug(
self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}") f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}"
)