From 0a09f84b07bf193995d1efd1f65c884402605d87 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 20 Sep 2023 18:26:47 +1000 Subject: [PATCH 1/4] feat(backend): selective invalidation for invocation cache This change enhances the invocation cache logic to delete cache entries when the resources to which they refer are deleted. For example, a cached output may refer to "some_image.png". If that image is deleted, and this particular cache entry is later retrieved by a node, that node's successors will receive references to the now non-existent "some_image.png". When they attempt to use that image, they will fail. To resolve this, we need to invalidate the cache when the resources to which it refers are deleted. Two options: - Invalidate the whole cache on every image/latents/etc delete - Selectively invalidate cache entries when their resources are deleted Node outputs can be any shape, with any number of resource references in arbitrarily nested pydantic models. Traversing that structure to identify resources is not trivial. But invalidating the whole cache is a bit heavy-handed. It would be nice to be more selective. Simple solution: - Invocation outputs' resource references are always string identifiers - like the image's or latents' name - Invocation outputs can be stringified, which includes said identifiers - When the invocation is cached, we store the stringified output alongside the "live" output classes - When a resource is deleted, pass its identifier to the cache service, which can then invalidate any cache entries that refer to it The images and latents storage services have been outfitted with `on_deleted()` callbacks, and the cache service registers itself to handle those events. This logic was copied from `ItemStorageABC`. `on_changed()` callback are also added to the images and latents services, though these are not currently used. Just following the existing pattern. --- invokeai/app/services/images.py | 35 +++++++++++- .../invocation_cache/invocation_cache_base.py | 31 ++++++++--- .../invocation_cache_memory.py | 53 +++++++++++++++---- invokeai/app/services/latent_storage.py | 28 +++++++++- 4 files changed, 127 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 2b0a3d62a5..08d5093a70 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from logging import Logger -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from PIL.Image import Image as PILImageType @@ -38,6 +38,29 @@ if TYPE_CHECKING: class ImageServiceABC(ABC): """High-level service for image management.""" + _on_changed_callbacks: list[Callable[[ImageDTO], None]] + _on_deleted_callbacks: list[Callable[[str], None]] + + def __init__(self) -> None: + self._on_changed_callbacks = list() + self._on_deleted_callbacks = list() + + def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None: + """Register a callback for when an image is changed""" + self._on_changed_callbacks.append(on_changed) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an image is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_changed(self, item: ImageDTO) -> None: + for callback in self._on_changed_callbacks: + callback(item) + + def _on_deleted(self, item_id: str) -> None: + for callback in self._on_deleted_callbacks: + callback(item_id) + @abstractmethod def create( self, @@ -161,6 +184,7 @@ class ImageService(ImageServiceABC): _services: ImageServiceDependencies def __init__(self, services: ImageServiceDependencies): + super().__init__() self._services = services def create( @@ -217,6 +241,7 @@ class ImageService(ImageServiceABC): self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow) image_dto = self.get_dto(image_name) + self._on_changed(image_dto) return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to save image record") @@ -235,7 +260,9 @@ class ImageService(ImageServiceABC): ) -> ImageDTO: try: self._services.image_records.update(image_name, changes) - return self.get_dto(image_name) + image_dto = self.get_dto(image_name) + self._on_changed(image_dto) + return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to update image record") raise @@ -374,6 +401,7 @@ class ImageService(ImageServiceABC): try: self._services.image_files.delete(image_name) self._services.image_records.delete(image_name) + self._on_deleted(image_name) except ImageRecordDeleteException: self._services.logger.error("Failed to delete image record") raise @@ -390,6 +418,8 @@ class ImageService(ImageServiceABC): for image_name in image_names: self._services.image_files.delete(image_name) self._services.image_records.delete_many(image_names) + for image_name in image_names: + self._on_deleted(image_name) except ImageRecordDeleteException: self._services.logger.error("Failed to delete image records") raise @@ -406,6 +436,7 @@ class ImageService(ImageServiceABC): count = len(image_names) for image_name in image_names: self._services.image_files.delete(image_name) + self._on_deleted(image_name) return count except ImageRecordDeleteException: self._services.logger.error("Failed to delete image records") diff --git a/invokeai/app/services/invocation_cache/invocation_cache_base.py b/invokeai/app/services/invocation_cache/invocation_cache_base.py index e60284378e..c35a31f851 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_base.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_base.py @@ -5,25 +5,42 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati class InvocationCacheBase(ABC): - """Base class for invocation caches.""" + """ + Base class for invocation caches. + When an invocation is executed, it is hashed and its output stored in the cache. + When new invocations are executed, if they are flagged with `use_cache`, they + will attempt to pull their value from the cache before executing. + + Implementations should register for the `on_deleted` event of the `images` and `latents` + services, and delete any cached outputs that reference the deleted image or latent. + + See the memory implementation for an example. + + Implementations should respect the `node_cache_size` configuration value, and skip all + cache logic if the value is set to 0. + """ @abstractmethod def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: - """Retrieves and invocation output from the cache""" + """Retrieves an invocation output from the cache""" pass @abstractmethod - def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: + def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: """Stores an invocation output in the cache""" pass @abstractmethod def delete(self, key: Union[int, str]) -> None: - """Deleted an invocation output from the cache""" + """Deleteds an invocation output from the cache""" pass - @classmethod @abstractmethod - def create_key(cls, value: BaseInvocation) -> Union[int, str]: - """Creates the cache key for an invocation""" + def clear(self) -> None: + """Clears the cache""" + pass + + @abstractmethod + def create_key(self, invocation: BaseInvocation) -> int: + """Gets the key for the invocation's cache item""" pass diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 8dd59cd5b7..455a78d9d1 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -3,44 +3,77 @@ from typing import Optional, Union from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase +from invokeai.app.services.invoker import Invoker class MemoryInvocationCache(InvocationCacheBase): - __cache: dict[Union[int, str], BaseInvocationOutput] + __cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]] __max_cache_size: int __cache_ids: Queue + __invoker: Invoker def __init__(self, max_cache_size: int = 512) -> None: self.__cache = dict() self.__max_cache_size = max_cache_size self.__cache_ids = Queue() + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker + self.__invoker.services.images.on_deleted(self._delete_by_match) + self.__invoker.services.latents.on_deleted(self._delete_by_match) + def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: if self.__max_cache_size == 0: - return None + return - return self.__cache.get(key, None) + item = self.__cache.get(key, None) + if item is not None: + return item[0] - def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: + def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: if self.__max_cache_size == 0: - return None + return if key not in self.__cache: - self.__cache[key] = value + self.__cache[key] = (invocation_output, invocation_output.json()) self.__cache_ids.put(key) if self.__cache_ids.qsize() > self.__max_cache_size: try: self.__cache.pop(self.__cache_ids.get()) except KeyError: + # this means the cache_ids are somehow out of sync w/ the cache pass def delete(self, key: Union[int, str]) -> None: if self.__max_cache_size == 0: - return None + return if key in self.__cache: del self.__cache[key] - @classmethod - def create_key(cls, value: BaseInvocation) -> Union[int, str]: - return hash(value.json(exclude={"id"})) + def clear(self, *args, **kwargs) -> None: + if self.__max_cache_size == 0: + return + + self.__cache.clear() + self.__cache_ids = Queue() + + def create_key(self, invocation: BaseInvocation) -> int: + return hash(invocation.json(exclude={"id"})) + + def _delete_by_match(self, to_match: str) -> None: + if self.__max_cache_size == 0: + return + + keys_to_delete = set() + for key, value_tuple in self.__cache.items(): + if to_match in value_tuple[1]: + keys_to_delete.add(key) + + if not keys_to_delete: + return + + for key in keys_to_delete: + self.delete(key) + + self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}") diff --git a/invokeai/app/services/latent_storage.py b/invokeai/app/services/latent_storage.py index f0b1dc9fe7..8605ef5abd 100644 --- a/invokeai/app/services/latent_storage.py +++ b/invokeai/app/services/latent_storage.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from pathlib import Path from queue import Queue -from typing import Dict, Optional, Union +from typing import Callable, Dict, Optional, Union import torch @@ -11,6 +11,13 @@ import torch class LatentsStorageBase(ABC): """Responsible for storing and retrieving latents.""" + _on_changed_callbacks: list[Callable[[torch.Tensor], None]] + _on_deleted_callbacks: list[Callable[[str], None]] + + def __init__(self) -> None: + self._on_changed_callbacks = list() + self._on_deleted_callbacks = list() + @abstractmethod def get(self, name: str) -> torch.Tensor: pass @@ -23,6 +30,22 @@ class LatentsStorageBase(ABC): def delete(self, name: str) -> None: pass + def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: + """Register a callback for when an item is changed""" + self._on_changed_callbacks.append(on_changed) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an item is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_changed(self, item: torch.Tensor) -> None: + for callback in self._on_changed_callbacks: + callback(item) + + def _on_deleted(self, item_id: str) -> None: + for callback in self._on_deleted_callbacks: + callback(item_id) + class ForwardCacheLatentsStorage(LatentsStorageBase): """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" @@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): __underlying_storage: LatentsStorageBase def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): + super().__init__() self.__underlying_storage = underlying_storage self.__cache = dict() self.__cache_ids = Queue() @@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): def save(self, name: str, data: torch.Tensor) -> None: self.__underlying_storage.save(name, data) self.__set_cache(name, data) + self._on_changed(data) def delete(self, name: str) -> None: self.__underlying_storage.delete(name) if name in self.__cache: del self.__cache[name] + self._on_deleted(name) def __get_cache(self, name: str) -> Optional[torch.Tensor]: return None if name not in self.__cache else self.__cache[name] From c1aa2b82eb8cd7158ee2b5c3518cdb03ccf13bc7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 20 Sep 2023 18:40:24 +1000 Subject: [PATCH 2/4] feat(nodes): default `node_cache_size` in `MemoryInvocationCache` to 0 (fully disabled) --- .../app/services/invocation_cache/invocation_cache_memory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 455a78d9d1..4c0eb2106f 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -12,13 +12,15 @@ class MemoryInvocationCache(InvocationCacheBase): __cache_ids: Queue __invoker: Invoker - def __init__(self, max_cache_size: int = 512) -> None: + def __init__(self, max_cache_size: int = 0) -> None: self.__cache = dict() self.__max_cache_size = max_cache_size self.__cache_ids = Queue() def start(self, invoker: Invoker) -> None: self.__invoker = invoker + if self.__max_cache_size == 0: + return self.__invoker.services.images.on_deleted(self._delete_by_match) self.__invoker.services.latents.on_deleted(self._delete_by_match) From bfed08673ae85a18b9142955480bbdfc16b27239 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 20 Sep 2023 18:40:40 +1000 Subject: [PATCH 3/4] fix(test): fix tests --- tests/nodes/test_graph_execution_state.py | 6 +++--- tests/nodes/test_invoker.py | 13 +++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 41ca93551a..a8a6590f68 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -3,8 +3,6 @@ import threading import pytest -from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache - # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split PromptCollectionTestInvocation, @@ -17,7 +15,9 @@ import sqlite3 from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph +from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats import InvocationStatsService @@ -61,7 +61,7 @@ def mock_services() -> InvocationServices: graph_execution_manager=graph_execution_manager, performance_statistics=InvocationStatsService(graph_execution_manager), processor=DefaultInvocationProcessor(), - configuration=None, # type: ignore + configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore session_queue=None, # type: ignore session_processor=None, # type: ignore invocation_cache=MemoryInvocationCache(), # type: ignore diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 7dc5cf57b3..c3b508f675 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -4,6 +4,8 @@ import threading import pytest +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig + # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split ErrorInvocation, @@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split wait_until, ) -from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_queue import MemoryInvocationQueue @@ -70,10 +71,10 @@ def mock_services() -> InvocationServices: graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), performance_statistics=InvocationStatsService(graph_execution_manager), - configuration=InvokeAIAppConfig(), + configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore session_queue=None, # type: ignore session_processor=None, # type: ignore - invocation_cache=MemoryInvocationCache(), + invocation_cache=MemoryInvocationCache(max_cache_size=0), ) @@ -102,7 +103,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): # @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_can_invoke(mock_invoker: Invoker, simple_graph): g = mock_invoker.create_execution_state(graph=simple_graph) - invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g) + invocation_id = mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g) assert invocation_id is not None def has_executed_any(g: GraphExecutionState): @@ -120,7 +121,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph): g = mock_invoker.create_execution_state(graph=simple_graph) invocation_id = mock_invoker.invoke( - queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True + queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True ) assert invocation_id is not None @@ -140,7 +141,7 @@ def test_handles_errors(mock_invoker: Invoker): g = mock_invoker.create_execution_state() g.graph.add_node(ErrorInvocation(id="1")) - mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True) + mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True) def has_executed_all(g: GraphExecutionState): g = mock_invoker.services.graph_execution_manager.get(g.id) From 4cdca45228fb0d6f7e664cc3e19b3ce0236bc676 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 20 Sep 2023 22:53:25 +1000 Subject: [PATCH 4/4] feat(api): add route to clear invocation cache --- invokeai/app/api/routers/app_info.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index afa17d5bd7..a98c8edc6a 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -103,3 +103,13 @@ async def set_log_level( """Sets the log verbosity level""" ApiDependencies.invoker.services.logger.setLevel(level) return LogLevel(ApiDependencies.invoker.services.logger.level) + + +@app_router.delete( + "/invocation_cache", + operation_id="clear_invocation_cache", + responses={200: {"description": "The operation was successful"}}, +) +async def clear_invocation_cache() -> None: + """Clears the invocation cache""" + ApiDependencies.invoker.services.invocation_cache.clear()