diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index b68bb91513..43cf1f9d65 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,3 +1,4 @@ +import numpy as np import torch from einops import rearrange from PIL import Image @@ -13,12 +14,15 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.step_callback import PipelineIntermediateState from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.util import image_to_dataURL @invocation( @@ -108,6 +112,35 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): with transformer_info as transformer: assert isinstance(transformer, Flux) + def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None: + if context.util.is_canceled(): + raise CanceledException + + # TODO: Make this look like the image + latent_image = unpack(img.float(), self.height, self.width) + latent_image = latent_image.squeeze() # Remove unnecessary dimensions + flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] + + # Create a new tensor of the required shape [255, 255, 3] + latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format + + # Convert to a NumPy array and then to a PIL Image + image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) + + (width, height) = image.size + width *= 8 + height *= 8 + + dataURL = image_to_dataURL(image, image_format="JPEG") + + # TODO: move this whole function to invocation context to properly reference these variables + context._services.events.emit_invocation_denoise_progress( + context._data.queue_item, + context._data.invocation, + state, + ProgressImage(dataURL=dataURL, width=width, height=height), + ) + x = denoise( model=transformer, img=img, @@ -116,6 +149,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): txt_ids=txt_ids, vec=clip_embeddings, timesteps=timesteps, + step_callback=step_callback, guidance=self.guidance, ) diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 318a0bcdce..ab9d41797b 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -8,6 +8,7 @@ from einops import rearrange, repeat from torch import Tensor from tqdm import tqdm +from invokeai.app.util.step_callback import PipelineIntermediateState from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.conditioner import HFEncoder @@ -108,6 +109,7 @@ def denoise( vec: Tensor, # sampling parameters timesteps: list[float], + step_callback: Callable[[Tensor, PipelineIntermediateState], None], guidance: float = 4.0, ): dtype = model.txt_in.bias.dtype @@ -121,6 +123,7 @@ def denoise( # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + step_count = 0 for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) pred = model( @@ -134,6 +137,17 @@ def denoise( ) img = img + (t_prev - t_curr) * pred + step_callback( + img, + PipelineIntermediateState( + step=step_count, + order=0, + total_steps=len(timesteps), + timestep=math.floor(t_curr), + latents=img, + ), + ) + step_count += 1 return img diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py new file mode 100644 index 0000000000..112d7111de --- /dev/null +++ b/invokeai/backend/flux/util.py @@ -0,0 +1,86 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +import os +from dataclasses import dataclass + +from invokeai.backend.flux.model import FluxParams +from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py index e8771dca22..286c96b527 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py @@ -1,8 +1,8 @@ from pathlib import Path import accelerate -from flux.model import Flux -from flux.util import configs as flux_configs +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import configs as flux_configs from safetensors.torch import load_file, save_file from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py index fe88b79d32..5415407a2b 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py @@ -4,10 +4,10 @@ from pathlib import Path import accelerate import torch -from flux.model import Flux -from flux.util import configs as flux_configs from safetensors.torch import load_file, save_file +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import configs as flux_configs from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4