add BlendInvocation

This commit is contained in:
Damian Stewart 2023-08-20 20:49:18 +02:00 committed by psychedelicious
parent beb3e5aeb7
commit 2bcded78e1
2 changed files with 128 additions and 0 deletions

View File

@ -71,6 +71,7 @@ class FieldDescriptions:
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
blend_alpha = "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"

View File

@ -4,6 +4,7 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor
@ -720,3 +721,129 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
@title("Resize Latents")
@tags("latents", "resize")
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
width: int = InputField(
ge=64,
multiple_of=8,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=8,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device = choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Blend Latents")
@tags("latents", "blend")
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
type: Literal["lblend"] = "lblend"
# Inputs
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.services.latents.get(self.latents_a.latents_name)
latents_b = context.services.latents.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise "Latents to blend must be the same size."
# TODO:
device = choose_torch_device()
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(device)
return v2
# blend
blended_latents = slerp(self.alpha, latents_a, latents_b)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, blended_latents)
return build_latents_output(latents_name=name, latents=blended_latents)