mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Feature/lru caching 2 (#4657)
* fix(nodes): do not disable invocation cache delete methods When the runtime disabled flag is on, do not skip the delete methods. This could lead to a hit on a missing resource. Do skip them when the cache size is 0, because the user cannot change this (must restart app to change it). * fix(nodes): do not use double-underscores in cache service * Thread lock for cache * Making cache LRU * Bug fixes * bugfix * Switching to one Lock and OrderedDict cache * Removing unused imports * Move lock cache instance * Addressing PR comments --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: Martin Kristiansen <martin@modyfi.io>
This commit is contained in:
parent
f8392b2f78
commit
a2613948d8
@ -1,4 +1,7 @@
|
|||||||
from queue import Queue
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from threading import Lock
|
||||||
|
from time import time
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
@ -7,22 +10,28 @@ from invokeai.app.services.invocation_cache.invocation_cache_common import Invoc
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(order=True)
|
||||||
|
class CachedItem:
|
||||||
|
invocation_output: BaseInvocationOutput = field(compare=False)
|
||||||
|
invocation_output_json: str = field(compare=False)
|
||||||
|
|
||||||
|
|
||||||
class MemoryInvocationCache(InvocationCacheBase):
|
class MemoryInvocationCache(InvocationCacheBase):
|
||||||
_cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
|
_cache: OrderedDict[Union[int, str], CachedItem]
|
||||||
_max_cache_size: int
|
_max_cache_size: int
|
||||||
_disabled: bool
|
_disabled: bool
|
||||||
_hits: int
|
_hits: int
|
||||||
_misses: int
|
_misses: int
|
||||||
_cache_ids: Queue
|
|
||||||
_invoker: Invoker
|
_invoker: Invoker
|
||||||
|
_lock: Lock
|
||||||
|
|
||||||
def __init__(self, max_cache_size: int = 0) -> None:
|
def __init__(self, max_cache_size: int = 0) -> None:
|
||||||
self._cache = dict()
|
self._cache = OrderedDict()
|
||||||
self._max_cache_size = max_cache_size
|
self._max_cache_size = max_cache_size
|
||||||
self._disabled = False
|
self._disabled = False
|
||||||
self._hits = 0
|
self._hits = 0
|
||||||
self._misses = 0
|
self._misses = 0
|
||||||
self._cache_ids = Queue()
|
self._lock = Lock()
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._invoker = invoker
|
self._invoker = invoker
|
||||||
@ -32,80 +41,87 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
||||||
|
|
||||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||||
if self._max_cache_size == 0 or self._disabled:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0 or self._disabled:
|
||||||
|
return None
|
||||||
item = self._cache.get(key, None)
|
item = self._cache.get(key, None)
|
||||||
if item is not None:
|
if item is not None:
|
||||||
self._hits += 1
|
self._hits += 1
|
||||||
return item[0]
|
self._cache.move_to_end(key)
|
||||||
self._misses += 1
|
return item.invocation_output
|
||||||
|
self._misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||||
if self._max_cache_size == 0 or self._disabled:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0 or self._disabled or key in self._cache:
|
||||||
|
return
|
||||||
|
# If the cache is full, we need to remove the least used
|
||||||
|
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||||
|
self._delete_oldest_access(number_to_delete)
|
||||||
|
self._cache[key] = CachedItem(time(), invocation_output, invocation_output.json())
|
||||||
|
|
||||||
if key not in self._cache:
|
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||||
self._cache[key] = (invocation_output, invocation_output.json())
|
number_to_delete = min(number_to_delete, len(self._cache))
|
||||||
self._cache_ids.put(key)
|
for _ in range(number_to_delete):
|
||||||
if self._cache_ids.qsize() > self._max_cache_size:
|
self._cache.popitem(last=False)
|
||||||
try:
|
|
||||||
self._cache.pop(self._cache_ids.get())
|
|
||||||
except KeyError:
|
|
||||||
# this means the cache_ids are somehow out of sync w/ the cache
|
|
||||||
pass
|
|
||||||
|
|
||||||
def delete(self, key: Union[int, str]) -> None:
|
def _delete(self, key: Union[int, str]) -> None:
|
||||||
if self._max_cache_size == 0:
|
if self._max_cache_size == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if key in self._cache:
|
if key in self._cache:
|
||||||
del self._cache[key]
|
del self._cache[key]
|
||||||
|
|
||||||
|
def delete(self, key: Union[int, str]) -> None:
|
||||||
|
with self._lock:
|
||||||
|
return self._delete(key)
|
||||||
|
|
||||||
def clear(self, *args, **kwargs) -> None:
|
def clear(self, *args, **kwargs) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
|
return
|
||||||
|
self._cache.clear()
|
||||||
|
self._misses = 0
|
||||||
|
self._hits = 0
|
||||||
|
|
||||||
self._cache.clear()
|
@staticmethod
|
||||||
self._cache_ids = Queue()
|
def create_key(invocation: BaseInvocation) -> int:
|
||||||
self._misses = 0
|
|
||||||
self._hits = 0
|
|
||||||
|
|
||||||
def create_key(self, invocation: BaseInvocation) -> int:
|
|
||||||
return hash(invocation.json(exclude={"id"}))
|
return hash(invocation.json(exclude={"id"}))
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
self._disabled = True
|
return
|
||||||
|
self._disabled = True
|
||||||
|
|
||||||
def enable(self) -> None:
|
def enable(self) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
self._disabled = False
|
return
|
||||||
|
self._disabled = False
|
||||||
|
|
||||||
def get_status(self) -> InvocationCacheStatus:
|
def get_status(self) -> InvocationCacheStatus:
|
||||||
return InvocationCacheStatus(
|
with self._lock:
|
||||||
hits=self._hits,
|
return InvocationCacheStatus(
|
||||||
misses=self._misses,
|
hits=self._hits,
|
||||||
enabled=not self._disabled and self._max_cache_size > 0,
|
misses=self._misses,
|
||||||
size=len(self._cache),
|
enabled=not self._disabled and self._max_cache_size > 0,
|
||||||
max_size=self._max_cache_size,
|
size=len(self._cache),
|
||||||
)
|
max_size=self._max_cache_size,
|
||||||
|
)
|
||||||
|
|
||||||
def _delete_by_match(self, to_match: str) -> None:
|
def _delete_by_match(self, to_match: str) -> None:
|
||||||
if self._max_cache_size == 0:
|
with self._lock:
|
||||||
return
|
if self._max_cache_size == 0:
|
||||||
|
return
|
||||||
keys_to_delete = set()
|
keys_to_delete = set()
|
||||||
for key, value_tuple in self._cache.items():
|
for key, cached_item in self._cache.items():
|
||||||
if to_match in value_tuple[1]:
|
if to_match in cached_item.invocation_output_json:
|
||||||
keys_to_delete.add(key)
|
keys_to_delete.add(key)
|
||||||
|
if not keys_to_delete:
|
||||||
if not keys_to_delete:
|
return
|
||||||
return
|
for key in keys_to_delete:
|
||||||
|
self._delete(key)
|
||||||
for key in keys_to_delete:
|
self._invoker.services.logger.debug(
|
||||||
self.delete(key)
|
f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}"
|
||||||
|
)
|
||||||
self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")
|
|
||||||
|
Loading…
Reference in New Issue
Block a user