Run ruff, setup initial text to image node

This commit is contained in:
Brandon Rising 2024-08-19 10:14:58 -04:00 committed by Brandon
parent 436f18ff55
commit 1bd90e0fd4
15 changed files with 291 additions and 124 deletions

View File

@ -1,8 +1,4 @@
import torch import torch
from einops import repeat
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -10,9 +6,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.modules.conditioner import HFEncoder from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@invocation( @invocation(

View File

@ -1,12 +1,6 @@
from typing import Literal
import accelerate
import torch import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from einops import rearrange, repeat
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from PIL import Image from PIL import Image
from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@ -20,23 +14,12 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.flux.model import Flux
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
auto_class = AutoModelForTextEncoding
@invocation( @invocation(
"flux_text_to_image", "flux_text_to_image",
@ -78,7 +61,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(flux_conditioning, FLUXConditioningInfo) assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents) image = self._run_vae_decoding(context, flux_ae_path, latents)
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@ -89,14 +72,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
t5_embeddings: torch.Tensor, t5_embeddings: torch.Tensor,
): ):
transformer_info = context.models.load(self.transformer.transformer) transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = TorchDevice.choose_torch_dtype()
# Prepare input noise.
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
# CPU RNG?
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty. # if the cache is not empty.
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer: with transformer_info as transformer:
assert isinstance(transformer, FluxTransformer2DModel) assert isinstance(transformer, Flux)
x = denoise( x = denoise(
model=transformer, model=transformer,
@ -144,21 +153,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
) -> Image.Image: ) -> Image.Image:
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(self.vae.vae)
with vae_info as vae: with vae_info as vae:
assert isinstance(vae, AutoencoderKL) assert isinstance(vae, AutoEncoder)
# TODO(ryand): Test that this works with both float16 and bfloat16.
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
img = vae.decode(latents)
img.clamp(-1, 1) img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
latents = flux_pipeline_with_vae._unpack_latents( return img_pil
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
)
latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
assert isinstance(image, Image.Image)
return image

View File

@ -1,6 +1,6 @@
import copy import copy
from time import sleep from time import sleep
from typing import List, Optional, Literal, Dict from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -12,10 +12,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
from invokeai.app.services.model_records import ModelRecordChanges from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
class ModelIdentifierField(BaseModel): class ModelIdentifierField(BaseModel):
@ -132,31 +132,22 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model) return ModelIdentifierOutput(model=self.model)
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = { T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"base": { "base": {
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2", "repo": "invokeai/flux_dev::t5_xxl_encoder/base",
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2", "name": "t5_base_encoder",
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
"format": ModelFormat.T5Encoder, "format": ModelFormat.T5Encoder,
}, },
"8b_quantized": { "8b_quantized": {
"text_encoder_repo": "hf_repo1", "repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
"tokenizer_repo": "hf_repo1", "name": "t5_8b_quantized_encoder",
"text_encoder_name": "hf_repo1", "format": ModelFormat.T5Encoder,
"tokenizer_name": "hf_repo1",
"format": ModelFormat.T5Encoder8b,
},
"4b_quantized": {
"text_encoder_repo": "hf_repo2",
"tokenizer_repo": "hf_repo2",
"text_encoder_name": "hf_repo2",
"tokenizer_name": "hf_repo2",
"format": ModelFormat.T5Encoder8b,
}, },
} }
@invocation_output("flux_model_loader_output") @invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput): class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output""" """Flux base model loader output"""
@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
ui_type=UIType.FluxMainModel, ui_type=UIType.FluxMainModel,
input=Input.Direct, input=Input.Direct,
) )
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.") t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
@ -189,7 +180,15 @@ class FluxModelLoaderInvocation(BaseInvocation):
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2) tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder) clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2) t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux) vae = self._install_model(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
ModelFormat.Checkpoint,
ModelType.VAE,
BaseModelType.Flux,
)
return FluxModelLoaderOutput( return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer), transformer=TransformerField(transformer=transformer),
@ -198,33 +197,59 @@ class FluxModelLoaderInvocation(BaseInvocation):
vae=VAEField(vae=vae), vae=VAEField(vae=vae),
) )
def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField: def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match(submodel): match submodel:
case SubModelType.Transformer: case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]: case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any) return self._install_model(
case SubModelType.TextEncoder2: context,
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any) submodel,
case SubModelType.Tokenizer2: "clip-vit-large-patch14",
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any) "openai/clip-vit-large-patch14",
ModelFormat.Diffusers,
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._install_model(
context,
submodel,
T5_ENCODER_MAP[self.t5_encoder]["name"],
T5_ENCODER_MAP[self.t5_encoder]["repo"],
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
ModelType.T5Encoder,
BaseModelType.Any,
)
case _: case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model") raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType): def _install_model(
if (models := context.models.search_by_attrs(name=name, base=base, type=type)): self,
context: InvocationContext,
submodel: SubModelType,
name: str,
repo_id: str,
format: ModelFormat,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1: if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}") raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel}) return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else: else:
model_path = context.models.download_and_cache_model(repo_id) model_path = context.models.download_and_cache_model(repo_id)
config = ModelRecordChanges(name = name, base = base, type=type, format=format) config = ModelRecordChanges(name=name, base=base, type=type, format=format)
model_install_job = context.models.import_local_model(model_path=model_path, config=config) model_install_job = context.models.import_local_model(model_path=model_path, config=config)
while not model_install_job.in_terminal_state: while not model_install_job.in_terminal_state:
sleep(0.01) sleep(0.01)
if not model_install_job.config_out: if not model_install_job.config_out:
raise Exception(f"Failed to install {name}") raise Exception(f"Failed to install {name}")
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel}) return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
update={"submodel_type": submodel}
)
@invocation( @invocation(
"main_model_loader", "main_model_loader",

View File

@ -301,7 +301,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
for row in result: for row in result:
try: try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1]) model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError as e: except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database. # We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a # One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type. # newer version of the app that added a new model type.

View File

@ -465,18 +465,20 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.install.download_and_cache_model(source=source) return self._services.model_manager.install.download_and_cache_model(source=source)
def import_local_model( def import_local_model(
self, self,
model_path: Path, model_path: Path,
config: Optional[ModelRecordChanges] = None, config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
): ):
""" """
TODO: Fill out description of this method TODO: Fill out description of this method
""" """
if not model_path.exists(): if not model_path.exists():
raise Exception("Models provided to import_local_model must already exist on disk") raise Exception("Models provided to import_local_model must already exist on disk")
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace) return self._services.model_manager.install.heuristic_import(
str(model_path), config=config, access_token=access_token, inplace=inplace
)
def load_local_model( def load_local_model(
self, self,

View File

@ -27,4 +27,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -3,9 +3,15 @@ from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, from invokeai.backend.flux.modules.layers import (
MLPEmbedder, SingleStreamBlock, DoubleStreamBlock,
timestep_embedding) EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass @dataclass
class FluxParams: class FluxParams:
@ -35,9 +41,7 @@ class Flux(nn.Module):
self.in_channels = params.in_channels self.in_channels = params.in_channels
self.out_channels = self.in_channels self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0: if params.hidden_size % params.num_heads != 0:
raise ValueError( raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim: if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
@ -108,4 +112,4 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img return img

View File

@ -309,4 +309,4 @@ class AutoEncoder(nn.Module):
return self.decoder(z) return self.decoder(z)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x)) return self.decode(self.encode(x))

View File

@ -1,5 +1,6 @@
from torch import Tensor, nn from torch import Tensor, nn
from transformers import (PreTrainedModel, PreTrainedTokenizer) from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module): class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int): def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
@ -27,4 +28,4 @@ class HFEncoder(nn.Module):
attention_mask=None, attention_mask=None,
output_hidden_states=False, output_hidden_states=False,
) )
return outputs[self.output_key] return outputs[self.output_key]

View File

@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
""" """
t = time_factor * t t = time_factor * t
half = dim // 2 half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
t.device
)
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@ -250,4 +248,4 @@ class LastLayer(nn.Module):
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x) x = self.linear(x)
return x return x

View File

@ -0,0 +1,134 @@
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Flux
from .modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)

View File

@ -1,14 +1,17 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI.""" """Class for Flux model loading in InvokeAI."""
from pathlib import Path
import yaml
from dataclasses import fields from dataclasses import fields
from safetensors.torch import load_file from pathlib import Path
from typing import Optional, Any from typing import Any, Optional
from transformers import T5EncoderModel, T5Tokenizer
import yaml
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux, FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
@ -19,20 +22,15 @@ from invokeai.backend.model_manager import (
) )
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
CheckpointConfigBase, CheckpointConfigBase,
MainCheckpointConfig,
CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig,
MainCheckpointConfig,
T5EncoderConfig, T5EncoderConfig,
VAECheckpointConfig, VAECheckpointConfig,
) )
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.model import Flux, FluxParams from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
T5Tokenizer)
app_config = get_config() app_config = get_config()
@ -56,9 +54,9 @@ class FluxVAELoader(GenericDiffusersLoader):
flux_conf = yaml.safe_load(stream) flux_conf = yaml.safe_load(stream)
except: except:
raise raise
dataclass_fields = {f.name for f in fields(AutoEncoderParams)} dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields} filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data) params = AutoEncoderParams(**filtered_data)
with SilenceWarnings(): with SilenceWarnings():
@ -92,6 +90,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
raise Exception("Only Checkpoint Flux models are currently supported.") raise Exception("Only Checkpoint Flux models are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(GenericDiffusersLoader): class T5EncoderCheckpointModel(GenericDiffusersLoader):
"""Class to load main models.""" """Class to load main models."""
@ -106,9 +105,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
match submodel_type: match submodel_type:
case SubModelType.Tokenizer2: case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path), max_length=512) return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
case SubModelType.TextEncoder2: case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path)) return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
raise Exception("Only Checkpoint Flux models are currently supported.") raise Exception("Only Checkpoint Flux models are currently supported.")
@ -148,7 +147,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
params = None params = None
model_path = Path(config.path) model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)} dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields} filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data) params = FluxParams(**filtered_data)
with SilenceWarnings(): with SilenceWarnings():

