feat(backend): selective invalidation for invocation cache (#4597)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

## Description

This change enhances the invocation cache logic to delete cache entries
when the resources to which they refer are deleted.

For example, a cached output may refer to "some_image.png". If that
image is deleted, and this particular cache entry is later retrieved by
a node, that node's successors will receive references to the now
non-existent "some_image.png". When they attempt to use that image, they
will fail.

To resolve this, we need to invalidate the cache when the resources to
which it refers are deleted. Two options:
- Invalidate the whole cache on every image/latents/etc delete
- Selectively invalidate cache entries when their resources are deleted

Node outputs can be any shape, with any number of resource references in
arbitrarily nested pydantic models. Traversing that structure to
identify resources is not trivial.

But invalidating the whole cache is a bit heavy-handed. It would be nice
to be more selective.

Simple solution:
- Invocation outputs' resource references are always string identifiers
- like the image's or latents' name
- Invocation outputs can be stringified, which includes said identifiers
- When the invocation is cached, we store the stringified output
alongside the "live" output classes
- When a resource is deleted, pass its identifier to the cache service,
which can then invalidate any cache entries that refer to it

The images and latents storage services have been outfitted with
`on_deleted()` callbacks, and the cache service registers itself to
handle those events. This logic was copied from `ItemStorageABC`.

`on_changed()` callback are also added to the images and latents
services, though these are not currently used. Just following the
existing pattern.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

Reproduce the issue on main:
- Create a graph in workflow editor with two connected resize nodes
- Add an image to the first
- Enable cache on both
- Run the graph
- Clear Intermediates (in settings)
- Disable cache on the *second* node
- Run the graph, it should fail

Switch to the PR branch and start over, doing the exact same steps. You
shouldn't get any errors.

Example graph to start with:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/c2f0f170-fff4-44f8-8d56-2d8b07ef6440)


## Added/updated tests?

- [~] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_
This commit is contained in:
Kent Keirsey 2023-09-20 11:09:39 -04:00 committed by GitHub
commit f7f0630d97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 27 deletions

View File

@ -103,3 +103,13 @@ async def set_log_level(
"""Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level)
return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.delete(
"/invocation_cache",
operation_id="clear_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def clear_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.clear()

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional
from PIL.Image import Image as PILImageType
@ -38,6 +38,29 @@ if TYPE_CHECKING:
class ImageServiceABC(ABC):
"""High-level service for image management."""
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an image is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: ImageDTO) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
@abstractmethod
def create(
self,
@ -161,6 +184,7 @@ class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(self, services: ImageServiceDependencies):
super().__init__()
self._services = services
def create(
@ -217,6 +241,7 @@ class ImageService(ImageServiceABC):
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
@ -235,7 +260,9 @@ class ImageService(ImageServiceABC):
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@ -374,6 +401,7 @@ class ImageService(ImageServiceABC):
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record")
raise
@ -390,6 +418,8 @@ class ImageService(ImageServiceABC):
for image_name in image_names:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)
for image_name in image_names:
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")
raise
@ -406,6 +436,7 @@ class ImageService(ImageServiceABC):
count = len(image_names)
for image_name in image_names:
self._services.image_files.delete(image_name)
self._on_deleted(image_name)
return count
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")

View File

@ -5,25 +5,42 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati
class InvocationCacheBase(ABC):
"""Base class for invocation caches."""
"""
Base class for invocation caches.
When an invocation is executed, it is hashed and its output stored in the cache.
When new invocations are executed, if they are flagged with `use_cache`, they
will attempt to pull their value from the cache before executing.
Implementations should register for the `on_deleted` event of the `images` and `latents`
services, and delete any cached outputs that reference the deleted image or latent.
See the memory implementation for an example.
Implementations should respect the `node_cache_size` configuration value, and skip all
cache logic if the value is set to 0.
"""
@abstractmethod
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
"""Retrieves and invocation output from the cache"""
"""Retrieves an invocation output from the cache"""
pass
@abstractmethod
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache"""
pass
@abstractmethod
def delete(self, key: Union[int, str]) -> None:
"""Deleted an invocation output from the cache"""
"""Deleteds an invocation output from the cache"""
pass
@classmethod
@abstractmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
"""Creates the cache key for an invocation"""
def clear(self) -> None:
"""Clears the cache"""
pass
@abstractmethod
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass

View File

@ -3,44 +3,79 @@ from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], BaseInvocationOutput]
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int
__cache_ids: Queue
__invoker: Invoker
def __init__(self, max_cache_size: int = 512) -> None:
def __init__(self, max_cache_size: int = 0) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__cache_ids = Queue()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
if self.__max_cache_size == 0:
return
self.__invoker.services.images.on_deleted(self._delete_by_match)
self.__invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0:
return None
return
return self.__cache.get(key, None)
item = self.__cache.get(key, None)
if item is not None:
return item[0]
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0:
return None
return
if key not in self.__cache:
self.__cache[key] = value
self.__cache[key] = (invocation_output, invocation_output.json())
self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size:
try:
self.__cache.pop(self.__cache_ids.get())
except KeyError:
# this means the cache_ids are somehow out of sync w/ the cache
pass
def delete(self, key: Union[int, str]) -> None:
if self.__max_cache_size == 0:
return None
return
if key in self.__cache:
del self.__cache[key]
@classmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
return hash(value.json(exclude={"id"}))
def clear(self, *args, **kwargs) -> None:
if self.__max_cache_size == 0:
return
self.__cache.clear()
self.__cache_ids = Queue()
def create_key(self, invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"}))
def _delete_by_match(self, to_match: str) -> None:
if self.__max_cache_size == 0:
return
keys_to_delete = set()
for key, value_tuple in self.__cache.items():
if to_match in value_tuple[1]:
keys_to_delete.add(key)
if not keys_to_delete:
return
for key in keys_to_delete:
self.delete(key)
self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional, Union
from typing import Callable, Dict, Optional, Union
import torch
@ -11,6 +11,13 @@ import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@ -23,6 +30,22 @@ class LatentsStorageBase(ABC):
def delete(self, name: str) -> None:
pass
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: torch.Tensor) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
self._on_deleted(name)
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name]

View File

@ -3,8 +3,6 @@ import threading
import pytest
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
PromptCollectionTestInvocation,
@ -17,7 +15,9 @@ import sqlite3
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
@ -61,7 +61,7 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager,
performance_statistics=InvocationStatsService(graph_execution_manager),
processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore
session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(), # type: ignore

View File

@ -4,6 +4,8 @@ import threading
import pytest
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
ErrorInvocation,
@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split
wait_until,
)
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -70,10 +71,10 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
configuration=InvokeAIAppConfig(),
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore
session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(),
invocation_cache=MemoryInvocationCache(max_cache_size=0),
)