mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move BlendLatentsInvocation to its own file. No functional changes.
This commit is contained in:
parent
ed03d281e6
commit
595096bdcf
98
invokeai/app/invocations/blend_latents.py
Normal file
98
invokeai/app/invocations/blend_latents.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"lblend",
|
||||||
|
title="Blend Latents",
|
||||||
|
tags=["latents", "blend"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.3",
|
||||||
|
)
|
||||||
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
latents_a: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
latents_b: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||||
|
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||||
|
|
||||||
|
if latents_a.shape != latents_b.shape:
|
||||||
|
raise Exception("Latents to blend must be the same size.")
|
||||||
|
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
|
def slerp(
|
||||||
|
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||||
|
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||||
|
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||||
|
DOT_THRESHOLD: float = 0.9995,
|
||||||
|
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||||
|
"""
|
||||||
|
Spherical linear interpolation
|
||||||
|
Args:
|
||||||
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||||
|
v0 (np.ndarray): Starting vector
|
||||||
|
v1 (np.ndarray): Final vector
|
||||||
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||||
|
colineal. Not recommended to alter this.
|
||||||
|
Returns:
|
||||||
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||||
|
"""
|
||||||
|
inputs_are_torch = False
|
||||||
|
if not isinstance(v0, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v0 = v0.detach().cpu().numpy()
|
||||||
|
if not isinstance(v1, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v1 = v1.detach().cpu().numpy()
|
||||||
|
|
||||||
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||||
|
if np.abs(dot) > DOT_THRESHOLD:
|
||||||
|
v2 = (1 - t) * v0 + t * v1
|
||||||
|
else:
|
||||||
|
theta_0 = np.arccos(dot)
|
||||||
|
sin_theta_0 = np.sin(theta_0)
|
||||||
|
theta_t = theta_0 * t
|
||||||
|
sin_theta_t = np.sin(theta_t)
|
||||||
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = sin_theta_t / sin_theta_0
|
||||||
|
v2 = s0 * v0 + s1 * v1
|
||||||
|
|
||||||
|
if inputs_are_torch:
|
||||||
|
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||||
|
return v2_torch
|
||||||
|
else:
|
||||||
|
assert isinstance(v2, np.ndarray)
|
||||||
|
return v2
|
||||||
|
|
||||||
|
# blend
|
||||||
|
bl = slerp(self.alpha, latents_a, latents_b)
|
||||||
|
assert isinstance(bl, torch.Tensor)
|
||||||
|
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
blended_latents = blended_latents.to("cpu")
|
||||||
|
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=blended_latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
|
@ -6,7 +6,6 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
@ -1304,90 +1303,3 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"lblend",
|
|
||||||
title="Blend Latents",
|
|
||||||
tags=["latents", "blend"],
|
|
||||||
category="latents",
|
|
||||||
version="1.0.3",
|
|
||||||
)
|
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
|
||||||
|
|
||||||
latents_a: LatentsField = InputField(
|
|
||||||
description=FieldDescriptions.latents,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
latents_b: LatentsField = InputField(
|
|
||||||
description=FieldDescriptions.latents,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
latents_a = context.tensors.load(self.latents_a.latents_name)
|
|
||||||
latents_b = context.tensors.load(self.latents_b.latents_name)
|
|
||||||
|
|
||||||
if latents_a.shape != latents_b.shape:
|
|
||||||
raise Exception("Latents to blend must be the same size.")
|
|
||||||
|
|
||||||
device = TorchDevice.choose_torch_device()
|
|
||||||
|
|
||||||
def slerp(
|
|
||||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
|
||||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
|
||||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
|
||||||
DOT_THRESHOLD: float = 0.9995,
|
|
||||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
|
||||||
"""
|
|
||||||
Spherical linear interpolation
|
|
||||||
Args:
|
|
||||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
|
||||||
v0 (np.ndarray): Starting vector
|
|
||||||
v1 (np.ndarray): Final vector
|
|
||||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
|
||||||
colineal. Not recommended to alter this.
|
|
||||||
Returns:
|
|
||||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
|
||||||
"""
|
|
||||||
inputs_are_torch = False
|
|
||||||
if not isinstance(v0, np.ndarray):
|
|
||||||
inputs_are_torch = True
|
|
||||||
v0 = v0.detach().cpu().numpy()
|
|
||||||
if not isinstance(v1, np.ndarray):
|
|
||||||
inputs_are_torch = True
|
|
||||||
v1 = v1.detach().cpu().numpy()
|
|
||||||
|
|
||||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
|
||||||
if np.abs(dot) > DOT_THRESHOLD:
|
|
||||||
v2 = (1 - t) * v0 + t * v1
|
|
||||||
else:
|
|
||||||
theta_0 = np.arccos(dot)
|
|
||||||
sin_theta_0 = np.sin(theta_0)
|
|
||||||
theta_t = theta_0 * t
|
|
||||||
sin_theta_t = np.sin(theta_t)
|
|
||||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
|
||||||
s1 = sin_theta_t / sin_theta_0
|
|
||||||
v2 = s0 * v0 + s1 * v1
|
|
||||||
|
|
||||||
if inputs_are_torch:
|
|
||||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
|
||||||
return v2_torch
|
|
||||||
else:
|
|
||||||
assert isinstance(v2, np.ndarray)
|
|
||||||
return v2
|
|
||||||
|
|
||||||
# blend
|
|
||||||
bl = slerp(self.alpha, latents_a, latents_b)
|
|
||||||
assert isinstance(bl, torch.Tensor)
|
|
||||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
blended_latents = blended_latents.to("cpu")
|
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=blended_latents)
|
|
||||||
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user