diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index afa17d5bd7..a98c8edc6a 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -103,3 +103,13 @@ async def set_log_level( """Sets the log verbosity level""" ApiDependencies.invoker.services.logger.setLevel(level) return LogLevel(ApiDependencies.invoker.services.logger.level) + + +@app_router.delete( + "/invocation_cache", + operation_id="clear_invocation_cache", + responses={200: {"description": "The operation was successful"}}, +) +async def clear_invocation_cache() -> None: + """Clears the invocation cache""" + ApiDependencies.invoker.services.invocation_cache.clear() diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 2b0a3d62a5..08d5093a70 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from logging import Logger -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from PIL.Image import Image as PILImageType @@ -38,6 +38,29 @@ if TYPE_CHECKING: class ImageServiceABC(ABC): """High-level service for image management.""" + _on_changed_callbacks: list[Callable[[ImageDTO], None]] + _on_deleted_callbacks: list[Callable[[str], None]] + + def __init__(self) -> None: + self._on_changed_callbacks = list() + self._on_deleted_callbacks = list() + + def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None: + """Register a callback for when an image is changed""" + self._on_changed_callbacks.append(on_changed) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an image is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_changed(self, item: ImageDTO) -> None: + for callback in self._on_changed_callbacks: + callback(item) + + def _on_deleted(self, item_id: str) -> None: + for callback in self._on_deleted_callbacks: + callback(item_id) + @abstractmethod def create( self, @@ -161,6 +184,7 @@ class ImageService(ImageServiceABC): _services: ImageServiceDependencies def __init__(self, services: ImageServiceDependencies): + super().__init__() self._services = services def create( @@ -217,6 +241,7 @@ class ImageService(ImageServiceABC): self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow) image_dto = self.get_dto(image_name) + self._on_changed(image_dto) return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to save image record") @@ -235,7 +260,9 @@ class ImageService(ImageServiceABC): ) -> ImageDTO: try: self._services.image_records.update(image_name, changes) - return self.get_dto(image_name) + image_dto = self.get_dto(image_name) + self._on_changed(image_dto) + return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to update image record") raise @@ -374,6 +401,7 @@ class ImageService(ImageServiceABC): try: self._services.image_files.delete(image_name) self._services.image_records.delete(image_name) + self._on_deleted(image_name) except ImageRecordDeleteException: self._services.logger.error("Failed to delete image record") raise @@ -390,6 +418,8 @@ class ImageService(ImageServiceABC): for image_name in image_names: self._services.image_files.delete(image_name) self._services.image_records.delete_many(image_names) + for image_name in image_names: + self._on_deleted(image_name) except ImageRecordDeleteException: self._services.logger.error("Failed to delete image records") raise @@ -406,6 +436,7 @@ class ImageService(ImageServiceABC): count = len(image_names) for image_name in image_names: self._services.image_files.delete(image_name) + self._on_deleted(image_name) return count except ImageRecordDeleteException: self._services.logger.error("Failed to delete image records") diff --git a/invokeai/app/services/invocation_cache/invocation_cache_base.py b/invokeai/app/services/invocation_cache/invocation_cache_base.py index e60284378e..c35a31f851 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_base.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_base.py @@ -5,25 +5,42 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati class InvocationCacheBase(ABC): - """Base class for invocation caches.""" + """ + Base class for invocation caches. + When an invocation is executed, it is hashed and its output stored in the cache. + When new invocations are executed, if they are flagged with `use_cache`, they + will attempt to pull their value from the cache before executing. + + Implementations should register for the `on_deleted` event of the `images` and `latents` + services, and delete any cached outputs that reference the deleted image or latent. + + See the memory implementation for an example. + + Implementations should respect the `node_cache_size` configuration value, and skip all + cache logic if the value is set to 0. + """ @abstractmethod def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: - """Retrieves and invocation output from the cache""" + """Retrieves an invocation output from the cache""" pass @abstractmethod - def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: + def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: """Stores an invocation output in the cache""" pass @abstractmethod def delete(self, key: Union[int, str]) -> None: - """Deleted an invocation output from the cache""" + """Deleteds an invocation output from the cache""" pass - @classmethod @abstractmethod - def create_key(cls, value: BaseInvocation) -> Union[int, str]: - """Creates the cache key for an invocation""" + def clear(self) -> None: + """Clears the cache""" + pass + + @abstractmethod + def create_key(self, invocation: BaseInvocation) -> int: + """Gets the key for the invocation's cache item""" pass diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 8dd59cd5b7..4c0eb2106f 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -3,44 +3,79 @@ from typing import Optional, Union from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase +from invokeai.app.services.invoker import Invoker class MemoryInvocationCache(InvocationCacheBase): - __cache: dict[Union[int, str], BaseInvocationOutput] + __cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]] __max_cache_size: int __cache_ids: Queue + __invoker: Invoker - def __init__(self, max_cache_size: int = 512) -> None: + def __init__(self, max_cache_size: int = 0) -> None: self.__cache = dict() self.__max_cache_size = max_cache_size self.__cache_ids = Queue() + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker + if self.__max_cache_size == 0: + return + self.__invoker.services.images.on_deleted(self._delete_by_match) + self.__invoker.services.latents.on_deleted(self._delete_by_match) + def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: if self.__max_cache_size == 0: - return None + return - return self.__cache.get(key, None) + item = self.__cache.get(key, None) + if item is not None: + return item[0] - def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: + def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: if self.__max_cache_size == 0: - return None + return if key not in self.__cache: - self.__cache[key] = value + self.__cache[key] = (invocation_output, invocation_output.json()) self.__cache_ids.put(key) if self.__cache_ids.qsize() > self.__max_cache_size: try: self.__cache.pop(self.__cache_ids.get()) except KeyError: + # this means the cache_ids are somehow out of sync w/ the cache pass def delete(self, key: Union[int, str]) -> None: if self.__max_cache_size == 0: - return None + return if key in self.__cache: del self.__cache[key] - @classmethod - def create_key(cls, value: BaseInvocation) -> Union[int, str]: - return hash(value.json(exclude={"id"})) + def clear(self, *args, **kwargs) -> None: + if self.__max_cache_size == 0: + return + + self.__cache.clear() + self.__cache_ids = Queue() + + def create_key(self, invocation: BaseInvocation) -> int: + return hash(invocation.json(exclude={"id"})) + + def _delete_by_match(self, to_match: str) -> None: + if self.__max_cache_size == 0: + return + + keys_to_delete = set() + for key, value_tuple in self.__cache.items(): + if to_match in value_tuple[1]: + keys_to_delete.add(key) + + if not keys_to_delete: + return + + for key in keys_to_delete: + self.delete(key) + + self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}") diff --git a/invokeai/app/services/latent_storage.py b/invokeai/app/services/latent_storage.py index f0b1dc9fe7..8605ef5abd 100644 --- a/invokeai/app/services/latent_storage.py +++ b/invokeai/app/services/latent_storage.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from pathlib import Path from queue import Queue -from typing import Dict, Optional, Union +from typing import Callable, Dict, Optional, Union import torch @@ -11,6 +11,13 @@ import torch class LatentsStorageBase(ABC): """Responsible for storing and retrieving latents.""" + _on_changed_callbacks: list[Callable[[torch.Tensor], None]] + _on_deleted_callbacks: list[Callable[[str], None]] + + def __init__(self) -> None: + self._on_changed_callbacks = list() + self._on_deleted_callbacks = list() + @abstractmethod def get(self, name: str) -> torch.Tensor: pass @@ -23,6 +30,22 @@ class LatentsStorageBase(ABC): def delete(self, name: str) -> None: pass + def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: + """Register a callback for when an item is changed""" + self._on_changed_callbacks.append(on_changed) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an item is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_changed(self, item: torch.Tensor) -> None: + for callback in self._on_changed_callbacks: + callback(item) + + def _on_deleted(self, item_id: str) -> None: + for callback in self._on_deleted_callbacks: + callback(item_id) + class ForwardCacheLatentsStorage(LatentsStorageBase): """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" @@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): __underlying_storage: LatentsStorageBase def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): + super().__init__() self.__underlying_storage = underlying_storage self.__cache = dict() self.__cache_ids = Queue() @@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): def save(self, name: str, data: torch.Tensor) -> None: self.__underlying_storage.save(name, data) self.__set_cache(name, data) + self._on_changed(data) def delete(self, name: str) -> None: self.__underlying_storage.delete(name) if name in self.__cache: del self.__cache[name] + self._on_deleted(name) def __get_cache(self, name: str) -> Optional[torch.Tensor]: return None if name not in self.__cache else self.__cache[name] diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 9009140134..e43075bd32 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -3,8 +3,6 @@ import threading import pytest -from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache - # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split PromptCollectionTestInvocation, @@ -17,7 +15,9 @@ import sqlite3 from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph +from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats import InvocationStatsService @@ -61,7 +61,7 @@ def mock_services() -> InvocationServices: graph_execution_manager=graph_execution_manager, performance_statistics=InvocationStatsService(graph_execution_manager), processor=DefaultInvocationProcessor(), - configuration=None, # type: ignore + configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore session_queue=None, # type: ignore session_processor=None, # type: ignore invocation_cache=MemoryInvocationCache(), # type: ignore diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 119ac70498..7c636c3eca 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -4,6 +4,8 @@ import threading import pytest +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig + # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split ErrorInvocation, @@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split wait_until, ) -from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_queue import MemoryInvocationQueue @@ -70,10 +71,10 @@ def mock_services() -> InvocationServices: graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), performance_statistics=InvocationStatsService(graph_execution_manager), - configuration=InvokeAIAppConfig(), + configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore session_queue=None, # type: ignore session_processor=None, # type: ignore - invocation_cache=MemoryInvocationCache(), + invocation_cache=MemoryInvocationCache(max_cache_size=0), )