from typing import Any, Union import numpy as np import numpy.typing as npt import torch from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.util.devices import TorchDevice @invocation( "lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.3", ) class BlendLatentsInvocation(BaseInvocation): """Blend two latents using a given alpha. Latents must have same size.""" latents_a: LatentsField = InputField( description=FieldDescriptions.latents, input=Input.Connection, ) latents_b: LatentsField = InputField( description=FieldDescriptions.latents, input=Input.Connection, ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: latents_a = context.tensors.load(self.latents_a.latents_name) latents_b = context.tensors.load(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") device = TorchDevice.choose_torch_device() def slerp( t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? v0: Union[torch.Tensor, npt.NDArray[Any]], v1: Union[torch.Tensor, npt.NDArray[Any]], DOT_THRESHOLD: float = 0.9995, ) -> Union[torch.Tensor, npt.NDArray[Any]]: """ Spherical linear interpolation Args: t (float/np.ndarray): Float value between 0.0 and 1.0 v0 (np.ndarray): Starting vector v1 (np.ndarray): Final vector DOT_THRESHOLD (float): Threshold for considering the two vectors as colineal. Not recommended to alter this. Returns: v2 (np.ndarray): Interpolation vector between v0 and v1 """ inputs_are_torch = False if not isinstance(v0, np.ndarray): inputs_are_torch = True v0 = v0.detach().cpu().numpy() if not isinstance(v1, np.ndarray): inputs_are_torch = True v1 = v1.detach().cpu().numpy() dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) if np.abs(dot) > DOT_THRESHOLD: v2 = (1 - t) * v0 + t * v1 else: theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 v2 = s0 * v0 + s1 * v1 if inputs_are_torch: v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) return v2_torch else: assert isinstance(v2, np.ndarray) return v2 # blend bl = slerp(self.alpha, latents_a, latents_b) assert isinstance(bl, torch.Tensor) blended_latents: torch.Tensor = bl # for type checking convenience # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") TorchDevice.empty_cache() name = context.tensors.save(tensor=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)