2024-08-06 21:51:22 +00:00
import torch
2024-08-19 14:14:58 +00:00
from einops import rearrange , repeat
2024-08-06 21:51:22 +00:00
from PIL import Image
2024-08-22 15:29:59 +00:00
from invokeai . app . invocations . baseinvocation import BaseInvocation , Classification , invocation
2024-08-12 18:23:02 +00:00
from invokeai . app . invocations . fields import (
ConditioningField ,
FieldDescriptions ,
Input ,
InputField ,
WithBoard ,
WithMetadata ,
)
2024-08-15 14:27:42 +00:00
from invokeai . app . invocations . model import TransformerField , VAEField
2024-08-06 21:51:22 +00:00
from invokeai . app . invocations . primitives import ImageOutput
from invokeai . app . services . shared . invocation_context import InvocationContext
2024-08-19 14:14:58 +00:00
from invokeai . backend . flux . model import Flux
from invokeai . backend . flux . modules . autoencoder import AutoEncoder
from invokeai . backend . flux . sampling import denoise , get_noise , get_schedule , unpack
2024-08-20 14:39:33 +00:00
from invokeai . backend . model_manager . config import CheckpointConfigBase
2024-08-12 18:23:02 +00:00
from invokeai . backend . stable_diffusion . diffusion . conditioning_data import FLUXConditioningInfo
2024-08-16 20:22:49 +00:00
from invokeai . backend . util . devices import TorchDevice
2024-08-06 21:51:22 +00:00
2024-08-08 18:23:20 +00:00
2024-08-06 21:51:22 +00:00
@invocation (
" flux_text_to_image " ,
title = " FLUX Text to Image " ,
2024-08-20 19:31:22 +00:00
tags = [ " image " , " flux " ] ,
2024-08-06 21:51:22 +00:00
category = " image " ,
version = " 1.0.0 " ,
2024-08-22 15:29:59 +00:00
classification = Classification . Prototype ,
2024-08-06 21:51:22 +00:00
)
class FluxTextToImageInvocation ( BaseInvocation , WithMetadata , WithBoard ) :
""" Text-to-image generation using a FLUX model. """
2024-08-12 22:01:42 +00:00
transformer : TransformerField = InputField (
2024-08-21 13:45:22 +00:00
description = FieldDescriptions . flux_model ,
2024-08-12 22:01:42 +00:00
input = Input . Connection ,
title = " Transformer " ,
2024-08-12 18:04:23 +00:00
)
2024-08-12 22:01:42 +00:00
vae : VAEField = InputField (
description = FieldDescriptions . vae ,
input = Input . Connection ,
2024-08-07 19:50:03 +00:00
)
2024-08-12 18:23:02 +00:00
positive_text_conditioning : ConditioningField = InputField (
description = FieldDescriptions . positive_cond , input = Input . Connection
)
2024-08-06 21:51:22 +00:00
width : int = InputField ( default = 1024 , multiple_of = 16 , description = " Width of the generated image. " )
height : int = InputField ( default = 1024 , multiple_of = 16 , description = " Height of the generated image. " )
2024-08-21 13:45:22 +00:00
num_steps : int = InputField (
default = 4 , description = " Number of diffusion steps. Recommend values are schnell: 4, dev: 50. "
)
2024-08-06 21:51:22 +00:00
guidance : float = InputField (
default = 4.0 ,
2024-08-21 13:45:22 +00:00
description = " The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell. " ,
2024-08-06 21:51:22 +00:00
)
seed : int = InputField ( default = 0 , description = " Randomness seed for reproducibility. " )
@torch.no_grad ( )
def invoke ( self , context : InvocationContext ) - > ImageOutput :
2024-08-12 18:23:02 +00:00
# Load the conditioning data.
cond_data = context . conditioning . load ( self . positive_text_conditioning . conditioning_name )
assert len ( cond_data . conditionings ) == 1
flux_conditioning = cond_data . conditionings [ 0 ]
assert isinstance ( flux_conditioning , FLUXConditioningInfo )
2024-08-12 22:01:42 +00:00
latents = self . _run_diffusion ( context , flux_conditioning . clip_embeds , flux_conditioning . t5_embeds )
2024-08-19 16:12:06 +00:00
image = self . _run_vae_decoding ( context , latents )
2024-08-06 21:51:22 +00:00
image_dto = context . images . save ( image = image )
return ImageOutput . build ( image_dto )
def _run_diffusion (
self ,
context : InvocationContext ,
clip_embeddings : torch . Tensor ,
t5_embeddings : torch . Tensor ,
) :
2024-08-12 22:01:42 +00:00
transformer_info = context . models . load ( self . transformer . transformer )
2024-08-19 17:59:44 +00:00
inference_dtype = torch . bfloat16
2024-08-19 14:14:58 +00:00
# Prepare input noise.
x = get_noise (
num_samples = 1 ,
height = self . height ,
width = self . width ,
device = TorchDevice . choose_torch_device ( ) ,
dtype = inference_dtype ,
seed = self . seed ,
)
img , img_ids = self . _prepare_latent_img_patches ( x )
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
2024-08-20 14:39:33 +00:00
is_schnell = (
" schnell " in transformer_info . config . config_path
if transformer_info . config and isinstance ( transformer_info . config , CheckpointConfigBase )
else " "
)
2024-08-19 14:14:58 +00:00
timesteps = get_schedule (
num_steps = self . num_steps ,
image_seq_len = img . shape [ 1 ] ,
shift = not is_schnell ,
)
bs , t5_seq_len , _ = t5_embeddings . shape
txt_ids = torch . zeros ( bs , t5_seq_len , 3 , dtype = inference_dtype , device = TorchDevice . choose_torch_device ( ) )
2024-08-07 19:50:03 +00:00
2024-08-07 22:10:09 +00:00
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
2024-08-19 14:14:58 +00:00
context . models . _services . model_manager . load . ram_cache . make_room ( 24 * 2 * * 30 )
2024-08-06 21:51:22 +00:00
2024-08-16 21:04:48 +00:00
with transformer_info as transformer :
2024-08-19 14:14:58 +00:00
assert isinstance ( transformer , Flux )
2024-08-16 20:22:49 +00:00
x = denoise (
model = transformer ,
img = img ,
img_ids = img_ids ,
txt = t5_embeddings ,
txt_ids = txt_ids ,
vec = clip_embeddings ,
timesteps = timesteps ,
guidance = self . guidance ,
2024-08-06 21:51:22 +00:00
)
2024-08-16 20:22:49 +00:00
x = unpack ( x . float ( ) , self . height , self . width )
return x
def _prepare_latent_img_patches ( self , latent_img : torch . Tensor ) - > tuple [ torch . Tensor , torch . Tensor ] :
""" Convert an input image in latent space to patches for diffusion.
This implementation was extracted from :
https : / / github . com / black - forest - labs / flux / blob / c00d7c60b085fce8058b9df845e036090873f2ce / src / flux / sampling . py #L32
Returns :
tuple [ Tensor , Tensor ] : ( img , img_ids ) , as defined in the original flux repo .
"""
bs , c , h , w = latent_img . shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange ( latent_img , " b c (h ph) (w pw) -> b (h w) (c ph pw) " , ph = 2 , pw = 2 )
if img . shape [ 0 ] == 1 and bs > 1 :
img = repeat ( img , " 1 ... -> bs ... " , bs = bs )
# Generate patch position ids.
2024-08-19 17:12:38 +00:00
img_ids = torch . zeros ( h / / 2 , w / / 2 , 3 , device = img . device )
img_ids [ . . . , 1 ] = img_ids [ . . . , 1 ] + torch . arange ( h / / 2 , device = img . device ) [ : , None ]
img_ids [ . . . , 2 ] = img_ids [ . . . , 2 ] + torch . arange ( w / / 2 , device = img . device ) [ None , : ]
2024-08-16 20:22:49 +00:00
img_ids = repeat ( img_ids , " h w c -> b (h w) c " , b = bs )
return img , img_ids
2024-08-07 22:10:09 +00:00
2024-08-06 21:51:22 +00:00
def _run_vae_decoding (
self ,
context : InvocationContext ,
2024-08-07 19:50:03 +00:00
latents : torch . Tensor ,
2024-08-06 21:51:22 +00:00
) - > Image . Image :
2024-08-12 22:01:42 +00:00
vae_info = context . models . load ( self . vae . vae )
with vae_info as vae :
2024-08-19 14:14:58 +00:00
assert isinstance ( vae , AutoEncoder )
# TODO(ryand): Test that this works with both float16 and bfloat16.
2024-08-19 17:12:38 +00:00
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
vae . to ( torch . float32 )
latents . to ( torch . float32 )
img = vae . decode ( latents )
2024-08-06 21:51:22 +00:00
2024-08-20 14:39:33 +00:00
img = img . clamp ( - 1 , 1 )
2024-08-16 20:22:49 +00:00
img = rearrange ( img [ 0 ] , " c h w -> h w c " )
img_pil = Image . fromarray ( ( 127.5 * ( img + 1.0 ) ) . byte ( ) . cpu ( ) . numpy ( ) )
2024-08-06 21:51:22 +00:00
2024-08-19 14:14:58 +00:00
return img_pil