View File

@ -39,11 +39,15 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
)
class StableDiffusionDiffusersModel(GenericDiffusersLoader): class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models.""" """Class to load main models."""

View File

@ -9,7 +9,7 @@ from typing import Optional
import torch import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@ -52,7 +52,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
return model.calc_size() return model.calc_size()
elif isinstance( elif isinstance(
model, model,
(T5TokenizerFast,T5Tokenizer,), (
T5TokenizerFast,
T5Tokenizer,
),
): ):
return len(model) return len(model)
else: else:

View File

@ -56,7 +56,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
}, },
BaseModelType.StableDiffusionXLRefiner: { BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml", ModelVariantType.Normal: "sd_xl_refiner.yaml",
} },
} }
@ -132,7 +132,7 @@ class ModelProbe(object):
fields = {} fields = {}
model_path = model_path.resolve() model_path = model_path.resolve()
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None model_info = None
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
@ -323,7 +323,7 @@ class ModelProbe(object):
if model_type is ModelType.Main: if model_type is ModelType.Main:
if base_type == BaseModelType.Flux: if base_type == BaseModelType.Flux:
config_file="flux/flux1-schnell.yaml" config_file = "flux/flux1-schnell.yaml"
else: else:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models if isinstance(config_file, dict): # need another tier for sd-2.x models
@ -727,6 +727,7 @@ class T5EncoderFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat: def get_format(self) -> ModelFormat:
return ModelFormat.T5Encoder return ModelFormat.T5Encoder
class ONNXFolderProbe(PipelineFolderProbe): class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors # Due to the way the installer is set up, the configuration file for safetensors