Merge branch 'main' into maryhipp/informational-popover

This commit is contained in:
chainchompa 2023-09-20 12:38:36 -04:00 committed by GitHub
commit b128db1d58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 27 deletions

View File

@ -103,3 +103,13 @@ async def set_log_level(
"""Sets the log verbosity level""" """Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level) ApiDependencies.invoker.services.logger.setLevel(level)
return LogLevel(ApiDependencies.invoker.services.logger.level) return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.delete(
"/invocation_cache",
operation_id="clear_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def clear_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.clear()

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Callable, Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
@ -38,6 +38,29 @@ if TYPE_CHECKING:
class ImageServiceABC(ABC): class ImageServiceABC(ABC):
"""High-level service for image management.""" """High-level service for image management."""
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an image is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: ImageDTO) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
@abstractmethod @abstractmethod
def create( def create(
self, self,
@ -161,6 +184,7 @@ class ImageService(ImageServiceABC):
_services: ImageServiceDependencies _services: ImageServiceDependencies
def __init__(self, services: ImageServiceDependencies): def __init__(self, services: ImageServiceDependencies):
super().__init__()
self._services = services self._services = services
def create( def create(
@ -217,6 +241,7 @@ class ImageService(ImageServiceABC):
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow) self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto return image_dto
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to save image record") self._services.logger.error("Failed to save image record")
@ -235,7 +260,9 @@ class ImageService(ImageServiceABC):
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.image_records.update(image_name, changes) self._services.image_records.update(image_name, changes)
return self.get_dto(image_name) image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self._services.logger.error("Failed to update image record")
raise raise
@ -374,6 +401,7 @@ class ImageService(ImageServiceABC):
try: try:
self._services.image_files.delete(image_name) self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name) self._services.image_records.delete(image_name)
self._on_deleted(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record") self._services.logger.error("Failed to delete image record")
raise raise
@ -390,6 +418,8 @@ class ImageService(ImageServiceABC):
for image_name in image_names: for image_name in image_names:
self._services.image_files.delete(image_name) self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names) self._services.image_records.delete_many(image_names)
for image_name in image_names:
self._on_deleted(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records") self._services.logger.error("Failed to delete image records")
raise raise
@ -406,6 +436,7 @@ class ImageService(ImageServiceABC):
count = len(image_names) count = len(image_names)
for image_name in image_names: for image_name in image_names:
self._services.image_files.delete(image_name) self._services.image_files.delete(image_name)
self._on_deleted(image_name)
return count return count
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records") self._services.logger.error("Failed to delete image records")

View File

@ -5,25 +5,42 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati
class InvocationCacheBase(ABC): class InvocationCacheBase(ABC):
"""Base class for invocation caches.""" """
Base class for invocation caches.
When an invocation is executed, it is hashed and its output stored in the cache.
When new invocations are executed, if they are flagged with `use_cache`, they
will attempt to pull their value from the cache before executing.
Implementations should register for the `on_deleted` event of the `images` and `latents`
services, and delete any cached outputs that reference the deleted image or latent.
See the memory implementation for an example.
Implementations should respect the `node_cache_size` configuration value, and skip all
cache logic if the value is set to 0.
"""
@abstractmethod @abstractmethod
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
"""Retrieves and invocation output from the cache""" """Retrieves an invocation output from the cache"""
pass pass
@abstractmethod @abstractmethod
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache""" """Stores an invocation output in the cache"""
pass pass
@abstractmethod @abstractmethod
def delete(self, key: Union[int, str]) -> None: def delete(self, key: Union[int, str]) -> None:
"""Deleted an invocation output from the cache""" """Deleteds an invocation output from the cache"""
pass pass
@classmethod
@abstractmethod @abstractmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]: def clear(self) -> None:
"""Creates the cache key for an invocation""" """Clears the cache"""
pass
@abstractmethod
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass pass

View File

