# Initially pulled from https://github.com/black-forest-labs/flux import math from typing import Callable import torch from einops import rearrange, repeat from torch import Tensor from tqdm import tqdm from invokeai.app.util.step_callback import PipelineIntermediateState from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.conditioner import HFEncoder def get_noise( num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int, ): # We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes. rand_device = "cpu" rand_dtype = torch.float16 return torch.randn( num_samples, 16, # allow for packing 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=rand_device, dtype=rand_dtype, generator=torch.Generator(device=rand_device).manual_seed(seed), ).to(device=device, dtype=dtype) def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: bs, c, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) return { "img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device), } def time_shift(mu: float, sigma: float, t: Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b def get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True, ) -> list[float]: # extra step for zero timesteps = torch.linspace(1, 0, num_steps + 1) # shifting the schedule to favor high timesteps for higher signal images if shift: # eastimate mu based on linear estimation between two points mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() def denoise( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], step_callback: Callable[[Tensor, PipelineIntermediateState], None], guidance: float = 4.0, ): dtype = model.txt_in.bias.dtype # TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller. img = img.to(dtype=dtype) img_ids = img_ids.to(dtype=dtype) txt = txt.to(dtype=dtype) txt_ids = txt_ids.to(dtype=dtype) vec = vec.to(dtype=dtype) # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) step_count = 0 for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, ) img = img + (t_prev - t_curr) * pred step_callback( img, PipelineIntermediateState( step=step_count, order=0, total_steps=len(timesteps), timestep=math.floor(t_curr), latents=img, ), ) step_count += 1 return img def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2, ) def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert an input image in latent space to patches for diffusion. This implementation was extracted from: https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 Returns: tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. """ bs, c, h, w = latent_img.shape # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) # Generate patch position ids. img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) return img, img_ids