2024-08-06 21:51:22 +00:00
import torch
2024-08-22 17:18:43 +00:00
from einops import rearrange
2024-08-06 21:51:22 +00:00
from PIL import Image
2024-08-22 15:29:59 +00:00
from invokeai . app . invocations . baseinvocation import BaseInvocation , Classification , invocation
2024-08-12 18:23:02 +00:00
from invokeai . app . invocations . fields import (
FieldDescriptions ,
2024-08-23 17:50:01 +00:00
FluxConditioningField ,
2024-08-12 18:23:02 +00:00
Input ,
InputField ,
WithBoard ,
WithMetadata ,
)
2024-08-15 14:27:42 +00:00
from invokeai . app . invocations . model import TransformerField , VAEField
2024-08-06 21:51:22 +00:00
from invokeai . app . invocations . primitives import ImageOutput
2024-08-26 17:14:48 +00:00
from invokeai . app . services . session_processor . session_processor_common import CanceledException
2024-08-06 21:51:22 +00:00
from invokeai . app . services . shared . invocation_context import InvocationContext
2024-08-19 14:14:58 +00:00
from invokeai . backend . flux . model import Flux
from invokeai . backend . flux . modules . autoencoder import AutoEncoder
2024-08-22 17:18:43 +00:00
from invokeai . backend . flux . sampling import denoise , get_noise , get_schedule , prepare_latent_img_patches , unpack
2024-08-12 18:23:02 +00:00
from invokeai . backend . stable_diffusion . diffusion . conditioning_data import FLUXConditioningInfo
2024-08-16 20:22:49 +00:00
from invokeai . backend . util . devices import TorchDevice
2024-08-06 21:51:22 +00:00
2024-08-08 18:23:20 +00:00
2024-08-06 21:51:22 +00:00
@invocation (
" flux_text_to_image " ,
title = " FLUX Text to Image " ,
2024-08-20 19:31:22 +00:00
tags = [ " image " , " flux " ] ,
2024-08-06 21:51:22 +00:00
category = " image " ,
version = " 1.0.0 " ,
2024-08-22 15:29:59 +00:00
classification = Classification . Prototype ,
2024-08-06 21:51:22 +00:00
)
class FluxTextToImageInvocation ( BaseInvocation , WithMetadata , WithBoard ) :
""" Text-to-image generation using a FLUX model. """
2024-08-12 22:01:42 +00:00
transformer : TransformerField = InputField (
2024-08-21 13:45:22 +00:00
description = FieldDescriptions . flux_model ,
2024-08-12 22:01:42 +00:00
input = Input . Connection ,
title = " Transformer " ,
2024-08-12 18:04:23 +00:00
)
2024-08-12 22:01:42 +00:00
vae : VAEField = InputField (
description = FieldDescriptions . vae ,
input = Input . Connection ,
2024-08-07 19:50:03 +00:00
)
2024-08-23 17:50:01 +00:00
positive_text_conditioning : FluxConditioningField = InputField (
2024-08-12 18:23:02 +00:00
description = FieldDescriptions . positive_cond , input = Input . Connection
)
2024-08-06 21:51:22 +00:00
width : int = InputField ( default = 1024 , multiple_of = 16 , description = " Width of the generated image. " )
height : int = InputField ( default = 1024 , multiple_of = 16 , description = " Height of the generated image. " )
2024-08-21 13:45:22 +00:00
num_steps : int = InputField (
default = 4 , description = " Number of diffusion steps. Recommend values are schnell: 4, dev: 50. "
)
2024-08-06 21:51:22 +00:00
guidance : float = InputField (
default = 4.0 ,
2024-08-21 13:45:22 +00:00
description = " The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell. " ,
2024-08-06 21:51:22 +00:00
)
seed : int = InputField ( default = 0 , description = " Randomness seed for reproducibility. " )
@torch.no_grad ( )
def invoke ( self , context : InvocationContext ) - > ImageOutput :
2024-08-28 15:03:08 +00:00
latents = self . _run_diffusion ( context )
2024-08-19 16:12:06 +00:00
image = self . _run_vae_decoding ( context , latents )
2024-08-06 21:51:22 +00:00
image_dto = context . images . save ( image = image )
return ImageOutput . build ( image_dto )
def _run_diffusion (
self ,
context : InvocationContext ,
) :
2024-08-19 17:59:44 +00:00
inference_dtype = torch . bfloat16
2024-08-19 14:14:58 +00:00
2024-08-28 15:03:08 +00:00
# Load the conditioning data.
cond_data = context . conditioning . load ( self . positive_text_conditioning . conditioning_name )
assert len ( cond_data . conditionings ) == 1
flux_conditioning = cond_data . conditionings [ 0 ]
assert isinstance ( flux_conditioning , FLUXConditioningInfo )
flux_conditioning = flux_conditioning . to ( dtype = inference_dtype )
t5_embeddings = flux_conditioning . t5_embeds
clip_embeddings = flux_conditioning . clip_embeds
transformer_info = context . models . load ( self . transformer . transformer )
2024-08-19 14:14:58 +00:00
# Prepare input noise.
x = get_noise (
num_samples = 1 ,
height = self . height ,
width = self . width ,
device = TorchDevice . choose_torch_device ( ) ,
dtype = inference_dtype ,
seed = self . seed ,
)
2024-08-28 15:03:08 +00:00
x , img_ids = prepare_latent_img_patches ( x )
2024-08-19 14:14:58 +00:00
2024-08-22 16:03:54 +00:00
is_schnell = " schnell " in transformer_info . config . config_path
2024-08-19 14:14:58 +00:00
timesteps = get_schedule (
num_steps = self . num_steps ,
2024-08-28 15:03:08 +00:00
image_seq_len = x . shape [ 1 ] ,
2024-08-19 14:14:58 +00:00
shift = not is_schnell ,
)
bs , t5_seq_len , _ = t5_embeddings . shape
txt_ids = torch . zeros ( bs , t5_seq_len , 3 , dtype = inference_dtype , device = TorchDevice . choose_torch_device ( ) )
2024-08-07 19:50:03 +00:00
2024-08-16 21:04:48 +00:00
with transformer_info as transformer :
2024-08-19 14:14:58 +00:00
assert isinstance ( transformer , Flux )
2024-08-16 20:22:49 +00:00
2024-08-26 17:07:31 +00:00
def step_callback ( ) - > None :
2024-08-24 15:01:16 +00:00
if context . util . is_canceled ( ) :
raise CanceledException
2024-08-26 16:59:36 +00:00
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
2024-08-24 15:01:16 +00:00
2024-08-16 20:22:49 +00:00
x = denoise (
model = transformer ,
2024-08-28 15:03:08 +00:00
img = x ,
2024-08-16 20:22:49 +00:00
img_ids = img_ids ,
txt = t5_embeddings ,
txt_ids = txt_ids ,
vec = clip_embeddings ,
timesteps = timesteps ,
2024-08-24 15:01:16 +00:00
step_callback = step_callback ,
2024-08-16 20:22:49 +00:00
guidance = self . guidance ,
2024-08-06 21:51:22 +00:00
)
2024-08-16 20:22:49 +00:00
x = unpack ( x . float ( ) , self . height , self . width )
return x
2024-08-06 21:51:22 +00:00
def _run_vae_decoding (
self ,
context : InvocationContext ,
2024-08-07 19:50:03 +00:00
latents : torch . Tensor ,
2024-08-06 21:51:22 +00:00
) - > Image . Image :
2024-08-12 22:01:42 +00:00
vae_info = context . models . load ( self . vae . vae )
with vae_info as vae :
2024-08-19 14:14:58 +00:00
assert isinstance ( vae , AutoEncoder )
2024-08-22 18:16:43 +00:00
latents = latents . to ( dtype = TorchDevice . choose_torch_dtype ( ) )
2024-08-19 17:12:38 +00:00
img = vae . decode ( latents )
2024-08-06 21:51:22 +00:00
2024-08-20 14:39:33 +00:00
img = img . clamp ( - 1 , 1 )
2024-08-16 20:22:49 +00:00
img = rearrange ( img [ 0 ] , " c h w -> h w c " )
img_pil = Image . fromarray ( ( 127.5 * ( img + 1.0 ) ) . byte ( ) . cpu ( ) . numpy ( ) )
2024-08-06 21:51:22 +00:00
2024-08-19 14:14:58 +00:00
return img_pil