mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): use ItemStorageABC
for tensors and conditioning
Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both. There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution. The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
This commit is contained in:
@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
@ -224,26 +223,7 @@ class TensorsInterface(InvocationContextInterface):
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
|
||||
# "mask", "noise", "masked_latents", etc.
|
||||
#
|
||||
# Retaining that capability in this wrapper would require either many different methods
|
||||
# to save tensors, or extra args for this method. Instead of complicating the API, we
|
||||
# will use the same naming scheme for all tensors.
|
||||
#
|
||||
# This has a very minor impact as we don't use them after a session completes.
|
||||
|
||||
# Previously, invocations chose the name for their tensors. This is a bit risky, so we
|
||||
# will generate a name for them instead. We use a uuid to ensure the name is unique.
|
||||
#
|
||||
# Because the name of the tensors file will includes the session and invocation IDs,
|
||||
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.tensors.save(
|
||||
name=name,
|
||||
data=tensor,
|
||||
)
|
||||
name = self._services.tensors.set(item=tensor)
|
||||
return name
|
||||
|
||||
def get(self, tensor_name: str) -> Tensor:
|
||||
@ -263,13 +243,7 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
:param conditioning_context_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
# See comment in TensorsInterface.save for why we generate the name here.
|
||||
|
||||
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||
self._services.conditioning.save(
|
||||
name=name,
|
||||
data=conditioning_data,
|
||||
)
|
||||
name = self._services.conditioning.set(item=conditioning_data)
|
||||
return name
|
||||
|
||||
def get(self, conditioning_name: str) -> ConditioningFieldData:
|
||||
|
Reference in New Issue
Block a user