From 787df67cebab7ce0dabef3e9572bbceb10c2b541 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 20 Sep 2023 00:26:07 +1000 Subject: [PATCH] feat(backend): test node cache selective invalidation --- invokeai/app/invocations/baseinvocation.py | 2 +- invokeai/app/services/images.py | 50 ++++++++++++++++++- .../invocation_cache/invocation_cache_base.py | 7 ++- .../invocation_cache_memory.py | 32 ++++++++++-- 4 files changed, 83 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 8aa3cf6ceb..08e3f10803 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -589,7 +589,7 @@ class BaseInvocation(ABC, BaseModel): if cached_value is None: context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') output = self.invoke(context) - context.services.invocation_cache.save(key, output) + context.services.invocation_cache.save(output) return output else: context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 2b0a3d62a5..cb1c150807 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from logging import Logger -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from PIL.Image import Image as PILImageType @@ -38,6 +38,27 @@ if TYPE_CHECKING: class ImageServiceABC(ABC): """High-level service for image management.""" + _on_changed_callbacks: list[Callable[[ImageDTO], None]] + _on_deleted_callbacks: list[Callable[[str], None]] + + @abstractmethod + def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None: + """Register a callback for when an item is changed""" + pass + + @abstractmethod + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an item is deleted""" + pass + + @abstractmethod + def _on_changed(self, item: ImageDTO) -> None: + pass + + @abstractmethod + def _on_deleted(self, item_id: str) -> None: + pass + @abstractmethod def create( self, @@ -159,6 +180,24 @@ class ImageServiceDependencies: class ImageService(ImageServiceABC): _services: ImageServiceDependencies + _on_changed_callbacks: list[Callable[[ImageDTO], None]] = list() + _on_deleted_callbacks: list[Callable[[str], None]] = list() + + def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None: + """Register a callback for when an item is changed""" + self._on_changed_callbacks.append(on_changed) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an item is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_changed(self, item: ImageDTO) -> None: + for callback in self._on_changed_callbacks: + callback(item) + + def _on_deleted(self, item_id: str) -> None: + for callback in self._on_deleted_callbacks: + callback(item_id) def __init__(self, services: ImageServiceDependencies): self._services = services @@ -217,6 +256,7 @@ class ImageService(ImageServiceABC): self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow) image_dto = self.get_dto(image_name) + self._on_changed(image_dto) return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to save image record") @@ -235,7 +275,9 @@ class ImageService(ImageServiceABC): ) -> ImageDTO: try: self._services.image_records.update(image_name, changes) - return self.get_dto(image_name) + image_dto = self.get_dto(image_name) + self._on_changed(image_dto) + return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to update image record") raise @@ -374,6 +416,7 @@ class ImageService(ImageServiceABC): try: self._services.image_files.delete(image_name) self._services.image_records.delete(image_name) + self._on_deleted(image_name) except ImageRecordDeleteException: self._services.logger.error("Failed to delete image record") raise @@ -390,6 +433,8 @@ class ImageService(ImageServiceABC): for image_name in image_names: self._services.image_files.delete(image_name) self._services.image_records.delete_many(image_names) + for image_name in image_names: + self._on_deleted(image_name) except ImageRecordDeleteException: self._services.logger.error("Failed to delete image records") raise @@ -406,6 +451,7 @@ class ImageService(ImageServiceABC): count = len(image_names) for image_name in image_names: self._services.image_files.delete(image_name) + self._on_deleted(image_name) return count except ImageRecordDeleteException: self._services.logger.error("Failed to delete image records") diff --git a/invokeai/app/services/invocation_cache/invocation_cache_base.py b/invokeai/app/services/invocation_cache/invocation_cache_base.py index e60284378e..96b20dbddb 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_base.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_base.py @@ -13,7 +13,7 @@ class InvocationCacheBase(ABC): pass @abstractmethod - def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: + def save(self, value: BaseInvocationOutput) -> None: """Stores an invocation output in the cache""" pass @@ -22,6 +22,11 @@ class InvocationCacheBase(ABC): """Deleted an invocation output from the cache""" pass + @abstractmethod + def clear(self) -> None: + """Clears the cache""" + pass + @classmethod @abstractmethod def create_key(cls, value: BaseInvocation) -> Union[int, str]: diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 8dd59cd5b7..1572743d08 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -1,32 +1,44 @@ from queue import Queue from typing import Optional, Union + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase +from invokeai.app.services.invoker import Invoker class MemoryInvocationCache(InvocationCacheBase): - __cache: dict[Union[int, str], BaseInvocationOutput] + __cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]] __max_cache_size: int __cache_ids: Queue + __invoker: Invoker def __init__(self, max_cache_size: int = 512) -> None: self.__cache = dict() self.__max_cache_size = max_cache_size self.__cache_ids = Queue() + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker + self.__invoker.services.images.on_deleted(self.delete_by_match) + def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: if self.__max_cache_size == 0: return None - return self.__cache.get(key, None) + item = self.__cache.get(key, None) + if item is not None: + return item[0] - def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: + def save(self, value: BaseInvocationOutput) -> None: if self.__max_cache_size == 0: return None + value_json = value.json(exclude={"id"}) + key = hash(value_json) + if key not in self.__cache: - self.__cache[key] = value + self.__cache[key] = (value, value_json) self.__cache_ids.put(key) if self.__cache_ids.qsize() > self.__max_cache_size: try: @@ -41,6 +53,18 @@ class MemoryInvocationCache(InvocationCacheBase): if key in self.__cache: del self.__cache[key] + def delete_by_match(self, to_match: str) -> None: + to_delete = [] + for name, item in self.__cache.items(): + if to_match in item[1]: + to_delete.append(name) + for key in to_delete: + self.delete(key) + + def clear(self, *args, **kwargs) -> None: + self.__cache.clear() + self.__cache_ids = Queue() + @classmethod def create_key(cls, value: BaseInvocation) -> Union[int, str]: return hash(value.json(exclude={"id"}))