mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): replace latents service with tensors and conditioning services
- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling - Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk` - Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices` - Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices` - Remove `latents` service and all `LatentsStorage` classes - Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
This commit is contained in:
parent
31db62ba99
commit
0710fb3fb0
@ -2,9 +2,14 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@ -23,8 +28,6 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
@ -68,6 +71,9 @@ class ApiDependencies:
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
output_folder = config.output_path
|
||||
if output_folder is None:
|
||||
raise ValueError("Output folder is not set")
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
@ -84,7 +90,10 @@ class ApiDependencies:
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor"))
|
||||
conditioning = PickleStorageForwardCache(
|
||||
PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning")
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
@ -117,7 +126,6 @@ class ApiDependencies:
|
||||
image_records=image_records,
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
latents=latents,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
@ -131,6 +139,8 @@ class ApiDependencies:
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
workflow_records=workflow_records,
|
||||
tensors=tensors,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
@ -163,11 +163,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = context.latents.save(tensor=masked_latents)
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
else:
|
||||
masked_latents_name = None
|
||||
|
||||
mask_name = context.latents.save(tensor=mask)
|
||||
mask_name = context.tensors.save(tensor=mask)
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
@ -621,10 +621,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
mask = context.latents.get(self.denoise_mask.mask_name)
|
||||
mask = context.tensors.get(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.latents.get(self.denoise_mask.masked_latents_name)
|
||||
masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
@ -636,11 +636,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.latents.get(self.noise.latents_name)
|
||||
noise = context.tensors.get(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
@ -752,7 +752,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=result_latents)
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -888,7 +888,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=resized_latents)
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -930,7 +930,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=resized_latents)
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@ -1011,7 +1011,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.latents.save(tensor=latents)
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@singledispatchmethod
|
||||
@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.latents.get(self.latents_a.latents_name)
|
||||
latents_b = context.latents.get(self.latents_b.latents_name)
|
||||
latents_a = context.tensors.get(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.get(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
@ -1103,7 +1103,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.latents.save(tensor=blended_latents)
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
|
||||
@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
x1 = self.x // LATENT_SCALE_FACTOR
|
||||
y1 = self.y // LATENT_SCALE_FACTOR
|
||||
@ -1158,7 +1158,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
|
||||
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||
|
||||
name = context.latents.save(tensor=cropped_latents)
|
||||
name = context.tensors.save(tensor=cropped_latents)
|
||||
|
||||
return LatentsOutput.build(latents_name=name, latents=cropped_latents)
|
||||
|
||||
|
@ -121,5 +121,5 @@ class NoiseInvocation(BaseInvocation):
|
||||
seed=self.seed,
|
||||
use_cpu=self.use_cpu,
|
||||
)
|
||||
name = context.latents.save(tensor=noise)
|
||||
name = context.tensors.save(tensor=noise)
|
||||
return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||
|
@ -313,9 +313,7 @@ class DenoiseMaskOutput(BaseInvocationOutput):
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single latents tensor"""
|
||||
|
||||
latents: LatentsField = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
)
|
||||
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
@ -346,7 +344,7 @@ class LatentsInvocation(BaseInvocation):
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.get(self.latents.latents_name)
|
||||
|
||||
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||
|
||||
|
@ -37,7 +37,8 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._invoker.services.images.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.tensors.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.conditioning.on_deleted(self._delete_by_match)
|
||||
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
with self._lock:
|
||||
|
@ -6,6 +6,10 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||
from .board_images.board_images_base import BoardImagesServiceABC
|
||||
from .board_records.board_records_base import BoardRecordStorageBase
|
||||
@ -21,11 +25,11 @@ if TYPE_CHECKING:
|
||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .pickle_storage.pickle_storage_base import PickleStorageBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState
|
||||
@ -48,7 +52,6 @@ class InvocationServices:
|
||||
images: "ImageServiceABC",
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
@ -63,6 +66,8 @@ class InvocationServices:
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "PickleStorageBase[torch.Tensor]",
|
||||
conditioning: "PickleStorageBase[ConditioningFieldData]",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@ -74,7 +79,6 @@ class InvocationServices:
|
||||
self.images = images
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
@ -89,3 +93,5 @@ class InvocationServices:
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
self.conditioning = conditioning
|
||||
|
@ -1,58 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: Path
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_latents()
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
def _delete_all_latents(self) -> None:
|
||||
"""
|
||||
Deletes all latents from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
deleted_latents_count = 0
|
||||
freed_space = 0
|
||||
for latents_file in Path(self.__output_folder).glob("*"):
|
||||
if latents_file.is_file():
|
||||
freed_space += latents_file.stat().st_size
|
||||
deleted_latents_count += 1
|
||||
latents_file.unlink()
|
||||
if deleted_latents_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
|
||||
)
|
@ -1,68 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = {}
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self.__underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self.__underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
@ -1,15 +1,15 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
import torch
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
class PickleStorageBase(ABC, Generic[T]):
|
||||
"""Responsible for storing and retrieving non-serializable data using a pickler."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -17,18 +17,18 @@ class LatentsStorageBase(ABC):
|
||||
self._on_deleted_callbacks = []
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
def get(self, name: str) -> T:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: T) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
@ -36,7 +36,7 @@ class LatentsStorageBase(ABC):
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: torch.Tensor) -> None:
|
||||
def _on_changed(self, item: T) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
@ -0,0 +1,58 @@
|
||||
from queue import Queue
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PickleStorageForwardCache(PickleStorageBase[T]):
|
||||
def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self._underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, name: str) -> T:
|
||||
cache_item = self._get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self._underlying_storage.get(name)
|
||||
self._set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: T) -> None:
|
||||
self._underlying_storage.save(name, data)
|
||||
self._set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self._underlying_storage.delete(name)
|
||||
if name in self._cache:
|
||||
del self._cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def _get_cache(self, name: str) -> Optional[T]:
|
||||
return None if name not in self._cache else self._cache[name]
|
||||
|
||||
def _set_cache(self, name: str, data: T):
|
||||
if name not in self._cache:
|
||||
self._cache[name] = data
|
||||
self._cache_ids.put(name)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
62
invokeai/app/services/pickle_storage/pickle_storage_torch.py
Normal file
62
invokeai/app/services/pickle_storage/pickle_storage_torch.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PickleStorageTorch(PickleStorageBase[T]):
|
||||
"""Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`."""
|
||||
|
||||
def __init__(self, output_folder: Path, item_type_name: "str"):
|
||||
self._output_folder = output_folder
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
self._item_type_name = item_type_name
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_items()
|
||||
|
||||
def get(self, name: str) -> T:
|
||||
latent_path = self._get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: T) -> None:
|
||||
self._output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self._get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self._get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
def _get_path(self, name: str) -> Path:
|
||||
return self._output_folder / name
|
||||
|
||||
def _delete_all_items(self) -> None:
|
||||
"""
|
||||
Deletes all pickled items from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
|
||||
if not self._invoker:
|
||||
raise ValueError("Invoker is not set. Must call `start()` first.")
|
||||
|
||||
deleted_latents_count = 0
|
||||
freed_space = 0
|
||||
for latents_file in Path(self._output_folder).glob("*"):
|
||||
if latents_file.is_file():
|
||||
freed_space += latents_file.stat().st_size
|
||||
deleted_latents_count += 1
|
||||
latents_file.unlink()
|
||||
if deleted_latents_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)"
|
||||
)
|
@ -216,48 +216,46 @@ class ImagesInterface(InvocationContextInterface):
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
|
||||
class LatentsInterface(InvocationContextInterface):
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""
|
||||
Saves a latents tensor, returning its name.
|
||||
Saves a tensor, returning its name.
|
||||
|
||||
:param tensor: The latents tensor to save.
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
|
||||
# "mask", "noise", "masked_latents", etc.
|
||||
#
|
||||
# Retaining that capability in this wrapper would require either many different methods
|
||||
# to save latents, or extra args for this method. Instead of complicating the API, we
|
||||
# will use the same naming scheme for all latents.
|
||||
# to save tensors, or extra args for this method. Instead of complicating the API, we
|
||||
# will use the same naming scheme for all tensors.
|
||||
#
|
||||
# This has a very minor impact as we don't use them after a session completes.
|
||||
|
||||
# Previously, invocations chose the name for their latents. This is a bit risky, so we
|
||||
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
|
||||
# will generate a name for them instead. We use a uuid to ensure the name is unique.
|
||||
#
|
||||
# Because the name of the latents file will includes the session and invocation IDs,
|
||||
# Because the name of the tensors file will includes the session and invocation IDs,
|
||||
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.latents.save(
|
||||
self._services.tensors.save(
|
||||
name=name,
|
||||
data=tensor,
|
||||
)
|
||||
return name
|
||||
|
||||
def get(self, latents_name: str) -> Tensor:
|
||||
def get(self, tensor_name: str) -> Tensor:
|
||||
"""
|
||||
Gets a latents tensor by name.
|
||||
Gets a tensor by name.
|
||||
|
||||
:param latents_name: The name of the latents tensor to get.
|
||||
:param tensor_name: The name of the tensor to get.
|
||||
"""
|
||||
return self._services.latents.get(latents_name)
|
||||
return self._services.tensors.get(tensor_name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
|
||||
# service, but it is typed to work with Tensors only. We have to fudge the types here.
|
||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
@ -265,15 +263,12 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
:param conditioning_context_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
|
||||
#
|
||||
# See comment for `LatentsInterface.save` for more info about this method (it's very
|
||||
# similar).
|
||||
# See comment in TensorsInterface.save for why we generate the name here.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
|
||||
self._services.latents.save(
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.conditioning.save(
|
||||
name=name,
|
||||
data=conditioning_data, # type: ignore [arg-type]
|
||||
data=conditioning_data,
|
||||
)
|
||||
return name
|
||||
|
||||
@ -284,7 +279,7 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
:param conditioning_name: The name of the conditioning data to get.
|
||||
"""
|
||||
|
||||
return self._services.latents.get(conditioning_name) # type: ignore [return-value]
|
||||
return self._services.conditioning.get(conditioning_name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
@ -400,7 +395,7 @@ class InvocationContext:
|
||||
def __init__(
|
||||
self,
|
||||
images: ImagesInterface,
|
||||
latents: LatentsInterface,
|
||||
tensors: TensorsInterface,
|
||||
conditioning: ConditioningInterface,
|
||||
models: ModelsInterface,
|
||||
logger: LoggerInterface,
|
||||
@ -412,8 +407,8 @@ class InvocationContext:
|
||||
) -> None:
|
||||
self.images = images
|
||||
"""Provides methods to save, get and update images and their metadata."""
|
||||
self.latents = latents
|
||||
"""Provides methods to save and get latents tensors, including image, noise, masks, and masked images."""
|
||||
self.tensors = tensors
|
||||
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
|
||||
self.conditioning = conditioning
|
||||
"""Provides methods to save and get conditioning data."""
|
||||
self.models = models
|
||||
@ -532,7 +527,7 @@ def build_invocation_context(
|
||||
|
||||
logger = LoggerInterface(services=services, context_data=context_data)
|
||||
images = ImagesInterface(services=services, context_data=context_data)
|
||||
latents = LatentsInterface(services=services, context_data=context_data)
|
||||
tensors = TensorsInterface(services=services, context_data=context_data)
|
||||
models = ModelsInterface(services=services, context_data=context_data)
|
||||
config = ConfigInterface(services=services, context_data=context_data)
|
||||
util = UtilInterface(services=services, context_data=context_data)
|
||||
@ -543,7 +538,7 @@ def build_invocation_context(
|
||||
images=images,
|
||||
logger=logger,
|
||||
config=config,
|
||||
latents=latents,
|
||||
tensors=tensors,
|
||||
models=models,
|
||||
context_data=context_data,
|
||||
util=util,
|
||||
|
Loading…
Reference in New Issue
Block a user