feat(nodes): use ItemStorageABC for tensors and conditioning

Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both.

There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution.

The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
This commit is contained in:
psychedelicious
2024-02-07 19:39:03 +11:00
parent ca09bd63a3
commit a50c7c1cd7
10 changed files with 145 additions and 204 deletions

View File

@ -4,9 +4,9 @@ from logging import Logger
import torch
from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk
from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache
from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -90,9 +90,9 @@ class ApiDependencies:
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor"))
conditioning = PickleStorageForwardCache(
PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning")
tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors"))
conditioning = ItemStorageForwardCache(
ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)