mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
136 lines
4.5 KiB
Python
136 lines
4.5 KiB
Python
# Initially pulled from https://github.com/black-forest-labs/flux
|
|
|
|
import math
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
def get_noise(
|
|
num_samples: int,
|
|
height: int,
|
|
width: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
):
|
|
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
|
rand_device = "cpu"
|
|
rand_dtype = torch.float16
|
|
return torch.randn(
|
|
num_samples,
|
|
16,
|
|
# allow for packing
|
|
2 * math.ceil(height / 16),
|
|
2 * math.ceil(width / 16),
|
|
device=rand_device,
|
|
dtype=rand_dtype,
|
|
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
).to(device=device, dtype=dtype)
|
|
|
|
|
|
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
|
|
|
|
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
|
m = (y2 - y1) / (x2 - x1)
|
|
b = y1 - m * x1
|
|
return lambda x: m * x + b
|
|
|
|
|
|
def get_schedule(
|
|
num_steps: int,
|
|
image_seq_len: int,
|
|
base_shift: float = 0.5,
|
|
max_shift: float = 1.15,
|
|
shift: bool = True,
|
|
) -> list[float]:
|
|
# extra step for zero
|
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
|
|
|
# shifting the schedule to favor high timesteps for higher signal images
|
|
if shift:
|
|
# estimate mu based on linear estimation between two points
|
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
|
timesteps = time_shift(mu, 1.0, timesteps)
|
|
|
|
return timesteps.tolist()
|
|
|
|
|
|
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
|
"""Find the last index in timesteps that is >= val.
|
|
|
|
We use epsilon-close equality to avoid potential floating point errors.
|
|
"""
|
|
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
|
assert idx >= 0
|
|
return idx
|
|
|
|
|
|
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
|
"""Clip the timestep schedule to the denoising range.
|
|
|
|
Args:
|
|
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
|
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
|
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
|
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
|
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
|
|
|
Returns:
|
|
list[float]: The clipped timestep schedule.
|
|
"""
|
|
assert 0.0 <= denoising_start <= 1.0
|
|
assert 0.0 <= denoising_end <= 1.0
|
|
assert denoising_start <= denoising_end
|
|
|
|
t_start_val = 1.0 - denoising_start
|
|
t_end_val = 1.0 - denoising_end
|
|
|
|
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
|
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
|
|
|
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
|
|
|
return clipped_timesteps
|
|
|
|
|
|
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""Unpack flat array of patch embeddings to latent image."""
|
|
return rearrange(
|
|
x,
|
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
h=math.ceil(height / 16),
|
|
w=math.ceil(width / 16),
|
|
ph=2,
|
|
pw=2,
|
|
)
|
|
|
|
|
|
def pack(x: torch.Tensor) -> torch.Tensor:
|
|
"""Pack latent image to flattented array of patch embeddings."""
|
|
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
|
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
|
|
|
|
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
|
"""Generate tensor of image position ids.
|
|
|
|
Args:
|
|
h (int): Height of image in latent space.
|
|
w (int): Width of image in latent space.
|
|
batch_size (int): Batch size.
|
|
device (torch.device): Device.
|
|
dtype (torch.dtype): dtype.
|
|
|
|
Returns:
|
|
torch.Tensor: Image position ids.
|
|
"""
|
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
|
return img_ids
|