fix(nodes): do not use double-underscores in cache service

This commit is contained in:
psychedelicious 2023-09-22 16:49:13 +10:00 committed by Kent Keirsey
parent 7d683b4db6
commit 7544eadd48

View File

@ -8,97 +8,97 @@ from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int
__disabled: bool
__hits: int
__misses: int
__cache_ids: Queue
__invoker: Invoker
_cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
_max_cache_size: int
_disabled: bool
_hits: int
_misses: int
_cache_ids: Queue
_invoker: Invoker
def __init__(self, max_cache_size: int = 0) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__disabled = False
self.__hits = 0
self.__misses = 0
self.__cache_ids = Queue()
self._cache = dict()
self._max_cache_size = max_cache_size
self._disabled = False
self._hits = 0
self._misses = 0
self._cache_ids = Queue()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
if self.__max_cache_size == 0:
self._invoker = invoker
if self._max_cache_size == 0:
return
self.__invoker.services.images.on_deleted(self._delete_by_match)
self.__invoker.services.latents.on_deleted(self._delete_by_match)
self._invoker.services.images.on_deleted(self._delete_by_match)
self._invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0 or self.__disabled:
if self._max_cache_size == 0 or self._disabled:
return
item = self.__cache.get(key, None)
item = self._cache.get(key, None)
if item is not None:
self.__hits += 1
self._hits += 1
return item[0]
self.__misses += 1
self._misses += 1
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0 or self.__disabled:
if self._max_cache_size == 0 or self._disabled:
return
if key not in self.__cache:
self.__cache[key] = (invocation_output, invocation_output.json())
self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size:
if key not in self._cache:
self._cache[key] = (invocation_output, invocation_output.json())
self._cache_ids.put(key)
if self._cache_ids.qsize() > self._max_cache_size:
try:
self.__cache.pop(self.__cache_ids.get())
self._cache.pop(self._cache_ids.get())
except KeyError:
# this means the cache_ids are somehow out of sync w/ the cache
pass
def delete(self, key: Union[int, str]) -> None:
if self.__max_cache_size == 0:
if self._max_cache_size == 0:
return
if key in self.__cache:
del self.__cache[key]
if key in self._cache:
del self._cache[key]
def clear(self, *args, **kwargs) -> None:
if self.__max_cache_size == 0:
if self._max_cache_size == 0:
return
self.__cache.clear()
self.__cache_ids = Queue()
self.__misses = 0
self.__hits = 0
self._cache.clear()
self._cache_ids = Queue()
self._misses = 0
self._hits = 0
def create_key(self, invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"}))
def disable(self) -> None:
if self.__max_cache_size == 0:
if self._max_cache_size == 0:
return
self.__disabled = True
self._disabled = True
def enable(self) -> None:
if self.__max_cache_size == 0:
if self._max_cache_size == 0:
return
self.__disabled = False
self._disabled = False
def get_status(self) -> InvocationCacheStatus:
return InvocationCacheStatus(
hits=self.__hits,
misses=self.__misses,
enabled=not self.__disabled and self.__max_cache_size > 0,
size=len(self.__cache),
max_size=self.__max_cache_size,
hits=self._hits,
misses=self._misses,
enabled=not self._disabled and self._max_cache_size > 0,
size=len(self._cache),
max_size=self._max_cache_size,
)
def _delete_by_match(self, to_match: str) -> None:
if self.__max_cache_size == 0:
if self._max_cache_size == 0:
return
keys_to_delete = set()
for key, value_tuple in self.__cache.items():
for key, value_tuple in self._cache.items():
if to_match in value_tuple[1]:
keys_to_delete.add(key)
@ -108,4 +108,4 @@ class MemoryInvocationCache(InvocationCacheBase):
for key in keys_to_delete:
self.delete(key)
self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")
self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")