feat(nodes): use ItemStorageABC for tensors and conditioning

Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both.

There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution.

The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
This commit is contained in:
psychedelicious
2024-02-07 19:39:03 +11:00
parent ca09bd63a3
commit a50c7c1cd7
10 changed files with 145 additions and 204 deletions

View File

@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
@ -224,26 +223,7 @@ class TensorsInterface(InvocationContextInterface):
:param tensor: The tensor to save.
"""
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
# "mask", "noise", "masked_latents", etc.
#
# Retaining that capability in this wrapper would require either many different methods
# to save tensors, or extra args for this method. Instead of complicating the API, we
# will use the same naming scheme for all tensors.
#
# This has a very minor impact as we don't use them after a session completes.
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
# will generate a name for them instead. We use a uuid to ensure the name is unique.
#
# Because the name of the tensors file will includes the session and invocation IDs,
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.tensors.save(
name=name,
data=tensor,
)
name = self._services.tensors.set(item=tensor)
return name
def get(self, tensor_name: str) -> Tensor:
@ -263,13 +243,7 @@ class ConditioningInterface(InvocationContextInterface):
:param conditioning_context_data: The conditioning data to save.
"""
# See comment in TensorsInterface.save for why we generate the name here.
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
self._services.conditioning.save(
name=name,
data=conditioning_data,
)
name = self._services.conditioning.set(item=conditioning_data)
return name
def get(self, conditioning_name: str) -> ConditioningFieldData: