mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add backend functions and classes for Flux implementation, Update the way flux encoders/tokenizers are loaded for prompt encoding, Update way flux vae is loaded
This commit is contained in:
30
invokeai/backend/flux/math.py
Normal file
30
invokeai/backend/flux/math.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
111
invokeai/backend/flux/model.py
Normal file
111
invokeai/backend/flux/model.py
Normal file
@ -0,0 +1,111 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
312
invokeai/backend/flux/modules/autoencoder.py
Normal file
312
invokeai/backend/flux/modules/autoencoder.py
Normal file
@ -0,0 +1,312 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
30
invokeai/backend/flux/modules/conditioner.py
Normal file
30
invokeai/backend/flux/modules/conditioner.py
Normal file
@ -0,0 +1,30 @@
|
||||
from torch import Tensor, nn
|
||||
from transformers import (PreTrainedModel, PreTrainedTokenizer)
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_module = encoder
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
253
invokeai/backend/flux/modules/layers.py
Normal file
253
invokeai/backend/flux/modules/layers.py
Normal file
@ -0,0 +1,253 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
@ -67,7 +67,9 @@ class ModelType(str, Enum):
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
CLIPEmbed = "clip_embed"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
T5Encoder = "t5_encoder"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
@ -106,6 +108,9 @@ class ModelFormat(str, Enum):
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
T5Encoder = "t5_encoder"
|
||||
T5Encoder8b = "t5_encoder_8b"
|
||||
T5Encoder4b = "t5_encoder_4b"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@ -207,6 +212,18 @@ class LoRAConfigBase(ModelConfigBase):
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
|
||||
class T5EncoderConfigBase(ModelConfigBase):
|
||||
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
||||
|
||||
|
||||
class T5EncoderConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@ -352,6 +369,17 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@ -416,6 +444,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
@ -423,6 +452,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
159
invokeai/backend/model_manager/load/model_loaders/flux.py
Normal file
159
invokeai/backend/model_manager/load/model_loaders/flux.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
from dataclasses import fields
|
||||
from safetensors.torch import load_file
|
||||
from typing import Optional, Any
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
MainCheckpointConfig,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.flux.model import Flux, FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder
|
||||
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
||||
T5Tokenizer)
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class FluxVAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, VAECheckpointConfig):
|
||||
model_path = Path(config.path)
|
||||
load_class = AutoEncoder
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
try:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
except:
|
||||
raise
|
||||
|
||||
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields}
|
||||
params = AutoEncoderParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = load_class(params).to(self._torch_dtype)
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
return model
|
||||
else:
|
||||
return super()._load_model(config, submodel_type)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||
class ClipCheckpointModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
return CLIPTokenizer.from_pretrained(config.path, max_length=77)
|
||||
case SubModelType.TextEncoder:
|
||||
return CLIPTextModel.from_pretrained(config.path)
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||
class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderConfig):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path), max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path))
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class FluxCheckpointModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
try:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
except:
|
||||
raise
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
flux_conf: Any,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
load_class = Flux
|
||||
params = None
|
||||
model_path = Path(config.path)
|
||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields}
|
||||
params = FluxParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = load_class(params).to(self._torch_dtype)
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
return model
|
@ -36,8 +36,14 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
|
@ -9,7 +9,7 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
@ -52,7 +52,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
return model.calc_size()
|
||||
elif isinstance(
|
||||
model,
|
||||
(T5TokenizerFast,),
|
||||
(T5TokenizerFast,T5Tokenizer,),
|
||||
):
|
||||
return len(model)
|
||||
else:
|
||||
|
@ -56,7 +56,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -132,7 +132,7 @@ class ModelProbe(object):
|
||||
fields = {}
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
||||
@ -162,7 +162,7 @@ class ModelProbe(object):
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["format"] = ModelFormat(fields.get("format")) or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
@ -223,7 +223,7 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
@ -322,10 +322,13 @@ class ModelProbe(object):
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
if base_type == BaseModelType.Flux:
|
||||
config_file="flux/flux1-schnell.yaml"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"controlnet/cldm_v15.yaml"
|
||||
@ -334,7 +337,9 @@ class ModelProbe(object):
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
"flux/flux1-schnell.yaml"
|
||||
if base_type is BaseModelType.Flux
|
||||
else "stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "stable-diffusion/sd_xl_base.yaml"
|
||||
if base_type is BaseModelType.StableDiffusionXL
|
||||
@ -421,7 +426,8 @@ class CheckpointProbeBase(ProbeBase):
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
@ -441,6 +447,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
return BaseModelType.Flux
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
@ -483,6 +491,7 @@ class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
@ -627,10 +636,6 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(f"{self.model_path}/model_index.json", "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline":
|
||||
return BaseModelType.Flux
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
@ -718,6 +723,10 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class T5EncoderFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.T5Encoder
|
||||
|
||||
class ONNXFolderProbe(PipelineFolderProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# Due to the way the installer is set up, the configuration file for safetensors
|
||||
@ -810,6 +819,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class CLIPEmbedFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
@ -840,8 +854,10 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
Reference in New Issue
Block a user