mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run ruff, setup initial text to image node
This commit is contained in:
@ -27,4 +27,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
@ -3,9 +3,15 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
@ -35,9 +41,7 @@ class Flux(nn.Module):
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
@ -108,4 +112,4 @@ class Flux(nn.Module):
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
return img
|
||||
|
@ -309,4 +309,4 @@ class AutoEncoder(nn.Module):
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
return self.decode(self.encode(x))
|
||||
|
@ -1,5 +1,6 @@
|
||||
from torch import Tensor, nn
|
||||
from transformers import (PreTrainedModel, PreTrainedTokenizer)
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
@ -27,4 +28,4 @@ class HFEncoder(nn.Module):
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
||||
return outputs[self.output_key]
|
||||
|
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
@ -250,4 +248,4 @@ class LastLayer(nn.Module):
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return x
|
||||
|
134
invokeai/backend/flux/sampling.py
Normal file
134
invokeai/backend/flux/sampling.py
Normal file
@ -0,0 +1,134 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
|
||||
from .model import Flux
|
||||
from .modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
Reference in New Issue
Block a user