Merge branch 'main' into maryhipp/informational-popover

This commit is contained in:
chainchompa 2023-09-20 12:38:36 -04:00 committed by GitHub
commit b128db1d58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 27 deletions

View File

@ -103,3 +103,13 @@ async def set_log_level(
"""Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level)
return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.delete(
"/invocation_cache",
operation_id="clear_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def clear_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.clear()

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional
from PIL.Image import Image as PILImageType
@ -38,6 +38,29 @@ if TYPE_CHECKING:
class ImageServiceABC(ABC):
"""High-level service for image management."""
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an image is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: ImageDTO) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
@abstractmethod
def create(
self,
@ -161,6 +184,7 @@ class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(self, services: ImageServiceDependencies):
super().__init__()
self._services = services
def create(
@ -217,6 +241,7 @@ class ImageService(ImageServiceABC):
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
@ -235,7 +260,9 @@ class ImageService(ImageServiceABC):
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@ -374,6 +401,7 @@ class ImageService(ImageServiceABC):
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record")
raise
@ -390,6 +418,8 @@ class ImageService(ImageServiceABC):
for image_name in image_names:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)
for image_name in image_names:
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")
raise
@ -406,6 +436,7 @@ class ImageService(ImageServiceABC):
count = len(image_names)
for image_name in image_names:
self._services.image_files.delete(image_name)
self._on_deleted(image_name)
return count
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")

View File

@ -5,25 +5,42 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati
class InvocationCacheBase(ABC):
"""Base class for invocation caches."""
"""
Base class for invocation caches.
When an invocation is executed, it is hashed and its output stored in the cache.
When new invocations are executed, if they are flagged with `use_cache`, they
will attempt to pull their value from the cache before executing.
Implementations should register for the `on_deleted` event of the `images` and `latents`
services, and delete any cached outputs that reference the deleted image or latent.
See the memory implementation for an example.
Implementations should respect the `node_cache_size` configuration value, and skip all
cache logic if the value is set to 0.
"""
@abstractmethod
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
"""Retrieves and invocation output from the cache"""
"""Retrieves an invocation output from the cache"""
pass
@abstractmethod
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache"""
pass
@abstractmethod
def delete(self, key: Union[int, str]) -> None:
"""Deleted an invocation output from the cache"""
"""Deleteds an invocation output from the cache"""
pass
@classmethod
@abstractmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
"""Creates the cache key for an invocation"""
def clear(self) -> None:
"""Clears the cache"""
pass
@abstractmethod
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass

View File

@ -3,44 +3,79 @@ from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], BaseInvocationOutput]
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int
__cache_ids: Queue
__invoker: Invoker
def __init__(self, max_cache_size: int = 512) -> None:
def __init__(self, max_cache_size: int = 0) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__cache_ids = Queue()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
if self.__max_cache_size == 0:
return
self.__invoker.services.images.on_deleted(self._delete_by_match)
self.__invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0:
return None
return
return self.__cache.get(key, None)
item = self.__cache.get(key, None)
if item is not None:
return item[0]
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0:
return None
return
if key not in self.__cache:
self.__cache[key] = value
self.__cache[key] = (invocation_output, invocation_output.json())
self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size:
try:
self.__cache.pop(self.__cache_ids.get())
except KeyError:
# this means the cache_ids are somehow out of sync w/ the cache
pass
def delete(self, key: Union[int, str]) -> None:
if self.__max_cache_size == 0:
return None
return
if key in self.__cache:
del self.__cache[key]
@classmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
return hash(value.json(exclude={"id"}))
def clear(self, *args, **kwargs) -> None:
if self.__max_cache_size == 0:
return
self.__cache.clear()
self.__cache_ids = Queue()
def create_key(self, invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"}))
def _delete_by_match(self, to_match: str) -> None:
if self.__max_cache_size == 0:
return
keys_to_delete = set()
for key, value_tuple in self.__cache.items():
if to_match in value_tuple[1]:
keys_to_delete.add(key)
if not keys_to_delete:
return
for key in keys_to_delete:
self.delete(key)
self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional, Union
from typing import Callable, Dict, Optional, Union
import torch
@ -11,6 +11,13 @@ import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@ -23,6 +30,22 @@ class LatentsStorageBase(ABC):
def delete(self, name: str) -> None:
pass
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: torch.Tensor) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
self._on_deleted(name)
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name]

View File

@ -3,8 +3,6 @@ import threading
import pytest
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
PromptCollectionTestInvocation,
@ -17,7 +15,9 @@ import sqlite3
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
@ -61,7 +61,7 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager,
performance_statistics=InvocationStatsService(graph_execution_manager),
processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore
session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(), # type: ignore

View File

@ -4,6 +4,8 @@ import threading
import pytest
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
ErrorInvocation,
@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split
wait_until,
)
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -70,10 +71,10 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
configuration=InvokeAIAppConfig(),
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore
session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(),
invocation_cache=MemoryInvocationCache(max_cache_size=0),
)