2024-08-21 19:53:58 +00:00
|
|
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
|
|
|
|
2024-08-19 14:14:58 +00:00
|
|
|
import math
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
|
|
|
|
def get_noise(
|
|
|
|
num_samples: int,
|
|
|
|
height: int,
|
|
|
|
width: int,
|
|
|
|
device: torch.device,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
seed: int,
|
|
|
|
):
|
2024-08-22 15:56:30 +00:00
|
|
|
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
|
|
|
rand_device = "cpu"
|
|
|
|
rand_dtype = torch.float16
|
2024-08-19 14:14:58 +00:00
|
|
|
return torch.randn(
|
|
|
|
num_samples,
|
|
|
|
16,
|
|
|
|
# allow for packing
|
|
|
|
2 * math.ceil(height / 16),
|
|
|
|
2 * math.ceil(width / 16),
|
2024-08-22 15:56:30 +00:00
|
|
|
device=rand_device,
|
|
|
|
dtype=rand_dtype,
|
|
|
|
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
|
|
).to(device=device, dtype=dtype)
|
2024-08-19 14:14:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
def time_shift(mu: float, sigma: float, t: Tensor):
|
|
|
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
|
|
|
|
|
|
|
|
|
|
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
|
|
|
m = (y2 - y1) / (x2 - x1)
|
|
|
|
b = y1 - m * x1
|
|
|
|
return lambda x: m * x + b
|
|
|
|
|
|
|
|
|
|
|
|
def get_schedule(
|
|
|
|
num_steps: int,
|
|
|
|
image_seq_len: int,
|
|
|
|
base_shift: float = 0.5,
|
|
|
|
max_shift: float = 1.15,
|
|
|
|
shift: bool = True,
|
|
|
|
) -> list[float]:
|
|
|
|
# extra step for zero
|
|
|
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
|
|
|
|
|
|
|
# shifting the schedule to favor high timesteps for higher signal images
|
|
|
|
if shift:
|
2024-08-29 14:17:08 +00:00
|
|
|
# estimate mu based on linear estimation between two points
|
2024-08-19 14:14:58 +00:00
|
|
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
|
|
|
timesteps = time_shift(mu, 1.0, timesteps)
|
|
|
|
|
|
|
|
return timesteps.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
|
|
|
return rearrange(
|
|
|
|
x,
|
|
|
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
|
|
h=math.ceil(height / 16),
|
|
|
|
w=math.ceil(width / 16),
|
|
|
|
ph=2,
|
|
|
|
pw=2,
|
|
|
|
)
|
2024-08-22 17:18:43 +00:00
|
|
|
|
|
|
|
|
2024-08-29 19:05:44 +00:00
|
|
|
def pack(x: Tensor) -> Tensor:
|
|
|
|
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
|
|
|
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor:
|
|
|
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
|
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
|
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
|
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
|
|
|
return img_ids
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
2024-08-22 17:18:43 +00:00
|
|
|
"""Convert an input image in latent space to patches for diffusion.
|
|
|
|
|
|
|
|
This implementation was extracted from:
|
|
|
|
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
|
|
|
|
2024-08-29 19:05:44 +00:00
|
|
|
Args:
|
|
|
|
img (torch.Tensor): Input image in latent space.
|
|
|
|
|
2024-08-22 17:18:43 +00:00
|
|
|
Returns:
|
|
|
|
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
|
|
|
"""
|
2024-08-29 19:05:44 +00:00
|
|
|
bs, c, h, w = img.shape
|
2024-08-22 17:18:43 +00:00
|
|
|
|
2024-08-29 19:05:44 +00:00
|
|
|
img = pack(img)
|
2024-08-22 17:18:43 +00:00
|
|
|
|
|
|
|
# Generate patch position ids.
|
2024-08-28 15:03:08 +00:00
|
|
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
|
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
|
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
2024-08-22 17:18:43 +00:00
|
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
|
|
|
|
|
|
return img, img_ids
|