mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add blend_noise node
This commit is contained in:
parent
e3de996525
commit
4113fd0ccf
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ from .baseinvocation import (
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
@ -63,7 +65,7 @@ Nodes
|
|||||||
|
|
||||||
@invocation_output("noise_output")
|
@invocation_output("noise_output")
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output."""
|
||||||
|
|
||||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
@ -121,3 +123,62 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, noise)
|
context.services.latents.save(name, noise)
|
||||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"blend_noise", title="Blend Noise", tags=["latents", "noise", "variations"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
|
class BlendNoiseInvocation(BaseInvocation):
|
||||||
|
"""Blend two noise tensors according to a proportion. Useful for generating variations."""
|
||||||
|
|
||||||
|
noise_A: LatentsField = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=0)
|
||||||
|
noise_B: LatentsField = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=1)
|
||||||
|
blend_ratio: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
|
"""Combine two noise vectors, returning a blend that can be used to generate variations."""
|
||||||
|
noise_a = context.services.latents.get(self.noise_A.latents_name)
|
||||||
|
noise_b = context.services.latents.get(self.noise_B.latents_name)
|
||||||
|
|
||||||
|
if noise_a is None or noise_b is None:
|
||||||
|
raise Exception("Both noise_A and noise_B must be provided.")
|
||||||
|
if noise_a.shape != noise_b.shape:
|
||||||
|
raise Exception("Both noise_A and noise_B must be same dimensions.")
|
||||||
|
|
||||||
|
seed = self.noise_A.seed
|
||||||
|
alpha = self.blend_ratio
|
||||||
|
merged_noise = self.slerp(alpha, noise_a, noise_b)
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.save(name, merged_noise)
|
||||||
|
return build_noise_output(latents_name=name, latents=merged_noise, seed=seed)
|
||||||
|
|
||||||
|
def slerp(self, t: float, v0: torch.tensor, v1: torch.tensor, DOT_THRESHOLD: float = 0.9995):
|
||||||
|
"""
|
||||||
|
Spherical linear interpolation.
|
||||||
|
|
||||||
|
:param t: Mixing value, float between 0.0 and 1.0.
|
||||||
|
:param v0: Source noise
|
||||||
|
:param v1: Target noise
|
||||||
|
:DOT_THRESHOLD: Threshold for considering two vectors colineal. Don't change.
|
||||||
|
|
||||||
|
:Returns: Interpolation vector between v0 and v1
|
||||||
|
"""
|
||||||
|
device = v0.device or choose_torch_device()
|
||||||
|
v0 = v0.detach().cpu().numpy()
|
||||||
|
v1 = v1.detach().cpu().numpy()
|
||||||
|
|
||||||
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||||
|
if np.abs(dot) > DOT_THRESHOLD:
|
||||||
|
v2 = (1 - t) * v0 + t * v1
|
||||||
|
else:
|
||||||
|
theta_0 = np.arccos(dot)
|
||||||
|
sin_theta_0 = np.sin(theta_0)
|
||||||
|
theta_t = theta_0 * t
|
||||||
|
sin_theta_t = np.sin(theta_t)
|
||||||
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = sin_theta_t / sin_theta_0
|
||||||
|
v2 = s0 * v0 + s1 * v1
|
||||||
|
|
||||||
|
return torch.from_numpy(v2).to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user