feat(backend): test node cache selective invalidation

This commit is contained in:
psychedelicious
2023-09-20 00:26:07 +10:00
parent 4b149ab521
commit 787df67ceb
4 changed files with 83 additions and 8 deletions

View File

@ -589,7 +589,7 @@ class BaseInvocation(ABC, BaseModel):
if cached_value is None:
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
output = self.invoke(context)
context.services.invocation_cache.save(key, output)
context.services.invocation_cache.save(output)
return output
else:
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional
from PIL.Image import Image as PILImageType
@ -38,6 +38,27 @@ if TYPE_CHECKING:
class ImageServiceABC(ABC):
"""High-level service for image management."""
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
@abstractmethod
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an item is changed"""
pass
@abstractmethod
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
pass
@abstractmethod
def _on_changed(self, item: ImageDTO) -> None:
pass
@abstractmethod
def _on_deleted(self, item_id: str) -> None:
pass
@abstractmethod
def create(
self,
@ -159,6 +180,24 @@ class ImageServiceDependencies:
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
_on_changed_callbacks: list[Callable[[ImageDTO], None]] = list()
_on_deleted_callbacks: list[Callable[[str], None]] = list()
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: ImageDTO) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
def __init__(self, services: ImageServiceDependencies):
self._services = services
@ -217,6 +256,7 @@ class ImageService(ImageServiceABC):
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
@ -235,7 +275,9 @@ class ImageService(ImageServiceABC):
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@ -374,6 +416,7 @@ class ImageService(ImageServiceABC):
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record")
raise
@ -390,6 +433,8 @@ class ImageService(ImageServiceABC):
for image_name in image_names:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)
for image_name in image_names:
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")
raise
@ -406,6 +451,7 @@ class ImageService(ImageServiceABC):
count = len(image_names)
for image_name in image_names:
self._services.image_files.delete(image_name)
self._on_deleted(image_name)
return count
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")

View File

@ -13,7 +13,7 @@ class InvocationCacheBase(ABC):
pass
@abstractmethod
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
def save(self, value: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache"""
pass
@ -22,6 +22,11 @@ class InvocationCacheBase(ABC):
"""Deleted an invocation output from the cache"""
pass
@abstractmethod
def clear(self) -> None:
"""Clears the cache"""
pass
@classmethod
@abstractmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:

View File

@ -1,32 +1,44 @@
from queue import Queue
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], BaseInvocationOutput]
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int
__cache_ids: Queue
__invoker: Invoker
def __init__(self, max_cache_size: int = 512) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__cache_ids = Queue()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self.__invoker.services.images.on_deleted(self.delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0:
return None
return self.__cache.get(key, None)
item = self.__cache.get(key, None)
if item is not None:
return item[0]
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
def save(self, value: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0:
return None
value_json = value.json(exclude={"id"})
key = hash(value_json)
if key not in self.__cache:
self.__cache[key] = value
self.__cache[key] = (value, value_json)
self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size:
try:
@ -41,6 +53,18 @@ class MemoryInvocationCache(InvocationCacheBase):
if key in self.__cache:
del self.__cache[key]
def delete_by_match(self, to_match: str) -> None:
to_delete = []
for name, item in self.__cache.items():
if to_match in item[1]:
to_delete.append(name)
for key in to_delete:
self.delete(key)
def clear(self, *args, **kwargs) -> None:
self.__cache.clear()
self.__cache_ids = Queue()
@classmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
return hash(value.json(exclude={"id"}))