mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
revert(nodes): revert making tensors/conditioning use item storage
Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning. Refined the class a bit too.
This commit is contained in:
parent
614f0e8086
commit
d9dc5d58be
@ -4,9 +4,9 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk
|
||||
from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@ -90,9 +90,9 @@ class ApiDependencies:
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors"))
|
||||
conditioning = ItemStorageForwardCache(
|
||||
ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
|
||||
tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors"))
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
|
@ -304,11 +304,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet,
|
||||
seed,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name)
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name)
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
mask = context.tensors.get(self.denoise_mask.mask_name)
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
|
||||
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.tensors.get(self.noise.latents_name)
|
||||
noise = context.tensors.load(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.tensors.get(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.get(self.latents_b.latents_name)
|
||||
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
x1 = self.x // LATENT_SCALE_FACTOR
|
||||
y1 = self.y // LATENT_SCALE_FACTOR
|
||||
|
@ -344,7 +344,7 @@ class LatentsInvocation(BaseInvocation):
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||
|
||||
|
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
@ -65,8 +67,8 @@ class InvocationServices:
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ItemStorageABC[torch.Tensor]",
|
||||
conditioning: "ItemStorageABC[ConditioningFieldData]",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
|
@ -1,7 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
@ -26,9 +28,9 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> str:
|
||||
def set(self, item: T) -> None:
|
||||
"""
|
||||
Sets the item. The id will be extracted based on id_field.
|
||||
Sets the item.
|
||||
:param item: the item to set
|
||||
"""
|
||||
pass
|
||||
|
@ -1,15 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable, TypeAlias, TypeVar
|
||||
|
||||
|
||||
class ItemNotFoundError(KeyError):
|
||||
"""Raised when an item is not found in storage"""
|
||||
|
||||
def __init__(self, item_id: str) -> None:
|
||||
super().__init__(f"Item with id {item_id} not found")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
SaveFunc: TypeAlias = Callable[[T, Path], None]
|
||||
LoadFunc: TypeAlias = Callable[[Path], T]
|
||||
|
@ -1,97 +0,0 @@
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ItemStorageEphemeralDisk(ItemStorageABC[T]):
|
||||
"""Provides a disk-backed ephemeral storage. The storage is cleared at startup.
|
||||
|
||||
:param output_folder: The folder where the items will be stored
|
||||
:param save: The function to use to save the items to disk [torch.save]
|
||||
:param load: The function to use to load the items from disk [torch.load]
|
||||
:param load_exc: The exception that is raised when an item is not found [FileNotFoundError]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_folder: Path,
|
||||
save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType]
|
||||
load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType]
|
||||
load_exc: Type[Exception] = FileNotFoundError,
|
||||
):
|
||||
super().__init__()
|
||||
self._output_folder = output_folder
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
self._save = save
|
||||
self._load = load
|
||||
self._load_exc = load_exc
|
||||
self.__item_class_name: Optional[str] = None
|
||||
|
||||
@property
|
||||
def _item_class_name(self) -> str:
|
||||
if not self.__item_class_name:
|
||||
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
||||
self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
|
||||
return self.__item_class_name
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_items()
|
||||
|
||||
def get(self, item_id: str) -> T:
|
||||
file_path = self._get_path(item_id)
|
||||
try:
|
||||
return self._load(file_path)
|
||||
except self._load_exc as e:
|
||||
raise ItemNotFoundError(item_id) from e
|
||||
|
||||
def set(self, item: T) -> str:
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
item_id = self._new_item_id()
|
||||
file_path = self._get_path(item_id)
|
||||
self._save(item, file_path)
|
||||
return item_id
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
file_path = self._get_path(item_id)
|
||||
file_path.unlink()
|
||||
|
||||
def _get_path(self, item_id: str) -> Path:
|
||||
return self._output_folder / item_id
|
||||
|
||||
def _new_item_id(self) -> str:
|
||||
return f"{self._item_class_name}_{uuid_string()}"
|
||||
|
||||
def _delete_all_items(self) -> None:
|
||||
"""
|
||||
Deletes all pickled items from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
|
||||
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
|
||||
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
|
||||
|
||||
if not self._invoker:
|
||||
raise ValueError("Invoker is not set. Must call `start()` first.")
|
||||
|
||||
deleted_count = 0
|
||||
freed_space = 0
|
||||
for file in Path(self._output_folder).glob("*"):
|
||||
if file.is_file():
|
||||
freed_space += file.stat().st_size
|
||||
deleted_count += 1
|
||||
file.unlink()
|
||||
if deleted_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)"
|
||||
)
|
@ -1,61 +0,0 @@
|
||||
from queue import Queue
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ItemStorageForwardCache(ItemStorageABC[T]):
|
||||
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
|
||||
|
||||
def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self._underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, item_id: str) -> T:
|
||||
cache_item = self._get_cache(item_id)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self._underlying_storage.get(item_id)
|
||||
self._set_cache(item_id, latent)
|
||||
return latent
|
||||
|
||||
def set(self, item: T) -> str:
|
||||
item_id = self._underlying_storage.set(item)
|
||||
self._set_cache(item_id, item)
|
||||
self._on_changed(item)
|
||||
return item_id
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
self._underlying_storage.delete(item_id)
|
||||
if item_id in self._cache:
|
||||
del self._cache[item_id]
|
||||
self._on_deleted(item_id)
|
||||
|
||||
def _get_cache(self, item_id: str) -> Optional[T]:
|
||||
return None if item_id not in self._cache else self._cache[item_id]
|
||||
|
||||
def _set_cache(self, item_id: str, data: T):
|
||||
if item_id not in self._cache:
|
||||
self._cache[item_id] = data
|
||||
self._cache_ids.put(item_id)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
@ -34,7 +34,7 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
|
||||
self._items[item_id] = item
|
||||
return item
|
||||
|
||||
def set(self, item: T) -> str:
|
||||
def set(self, item: T) -> None:
|
||||
item_id = getattr(item, self._id_field)
|
||||
if item_id in self._items:
|
||||
# If item already exists, remove it and add it to the end
|
||||
@ -44,7 +44,6 @@ class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
|
||||
self._items.popitem(last=False)
|
||||
self._items[item_id] = item
|
||||
self._on_changed(item)
|
||||
return item_id
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
# This is a no-op if the item doesn't exist.
|
||||
|
@ -0,0 +1,53 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ObjectSerializerBase(ABC, Generic[T]):
|
||||
"""Saves and loads arbitrary python objects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_saved_callbacks: list[Callable[[str, T], None]] = []
|
||||
self._on_deleted_callbacks: list[Callable[[str], None]] = []
|
||||
|
||||
@abstractmethod
|
||||
def load(self, name: str) -> T:
|
||||
"""
|
||||
Loads the object.
|
||||
:param name: The name of the object to load.
|
||||
:raises ObjectNotFoundError: if the object is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, obj: T) -> str:
|
||||
"""
|
||||
Saves the object, returning its name.
|
||||
:param obj: The object to save.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
"""
|
||||
Deletes the object, if it exists.
|
||||
:param name: The name of the object to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_saved(self, on_saved: Callable[[str, T], None]) -> None:
|
||||
"""Register a callback for when an object is saved"""
|
||||
self._on_saved_callbacks.append(on_saved)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an object is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_saved(self, name: str, obj: T) -> None:
|
||||
for callback in self._on_saved_callbacks:
|
||||
callback(name, obj)
|
||||
|
||||
def _on_deleted(self, name: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(name)
|
@ -0,0 +1,5 @@
|
||||
class ObjectNotFoundError(KeyError):
|
||||
"""Raised when an object is not found while loading"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f"Object with name {name} not found")
|
@ -0,0 +1,84 @@
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]):
|
||||
"""Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup.
|
||||
|
||||
:param output_folder: The folder where the objects will be stored
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: Path):
|
||||
super().__init__()
|
||||
self._output_dir = output_dir
|
||||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.__obj_class_name: Optional[str] = None
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all()
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
except FileNotFoundError as e:
|
||||
raise ObjectNotFoundError(name) from e
|
||||
|
||||
def save(self, obj: T) -> str:
|
||||
name = self._new_name()
|
||||
file_path = self._get_path(name)
|
||||
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
return name
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
file_path = self._get_path(name)
|
||||
file_path.unlink()
|
||||
|
||||
@property
|
||||
def _obj_class_name(self) -> str:
|
||||
if not self.__obj_class_name:
|
||||
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
||||
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues]
|
||||
return self.__obj_class_name
|
||||
|
||||
def _get_path(self, name: str) -> Path:
|
||||
return self._output_dir / name
|
||||
|
||||
def _new_name(self) -> str:
|
||||
return f"{self._obj_class_name}_{uuid_string()}"
|
||||
|
||||
def _delete_all(self) -> None:
|
||||
"""
|
||||
Deletes all objects from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
|
||||
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
|
||||
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
|
||||
|
||||
if not self._invoker:
|
||||
raise ValueError("Invoker is not set. Must call `start()` first.")
|
||||
|
||||
deleted_count = 0
|
||||
freed_space = 0
|
||||
for file in Path(self._output_dir).glob("*"):
|
||||
if file.is_file():
|
||||
freed_space += file.stat().st_size
|
||||
deleted_count += 1
|
||||
file.unlink()
|
||||
if deleted_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)"
|
||||
)
|
@ -0,0 +1,61 @@
|
||||
from queue import Queue
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
|
||||
|
||||
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self._underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
cache_item = self._get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self._underlying_storage.load(name)
|
||||
self._set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, obj: T) -> str:
|
||||
name = self._underlying_storage.save(obj)
|
||||
self._set_cache(name, obj)
|
||||
self._on_saved(name, obj)
|
||||
return name
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self._underlying_storage.delete(name)
|
||||
if name in self._cache:
|
||||
del self._cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def _get_cache(self, name: str) -> Optional[T]:
|
||||
return None if name not in self._cache else self._cache[name]
|
||||
|
||||
def _set_cache(self, name: str, data: T):
|
||||
if name not in self._cache:
|
||||
self._cache[name] = data
|
||||
self._cache_ids.put(name)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
@ -223,16 +223,16 @@ class TensorsInterface(InvocationContextInterface):
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
name = self._services.tensors.set(item=tensor)
|
||||
return name
|
||||
tensor_id = self._services.tensors.save(obj=tensor)
|
||||
return tensor_id
|
||||
|
||||
def get(self, tensor_name: str) -> Tensor:
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""
|
||||
Gets a tensor by name.
|
||||
Loads a tensor by name.
|
||||
|
||||
:param tensor_name: The name of the tensor to get.
|
||||
:param name: The name of the tensor to load.
|
||||
"""
|
||||
return self._services.tensors.get(tensor_name)
|
||||
return self._services.tensors.load(name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
@ -243,17 +243,17 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
:param conditioning_context_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
name = self._services.conditioning.set(item=conditioning_data)
|
||||
return name
|
||||
conditioning_id = self._services.conditioning.save(obj=conditioning_data)
|
||||
return conditioning_id
|
||||
|
||||
def get(self, conditioning_name: str) -> ConditioningFieldData:
|
||||
def load(self, name: str) -> ConditioningFieldData:
|
||||
"""
|
||||
Gets conditioning data by name.
|
||||
Loads conditioning data by name.
|
||||
|
||||
:param conditioning_name: The name of the conditioning data to get.
|
||||
:param name: The name of the conditioning data to load.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.get(conditioning_name)
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
|
Loading…
Reference in New Issue
Block a user