@ -3,44 +3,79 @@ from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase): class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], BaseInvocationOutput] __cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int __max_cache_size: int
__cache_ids: Queue __cache_ids: Queue
__invoker: Invoker
def __init__(self, max_cache_size: int = 512) -> None: def __init__(self, max_cache_size: int = 0) -> None:
self.__cache = dict() self.__cache = dict()
self.__max_cache_size = max_cache_size self.__max_cache_size = max_cache_size
self.__cache_ids = Queue() self.__cache_ids = Queue()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
if self.__max_cache_size == 0:
return
self.__invoker.services.images.on_deleted(self._delete_by_match)
self.__invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0: if self.__max_cache_size == 0:
return None return
return self.__cache.get(key, None) item = self.__cache.get(key, None)
if item is not None:
return item[0]
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None: def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0: if self.__max_cache_size == 0:
return None return
if key not in self.__cache: if key not in self.__cache:
self.__cache[key] = value self.__cache[key] = (invocation_output, invocation_output.json())
self.__cache_ids.put(key) self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size: if self.__cache_ids.qsize() > self.__max_cache_size:
try: try:
self.__cache.pop(self.__cache_ids.get()) self.__cache.pop(self.__cache_ids.get())
except KeyError: except KeyError:
# this means the cache_ids are somehow out of sync w/ the cache
pass pass
def delete(self, key: Union[int, str]) -> None: def delete(self, key: Union[int, str]) -> None:
if self.__max_cache_size == 0: if self.__max_cache_size == 0:
return None return
if key in self.__cache: if key in self.__cache:
del self.__cache[key] del self.__cache[key]
@classmethod def clear(self, *args, **kwargs) -> None:
def create_key(cls, value: BaseInvocation) -> Union[int, str]: if self.__max_cache_size == 0:
return hash(value.json(exclude={"id"})) return
self.__cache.clear()
self.__cache_ids = Queue()
def create_key(self, invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"}))
def _delete_by_match(self, to_match: str) -> None:
if self.__max_cache_size == 0:
return
keys_to_delete = set()
for key, value_tuple in self.__cache.items():
if to_match in value_tuple[1]:
keys_to_delete.add(key)
if not keys_to_delete:
return
for key in keys_to_delete:
self.delete(key)
self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, Optional, Union from typing import Callable, Dict, Optional, Union
import torch import torch
@ -11,6 +11,13 @@ import torch
class LatentsStorageBase(ABC): class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents.""" """Responsible for storing and retrieving latents."""
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
@abstractmethod @abstractmethod
def get(self, name: str) -> torch.Tensor: def get(self, name: str) -> torch.Tensor:
pass pass
@ -23,6 +30,22 @@ class LatentsStorageBase(ABC):
def delete(self, name: str) -> None: def delete(self, name: str) -> None:
pass pass
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: torch.Tensor) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
class ForwardCacheLatentsStorage(LatentsStorageBase): class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
__underlying_storage: LatentsStorageBase __underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage self.__underlying_storage = underlying_storage
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def save(self, name: str, data: torch.Tensor) -> None: def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data) self.__underlying_storage.save(name, data)
self.__set_cache(name, data) self.__set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None: def delete(self, name: str) -> None:
self.__underlying_storage.delete(name) self.__underlying_storage.delete(name)
if name in self.__cache: if name in self.__cache:
del self.__cache[name] del self.__cache[name]
self._on_deleted(name)
def __get_cache(self, name: str) -> Optional[torch.Tensor]: def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name] return None if name not in self.__cache else self.__cache[name]

View File

@ -3,8 +3,6 @@ import threading
import pytest import pytest
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
# This import must happen before other invoke imports or test in other files(!!) break # This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split from .test_nodes import ( # isort: split
PromptCollectionTestInvocation, PromptCollectionTestInvocation,
@ -17,7 +15,9 @@ import sqlite3
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.invocation_stats import InvocationStatsService
@ -61,7 +61,7 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
performance_statistics=InvocationStatsService(graph_execution_manager), performance_statistics=InvocationStatsService(graph_execution_manager),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=None, # type: ignore configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore session_queue=None, # type: ignore
session_processor=None, # type: ignore session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(), # type: ignore invocation_cache=MemoryInvocationCache(), # type: ignore

View File

@ -4,6 +4,8 @@ import threading
import pytest import pytest
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
# This import must happen before other invoke imports or test in other files(!!) break # This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split from .test_nodes import ( # isort: split
ErrorInvocation, ErrorInvocation,
@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split
wait_until, wait_until,
) )
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -70,10 +71,10 @@ def mock_services() -> InvocationServices:
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager), performance_statistics=InvocationStatsService(graph_execution_manager),
configuration=InvokeAIAppConfig(), configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
session_queue=None, # type: ignore session_queue=None, # type: ignore
session_processor=None, # type: ignore session_processor=None, # type: ignore
invocation_cache=MemoryInvocationCache(), invocation_cache=MemoryInvocationCache(max_cache_size=0),
) )