mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add blend_noise node
This commit is contained in:
parent
e3de996525
commit
4113fd0ccf
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pydantic import validator
|
||||
|
||||
@ -12,6 +13,7 @@ from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
@ -63,7 +65,7 @@ Nodes
|
||||
|
||||
@invocation_output("noise_output")
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
"""Invocation noise output."""
|
||||
|
||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
@ -121,3 +123,62 @@ class NoiseInvocation(BaseInvocation):
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
"blend_noise", title="Blend Noise", tags=["latents", "noise", "variations"], category="latents", version="1.0.0"
|
||||
)
|
||||
class BlendNoiseInvocation(BaseInvocation):
|
||||
"""Blend two noise tensors according to a proportion. Useful for generating variations."""
|
||||
|
||||
noise_A: LatentsField = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=0)
|
||||
noise_B: LatentsField = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=1)
|
||||
blend_ratio: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
"""Combine two noise vectors, returning a blend that can be used to generate variations."""
|
||||
noise_a = context.services.latents.get(self.noise_A.latents_name)
|
||||
noise_b = context.services.latents.get(self.noise_B.latents_name)
|
||||
|
||||
if noise_a is None or noise_b is None:
|
||||
raise Exception("Both noise_A and noise_B must be provided.")
|
||||
if noise_a.shape != noise_b.shape:
|
||||
raise Exception("Both noise_A and noise_B must be same dimensions.")
|
||||
|
||||
seed = self.noise_A.seed
|
||||
alpha = self.blend_ratio
|
||||
merged_noise = self.slerp(alpha, noise_a, noise_b)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, merged_noise)
|
||||
return build_noise_output(latents_name=name, latents=merged_noise, seed=seed)
|
||||
|
||||
def slerp(self, t: float, v0: torch.tensor, v1: torch.tensor, DOT_THRESHOLD: float = 0.9995):
|
||||
"""
|
||||
Spherical linear interpolation.
|
||||
|
||||
:param t: Mixing value, float between 0.0 and 1.0.
|
||||
:param v0: Source noise
|
||||
:param v1: Target noise
|
||||
:DOT_THRESHOLD: Threshold for considering two vectors colineal. Don't change.
|
||||
|
||||
:Returns: Interpolation vector between v0 and v1
|
||||
"""
|
||||
device = v0.device or choose_torch_device()
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
return torch.from_numpy(v2).to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user