Run ruff, setup initial text to image node

This commit is contained in:
Brandon Rising
2024-08-19 10:14:58 -04:00
committed by Brandon
parent 436f18ff55
commit 1bd90e0fd4
15 changed files with 291 additions and 124 deletions

View File

@ -27,4 +27,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -3,9 +3,15 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
@ -35,9 +41,7 @@ class Flux(nn.Module):
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
@ -108,4 +112,4 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
return img

View File

@ -309,4 +309,4 @@ class AutoEncoder(nn.Module):
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))
return self.decode(self.encode(x))

View File

@ -1,5 +1,6 @@
from torch import Tensor, nn
from transformers import (PreTrainedModel, PreTrainedTokenizer)
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
@ -27,4 +28,4 @@ class HFEncoder(nn.Module):
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]
return outputs[self.output_key]

View File

@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@ -250,4 +248,4 @@ class LastLayer(nn.Module):
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
return x

View File

@ -0,0 +1,134 @@
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Flux
from .modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)

View File

@ -1,14 +1,17 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
from pathlib import Path
import yaml
from dataclasses import fields
from safetensors.torch import load_file
from typing import Optional, Any
from transformers import T5EncoderModel, T5Tokenizer
from pathlib import Path
from typing import Any, Optional
import yaml
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux, FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -19,20 +22,15 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
MainCheckpointConfig,
CLIPEmbedDiffusersConfig,
MainCheckpointConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.model import Flux, FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
T5Tokenizer)
from invokeai.backend.util.silence_warnings import SilenceWarnings
app_config = get_config()
@ -56,9 +54,9 @@ class FluxVAELoader(GenericDiffusersLoader):
flux_conf = yaml.safe_load(stream)
except:
raise
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields}
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)
with SilenceWarnings():
@ -92,6 +90,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
raise Exception("Only Checkpoint Flux models are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(GenericDiffusersLoader):
"""Class to load main models."""
@ -106,9 +105,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path), max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path))
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
raise Exception("Only Checkpoint Flux models are currently supported.")
@ -148,7 +147,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings():

View File

@ -39,11 +39,15 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
)
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""

View File

@ -9,7 +9,7 @@ from typing import Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@ -52,7 +52,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
return model.calc_size()
elif isinstance(
model,
(T5TokenizerFast,T5Tokenizer,),
(
T5TokenizerFast,
T5Tokenizer,
),
):
return len(model)
else:

View File

@ -56,7 +56,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
}
},
}
@ -132,7 +132,7 @@ class ModelProbe(object):
fields = {}
model_path = model_path.resolve()
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
@ -323,7 +323,7 @@ class ModelProbe(object):
if model_type is ModelType.Main:
if base_type == BaseModelType.Flux:
config_file="flux/flux1-schnell.yaml"
config_file = "flux/flux1-schnell.yaml"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
@ -727,6 +727,7 @@ class T5EncoderFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.T5Encoder
class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors