fix(nodes): do not use double-underscores in cache service

This commit is contained in:
psychedelicious 2023-09-22 16:49:13 +10:00 committed by Kent Keirsey
parent 7d683b4db6
commit 7544eadd48

View File

@ -8,97 +8,97 @@ from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase): class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]] _cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int _max_cache_size: int
__disabled: bool _disabled: bool
__hits: int _hits: int
__misses: int _misses: int
__cache_ids: Queue _cache_ids: Queue
__invoker: Invoker _invoker: Invoker
def __init__(self, max_cache_size: int = 0) -> None: def __init__(self, max_cache_size: int = 0) -> None:
self.__cache = dict() self._cache = dict()
self.__max_cache_size = max_cache_size self._max_cache_size = max_cache_size
self.__disabled = False self._disabled = False
self.__hits = 0 self._hits = 0
self.__misses = 0 self._misses = 0
self.__cache_ids = Queue() self._cache_ids = Queue()
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self.__invoker = invoker self._invoker = invoker
if self.__max_cache_size == 0: if self._max_cache_size == 0:
return return
self.__invoker.services.images.on_deleted(self._delete_by_match) self._invoker.services.images.on_deleted(self._delete_by_match)
self.__invoker.services.latents.on_deleted(self._delete_by_match) self._invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0 or self.__disabled: if self._max_cache_size == 0 or self._disabled:
return return
item = self.__cache.get(key, None) item = self._cache.get(key, None)
if item is not None: if item is not None:
self.__hits += 1 self._hits += 1
return item[0] return item[0]
self.__misses += 1 self._misses += 1
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0 or self.__disabled: if self._max_cache_size == 0 or self._disabled:
return return
if key not in self.__cache: if key not in self._cache:
self.__cache[key] = (invocation_output, invocation_output.json()) self._cache[key] = (invocation_output, invocation_output.json())
self.__cache_ids.put(key) self._cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size: if self._cache_ids.qsize() > self._max_cache_size:
try: try:
self.__cache.pop(self.__cache_ids.get()) self._cache.pop(self._cache_ids.get())
except KeyError: except KeyError:
# this means the cache_ids are somehow out of sync w/ the cache # this means the cache_ids are somehow out of sync w/ the cache
pass pass
def delete(self, key: Union[int, str]) -> None: def delete(self, key: Union[int, str]) -> None:
if self.__max_cache_size == 0: if self._max_cache_size == 0:
return return
if key in self.__cache: if key in self._cache:
del self.__cache[key] del self._cache[key]
def clear(self, *args, **kwargs) -> None: def clear(self, *args, **kwargs) -> None:
if self.__max_cache_size == 0: if self._max_cache_size == 0:
return return
self.__cache.clear() self._cache.clear()
self.__cache_ids = Queue() self._cache_ids = Queue()
self.__misses = 0 self._misses = 0
self.__hits = 0 self._hits = 0
def create_key(self, invocation: BaseInvocation) -> int: def create_key(self, invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"})) return hash(invocation.json(exclude={"id"}))
def disable(self) -> None: def disable(self) -> None:
if self.__max_cache_size == 0: if self._max_cache_size == 0:
return return
self.__disabled = True self._disabled = True
def enable(self) -> None: def enable(self) -> None:
if self.__max_cache_size == 0: if self._max_cache_size == 0:
return return
self.__disabled = False self._disabled = False
def get_status(self) -> InvocationCacheStatus: def get_status(self) -> InvocationCacheStatus:
return InvocationCacheStatus( return InvocationCacheStatus(
hits=self.__hits, hits=self._hits,
misses=self.__misses, misses=self._misses,
enabled=not self.__disabled and self.__max_cache_size > 0, enabled=not self._disabled and self._max_cache_size > 0,
size=len(self.__cache), size=len(self._cache),
max_size=self.__max_cache_size, max_size=self._max_cache_size,
) )
def _delete_by_match(self, to_match: str) -> None: def _delete_by_match(self, to_match: str) -> None:
if self.__max_cache_size == 0: if self._max_cache_size == 0:
return return
keys_to_delete = set() keys_to_delete = set()
for key, value_tuple in self.__cache.items(): for key, value_tuple in self._cache.items():
if to_match in value_tuple[1]: if to_match in value_tuple[1]:
keys_to_delete.add(key) keys_to_delete.add(key)
@ -108,4 +108,4 @@ class MemoryInvocationCache(InvocationCacheBase):
for key in keys_to_delete: for key in keys_to_delete:
self.delete(key) self.delete(key)
self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}") self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")