diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 5c0d0ef2ac..9cecd89bca 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,8 +1,4 @@ import torch - - -from einops import repeat -from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -10,9 +6,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.flux.modules.conditioner import HFEncoder +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo @invocation( diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 1327f81709..fdb8e9c1dd 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,12 +1,6 @@ -from typing import Literal - -import accelerate import torch -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from diffusers.pipelines.flux.pipeline_flux import FluxPipeline +from einops import rearrange, repeat from PIL import Image -from safetensors.torch import load_file -from transformers.models.auto import AutoModelForTextEncoding from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ( @@ -20,23 +14,12 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 -from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel -from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.modules.autoencoder import AutoEncoder +from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice -TFluxModelKeys = Literal["flux-schnell"] -FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"} - - -class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel): - base_class = FluxTransformer2DModel - - -class QuantizedModelForTextEncoding(FastQuantizedTransformersModel): - auto_class = AutoModelForTextEncoding - @invocation( "flux_text_to_image", @@ -78,7 +61,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): assert isinstance(flux_conditioning, FLUXConditioningInfo) latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) - image = self._run_vae_decoding(context, latents) + image = self._run_vae_decoding(context, flux_ae_path, latents) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) @@ -89,14 +72,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): t5_embeddings: torch.Tensor, ): transformer_info = context.models.load(self.transformer.transformer) + inference_dtype = TorchDevice.choose_torch_dtype() + + # Prepare input noise. + # TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a + # CPU RNG? + x = get_noise( + num_samples=1, + height=self.height, + width=self.width, + device=TorchDevice.choose_torch_device(), + dtype=inference_dtype, + seed=self.seed, + ) + + img, img_ids = self._prepare_latent_img_patches(x) + + # HACK(ryand): Find a better way to determine if this is a schnell model or not. + is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else "" + timesteps = get_schedule( + num_steps=self.num_steps, + image_seq_len=img.shape[1], + shift=not is_schnell, + ) + + bs, t5_seq_len, _ = t5_embeddings.shape + txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems # if the cache is not empty. - # context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) + context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) with transformer_info as transformer: - assert isinstance(transformer, FluxTransformer2DModel) + assert isinstance(transformer, Flux) x = denoise( model=transformer, @@ -144,21 +153,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ) -> Image.Image: vae_info = context.models.load(self.vae.vae) with vae_info as vae: - assert isinstance(vae, AutoencoderKL) + assert isinstance(vae, AutoEncoder) + # TODO(ryand): Test that this works with both float16 and bfloat16. + with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()): + img = vae.decode(latents) img.clamp(-1, 1) img = rearrange(img[0], "c h w -> h w c") img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) - latents = flux_pipeline_with_vae._unpack_latents( - latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor - ) - latents = ( - latents / flux_pipeline_with_vae.vae.config.scaling_factor - ) + flux_pipeline_with_vae.vae.config.shift_factor - latents = latents.to(dtype=vae.dtype) - image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0] - image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0] - - assert isinstance(image, Image.Image) - return image + return img_pil diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 3908bef4da..f408dc3e0e 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,6 +1,6 @@ import copy from time import sleep -from typing import List, Optional, Literal, Dict +from typing import Dict, List, Literal, Optional from pydantic import BaseModel, Field @@ -12,10 +12,10 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.services.model_records import ModelRecordChanges from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from invokeai.app.services.model_records import ModelRecordChanges -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType class ModelIdentifierField(BaseModel): @@ -132,31 +132,22 @@ class ModelIdentifierInvocation(BaseInvocation): return ModelIdentifierOutput(model=self.model) -T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"] + +T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"] T5_ENCODER_MAP: Dict[str, Dict[str, str]] = { "base": { - "text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2", - "tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2", - "text_encoder_name": "FLUX.1-schnell_text_encoder_2", - "tokenizer_name": "FLUX.1-schnell_tokenizer_2", + "repo": "invokeai/flux_dev::t5_xxl_encoder/base", + "name": "t5_base_encoder", "format": ModelFormat.T5Encoder, }, "8b_quantized": { - "text_encoder_repo": "hf_repo1", - "tokenizer_repo": "hf_repo1", - "text_encoder_name": "hf_repo1", - "tokenizer_name": "hf_repo1", - "format": ModelFormat.T5Encoder8b, - }, - "4b_quantized": { - "text_encoder_repo": "hf_repo2", - "tokenizer_repo": "hf_repo2", - "text_encoder_name": "hf_repo2", - "tokenizer_name": "hf_repo2", - "format": ModelFormat.T5Encoder8b, + "repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized", + "name": "t5_8b_quantized_encoder", + "format": ModelFormat.T5Encoder, }, } + @invocation_output("flux_model_loader_output") class FluxModelLoaderOutput(BaseInvocationOutput): """Flux base model loader output""" @@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation): ui_type=UIType.FluxMainModel, input=Input.Direct, ) - + t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.") def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: @@ -189,7 +180,15 @@ class FluxModelLoaderInvocation(BaseInvocation): tokenizer2 = self._get_model(context, SubModelType.Tokenizer2) clip_encoder = self._get_model(context, SubModelType.TextEncoder) t5_encoder = self._get_model(context, SubModelType.TextEncoder2) - vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux) + vae = self._install_model( + context, + SubModelType.VAE, + "FLUX.1-schnell_ae", + "black-forest-labs/FLUX.1-schnell::ae.safetensors", + ModelFormat.Checkpoint, + ModelType.VAE, + BaseModelType.Flux, + ) return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer), @@ -198,33 +197,59 @@ class FluxModelLoaderInvocation(BaseInvocation): vae=VAEField(vae=vae), ) - def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField: - match(submodel): + def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField: + match submodel: case SubModelType.Transformer: return self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]: - return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any) - case SubModelType.TextEncoder2: - return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any) - case SubModelType.Tokenizer2: - return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any) + return self._install_model( + context, + submodel, + "clip-vit-large-patch14", + "openai/clip-vit-large-patch14", + ModelFormat.Diffusers, + ModelType.CLIPEmbed, + BaseModelType.Any, + ) + case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]: + return self._install_model( + context, + submodel, + T5_ENCODER_MAP[self.t5_encoder]["name"], + T5_ENCODER_MAP[self.t5_encoder]["repo"], + ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), + ModelType.T5Encoder, + BaseModelType.Any, + ) case _: - raise Exception(f"{submodel.value} is not a supported submodule for a flux model") + raise Exception(f"{submodel.value} is not a supported submodule for a flux model") - def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType): - if (models := context.models.search_by_attrs(name=name, base=base, type=type)): + def _install_model( + self, + context: InvocationContext, + submodel: SubModelType, + name: str, + repo_id: str, + format: ModelFormat, + type: ModelType, + base: BaseModelType, + ): + if models := context.models.search_by_attrs(name=name, base=base, type=type): if len(models) != 1: raise Exception(f"Multiple models detected for selected model with name {name}") return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel}) else: model_path = context.models.download_and_cache_model(repo_id) - config = ModelRecordChanges(name = name, base = base, type=type, format=format) + config = ModelRecordChanges(name=name, base=base, type=type, format=format) model_install_job = context.models.import_local_model(model_path=model_path, config=config) while not model_install_job.in_terminal_state: sleep(0.01) if not model_install_job.config_out: raise Exception(f"Failed to install {name}") - return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel}) + return ModelIdentifierField.from_config(model_install_job.config_out).model_copy( + update={"submodel_type": submodel} + ) + @invocation( "main_model_loader", diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index d1ec015242..1d0780efe1 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -301,7 +301,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): for row in result: try: model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1]) - except pydantic.ValidationError as e: + except pydantic.ValidationError: # We catch this error so that the app can still run if there are invalid model configs in the database. # One reason that an invalid model config might be in the database is if someone had to rollback from a # newer version of the app that added a new model type. diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 9a5ac3fb5a..9ba1bf68f3 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -465,18 +465,20 @@ class ModelsInterface(InvocationContextInterface): return self._services.model_manager.install.download_and_cache_model(source=source) def import_local_model( - self, - model_path: Path, - config: Optional[ModelRecordChanges] = None, - access_token: Optional[str] = None, - inplace: Optional[bool] = False, + self, + model_path: Path, + config: Optional[ModelRecordChanges] = None, + access_token: Optional[str] = None, + inplace: Optional[bool] = False, ): """ TODO: Fill out description of this method """ if not model_path.exists(): raise Exception("Models provided to import_local_model must already exist on disk") - return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace) + return self._services.model_manager.install.heuristic_import( + str(model_path), config=config, access_token=access_token, inplace=inplace + ) def load_local_model( self, diff --git a/invokeai/backend/flux/math.py b/invokeai/backend/flux/math.py index 71b91fa0f5..0156bb6a20 100644 --- a/invokeai/backend/flux/math.py +++ b/invokeai/backend/flux/math.py @@ -27,4 +27,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) \ No newline at end of file + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index 2cb0aa102e..f7ef25bf4f 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -3,9 +3,15 @@ from dataclasses import dataclass import torch from torch import Tensor, nn -from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, - MLPEmbedder, SingleStreamBlock, - timestep_embedding) +from invokeai.backend.flux.modules.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + @dataclass class FluxParams: @@ -35,9 +41,7 @@ class Flux(nn.Module): self.in_channels = params.in_channels self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: - raise ValueError( - f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" - ) + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") pe_dim = params.hidden_size // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") @@ -108,4 +112,4 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - return img \ No newline at end of file + return img diff --git a/invokeai/backend/flux/modules/autoencoder.py b/invokeai/backend/flux/modules/autoencoder.py index f6e072ecaa..75159f711f 100644 --- a/invokeai/backend/flux/modules/autoencoder.py +++ b/invokeai/backend/flux/modules/autoencoder.py @@ -309,4 +309,4 @@ class AutoEncoder(nn.Module): return self.decoder(z) def forward(self, x: Tensor) -> Tensor: - return self.decode(self.encode(x)) \ No newline at end of file + return self.decode(self.encode(x)) diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py index 2a9e17c20e..974ad64ab3 100644 --- a/invokeai/backend/flux/modules/conditioner.py +++ b/invokeai/backend/flux/modules/conditioner.py @@ -1,5 +1,6 @@ from torch import Tensor, nn -from transformers import (PreTrainedModel, PreTrainedTokenizer) +from transformers import PreTrainedModel, PreTrainedTokenizer + class HFEncoder(nn.Module): def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int): @@ -27,4 +28,4 @@ class HFEncoder(nn.Module): attention_mask=None, output_hidden_states=False, ) - return outputs[self.output_key] \ No newline at end of file + return outputs[self.output_key] diff --git a/invokeai/backend/flux/modules/layers.py b/invokeai/backend/flux/modules/layers.py index cb4eee0c2d..4f9d515daf 100644 --- a/invokeai/backend/flux/modules/layers.py +++ b/invokeai/backend/flux/modules/layers.py @@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 """ t = time_factor * t half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - t.device - ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -250,4 +248,4 @@ class LastLayer(nn.Module): shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = self.linear(x) - return x \ No newline at end of file + return x diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py new file mode 100644 index 0000000000..89d9d417e0 --- /dev/null +++ b/invokeai/backend/flux/sampling.py @@ -0,0 +1,134 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from .model import Flux +from .modules.conditioner import HFEncoder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + + img = img + (t_prev - t_curr) * pred + + return img + + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 7a028a55e1..78ecfccfa3 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1,14 +1,17 @@ # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team """Class for Flux model loading in InvokeAI.""" -from pathlib import Path -import yaml - from dataclasses import fields -from safetensors.torch import load_file -from typing import Optional, Any -from transformers import T5EncoderModel, T5Tokenizer +from pathlib import Path +from typing import Any, Optional +import yaml +from safetensors.torch import load_file +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + +from invokeai.app.services.config.config_default import get_config +from invokeai.backend.flux.model import Flux, FluxParams +from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -19,20 +22,15 @@ from invokeai.backend.model_manager import ( ) from invokeai.backend.model_manager.config import ( CheckpointConfigBase, - MainCheckpointConfig, CLIPEmbedDiffusersConfig, + MainCheckpointConfig, T5EncoderConfig, VAECheckpointConfig, ) -from invokeai.app.services.config.config_default import get_config from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader -from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.flux.model import Flux, FluxParams -from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder -from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, - T5Tokenizer) +from invokeai.backend.util.silence_warnings import SilenceWarnings app_config = get_config() @@ -56,9 +54,9 @@ class FluxVAELoader(GenericDiffusersLoader): flux_conf = yaml.safe_load(stream) except: raise - + dataclass_fields = {f.name for f in fields(AutoEncoderParams)} - filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields} + filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields} params = AutoEncoderParams(**filtered_data) with SilenceWarnings(): @@ -92,6 +90,7 @@ class ClipCheckpointModel(GenericDiffusersLoader): raise Exception("Only Checkpoint Flux models are currently supported.") + @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) class T5EncoderCheckpointModel(GenericDiffusersLoader): """Class to load main models.""" @@ -106,9 +105,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader): match submodel_type: case SubModelType.Tokenizer2: - return T5Tokenizer.from_pretrained(Path(config.path), max_length=512) + return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512) case SubModelType.TextEncoder2: - return T5EncoderModel.from_pretrained(Path(config.path)) + return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer") raise Exception("Only Checkpoint Flux models are currently supported.") @@ -148,7 +147,7 @@ class FluxCheckpointModel(GenericDiffusersLoader): params = None model_path = Path(config.path) dataclass_fields = {f.name for f in fields(FluxParams)} - filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields} + filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields} params = FluxParams(**filtered_data) with SilenceWarnings(): diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index e034e11011..572859dbae 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -39,11 +39,15 @@ VARIANT_TO_IN_CHANNEL_MAP = { @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register( + base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers +) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register( + base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint +) class StableDiffusionDiffusersModel(GenericDiffusersLoader): """Class to load main models.""" diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 6987e5222d..6f93fcbd75 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -9,7 +9,7 @@ from typing import Optional import torch from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin -from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer +from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline @@ -52,7 +52,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: return model.calc_size() elif isinstance( model, - (T5TokenizerFast,T5Tokenizer,), + ( + T5TokenizerFast, + T5Tokenizer, + ), ): return len(model) else: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index a3a648806f..fcb4e9b2f0 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -56,7 +56,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched }, BaseModelType.StableDiffusionXLRefiner: { ModelVariantType.Normal: "sd_xl_refiner.yaml", - } + }, } @@ -132,7 +132,7 @@ class ModelProbe(object): fields = {} model_path = model_path.resolve() - + format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint model_info = None model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None @@ -323,7 +323,7 @@ class ModelProbe(object): if model_type is ModelType.Main: if base_type == BaseModelType.Flux: - config_file="flux/flux1-schnell.yaml" + config_file = "flux/flux1-schnell.yaml" else: config_file = LEGACY_CONFIGS[base_type][variant_type] if isinstance(config_file, dict): # need another tier for sd-2.x models @@ -727,6 +727,7 @@ class T5EncoderFolderProbe(FolderProbeBase): def get_format(self) -> ModelFormat: return ModelFormat.T5Encoder + class ONNXFolderProbe(PipelineFolderProbe): def get_base_type(self) -> BaseModelType: # Due to the way the installer is set up, the configuration file for safetensors