2024-08-30 14:46:04 +00:00
from typing import Callable , Optional
2024-08-29 14:17:08 +00:00
2024-08-06 21:51:22 +00:00
import torch
2024-08-29 19:05:44 +00:00
import torchvision . transforms as tv_transforms
from torchvision . transforms . functional import resize as tv_resize
2024-08-06 21:51:22 +00:00
2024-08-22 15:29:59 +00:00
from invokeai . app . invocations . baseinvocation import BaseInvocation , Classification , invocation
2024-08-12 18:23:02 +00:00
from invokeai . app . invocations . fields import (
2024-08-29 19:05:44 +00:00
DenoiseMaskField ,
2024-08-12 18:23:02 +00:00
FieldDescriptions ,
2024-08-23 17:50:01 +00:00
FluxConditioningField ,
2024-08-12 18:23:02 +00:00
Input ,
InputField ,
2024-08-29 14:17:08 +00:00
LatentsField ,
2024-08-12 18:23:02 +00:00
WithBoard ,
WithMetadata ,
)
2024-08-29 15:08:11 +00:00
from invokeai . app . invocations . model import TransformerField
2024-08-29 14:50:58 +00:00
from invokeai . app . invocations . primitives import LatentsOutput
2024-08-26 17:14:48 +00:00
from invokeai . app . services . session_processor . session_processor_common import CanceledException
2024-08-06 21:51:22 +00:00
from invokeai . app . services . shared . invocation_context import InvocationContext
2024-08-29 19:05:44 +00:00
from invokeai . backend . flux . denoise import denoise
2024-08-30 14:46:04 +00:00
from invokeai . backend . flux . inpaint_extension import InpaintExtension
2024-08-19 14:14:58 +00:00
from invokeai . backend . flux . model import Flux
2024-08-29 19:05:44 +00:00
from invokeai . backend . flux . sampling_utils import (
generate_img_ids ,
get_noise ,
get_schedule ,
pack ,
unpack ,
)
2024-08-12 18:23:02 +00:00
from invokeai . backend . stable_diffusion . diffusion . conditioning_data import FLUXConditioningInfo
2024-08-16 20:22:49 +00:00
from invokeai . backend . util . devices import TorchDevice
2024-08-06 21:51:22 +00:00
2024-08-08 18:23:20 +00:00
2024-08-06 21:51:22 +00:00
@invocation (
" flux_text_to_image " ,
title = " FLUX Text to Image " ,
2024-08-20 19:31:22 +00:00
tags = [ " image " , " flux " ] ,
2024-08-06 21:51:22 +00:00
category = " image " ,
2024-08-29 14:56:58 +00:00
version = " 2.0.0 " ,
2024-08-22 15:29:59 +00:00
classification = Classification . Prototype ,
2024-08-06 21:51:22 +00:00
)
class FluxTextToImageInvocation ( BaseInvocation , WithMetadata , WithBoard ) :
""" Text-to-image generation using a FLUX model. """
2024-08-29 14:17:08 +00:00
# If latents is provided, this means we are doing image-to-image.
latents : Optional [ LatentsField ] = InputField (
default = None ,
description = FieldDescriptions . latents ,
input = Input . Connection ,
)
2024-08-30 14:46:04 +00:00
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
2024-08-29 19:05:44 +00:00
denoise_mask : Optional [ DenoiseMaskField ] = InputField (
default = None ,
description = FieldDescriptions . denoise_mask ,
input = Input . Connection ,
)
2024-08-29 14:17:08 +00:00
denoising_start : float = InputField (
default = 0.0 ,
ge = 0 ,
le = 1 ,
description = FieldDescriptions . denoising_start ,
)
2024-08-12 22:01:42 +00:00
transformer : TransformerField = InputField (
2024-08-21 13:45:22 +00:00
description = FieldDescriptions . flux_model ,
2024-08-12 22:01:42 +00:00
input = Input . Connection ,
title = " Transformer " ,
2024-08-12 18:04:23 +00:00
)
2024-08-23 17:50:01 +00:00
positive_text_conditioning : FluxConditioningField = InputField (
2024-08-12 18:23:02 +00:00
description = FieldDescriptions . positive_cond , input = Input . Connection
)
2024-08-06 21:51:22 +00:00
width : int = InputField ( default = 1024 , multiple_of = 16 , description = " Width of the generated image. " )
height : int = InputField ( default = 1024 , multiple_of = 16 , description = " Height of the generated image. " )
2024-08-21 13:45:22 +00:00
num_steps : int = InputField (
2024-08-29 14:50:58 +00:00
default = 4 , description = " Number of diffusion steps. Recommended values are schnell: 4, dev: 50. "
2024-08-21 13:45:22 +00:00
)
2024-08-06 21:51:22 +00:00
guidance : float = InputField (
default = 4.0 ,
2024-08-21 13:45:22 +00:00
description = " The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell. " ,
2024-08-06 21:51:22 +00:00
)
seed : int = InputField ( default = 0 , description = " Randomness seed for reproducibility. " )
@torch.no_grad ( )
2024-08-29 14:50:58 +00:00
def invoke ( self , context : InvocationContext ) - > LatentsOutput :
2024-08-28 15:03:08 +00:00
latents = self . _run_diffusion ( context )
2024-08-29 14:50:58 +00:00
latents = latents . detach ( ) . to ( " cpu " )
name = context . tensors . save ( tensor = latents )
return LatentsOutput . build ( latents_name = name , latents = latents , seed = None )
2024-08-06 21:51:22 +00:00
def _run_diffusion (
self ,
context : InvocationContext ,
) :
2024-08-19 17:59:44 +00:00
inference_dtype = torch . bfloat16
2024-08-19 14:14:58 +00:00
2024-08-28 15:03:08 +00:00
# Load the conditioning data.
cond_data = context . conditioning . load ( self . positive_text_conditioning . conditioning_name )
assert len ( cond_data . conditionings ) == 1
flux_conditioning = cond_data . conditionings [ 0 ]
assert isinstance ( flux_conditioning , FLUXConditioningInfo )
flux_conditioning = flux_conditioning . to ( dtype = inference_dtype )
t5_embeddings = flux_conditioning . t5_embeds
clip_embeddings = flux_conditioning . clip_embeds
2024-08-29 14:17:08 +00:00
# Load the input latents, if provided.
init_latents = context . tensors . load ( self . latents . latents_name ) if self . latents else None
if init_latents is not None :
init_latents = init_latents . to ( device = TorchDevice . choose_torch_device ( ) , dtype = inference_dtype )
2024-08-28 15:03:08 +00:00
2024-08-19 14:14:58 +00:00
# Prepare input noise.
2024-08-29 19:05:44 +00:00
noise = get_noise (
2024-08-19 14:14:58 +00:00
num_samples = 1 ,
height = self . height ,
width = self . width ,
device = TorchDevice . choose_torch_device ( ) ,
dtype = inference_dtype ,
seed = self . seed ,
)
2024-08-29 14:17:08 +00:00
transformer_info = context . models . load ( self . transformer . transformer )
2024-08-22 16:03:54 +00:00
is_schnell = " schnell " in transformer_info . config . config_path
2024-08-30 14:46:04 +00:00
# Calculate the timestep schedule.
2024-08-29 19:05:44 +00:00
image_seq_len = noise . shape [ - 1 ] * noise . shape [ - 2 ] / / 4
2024-08-19 14:14:58 +00:00
timesteps = get_schedule (
num_steps = self . num_steps ,
2024-08-29 19:05:44 +00:00
image_seq_len = image_seq_len ,
2024-08-19 14:14:58 +00:00
shift = not is_schnell ,
)
2024-08-29 19:05:44 +00:00
# Prepare input latent image.
2024-08-30 15:09:55 +00:00
if init_latents is not None :
# If init_latents is provided, we are doing image-to-image.
if is_schnell :
context . logger . warning (
" Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
" to be poor. Consider using a FLUX dev model instead. "
)
2024-08-29 14:17:08 +00:00
# Clip the timesteps schedule based on denoising_start.
# TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space?
start_idx = int ( self . denoising_start * len ( timesteps ) )
timesteps = timesteps [ start_idx : ]
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps [ 0 ]
2024-08-29 19:05:44 +00:00
x = t_0 * noise + ( 1.0 - t_0 ) * init_latents
else :
2024-08-30 15:09:55 +00:00
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self . denoising_start > 1e-5 :
raise ValueError ( " denoising_start should be 0 when initial latents are not provided. " )
2024-08-29 19:05:44 +00:00
x = noise
inpaint_mask = self . _prep_inpaint_mask ( context , x )
b , _c , h , w = x . shape
img_ids = generate_img_ids ( h = h , w = w , batch_size = b , device = x . device , dtype = x . dtype )
2024-08-29 14:17:08 +00:00
2024-08-19 14:14:58 +00:00
bs , t5_seq_len , _ = t5_embeddings . shape
txt_ids = torch . zeros ( bs , t5_seq_len , 3 , dtype = inference_dtype , device = TorchDevice . choose_torch_device ( ) )
2024-08-07 19:50:03 +00:00
2024-08-29 19:05:44 +00:00
# Pack all latent tensors.
init_latents = pack ( init_latents ) if init_latents is not None else None
inpaint_mask = pack ( inpaint_mask ) if inpaint_mask is not None else None
noise = pack ( noise )
x = pack ( x )
2024-08-30 14:46:04 +00:00
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
2024-08-29 19:05:44 +00:00
assert image_seq_len == x . shape [ 1 ]
2024-08-30 14:46:04 +00:00
# Prepare inpaint extension.
inpaint_extension : InpaintExtension | None = None
if inpaint_mask is not None :
assert init_latents is not None
inpaint_extension = InpaintExtension (
init_latents = init_latents ,
inpaint_mask = inpaint_mask ,
noise = noise ,
)
2024-08-16 21:04:48 +00:00
with transformer_info as transformer :
2024-08-19 14:14:58 +00:00
assert isinstance ( transformer , Flux )
2024-08-16 20:22:49 +00:00
x = denoise (
model = transformer ,
2024-08-28 15:03:08 +00:00
img = x ,
2024-08-16 20:22:49 +00:00
img_ids = img_ids ,
txt = t5_embeddings ,
txt_ids = txt_ids ,
vec = clip_embeddings ,
timesteps = timesteps ,
2024-08-30 14:46:04 +00:00
step_callback = self . _build_step_callback ( context ) ,
2024-08-16 20:22:49 +00:00
guidance = self . guidance ,
2024-08-30 14:46:04 +00:00
inpaint_extension = inpaint_extension ,
2024-08-06 21:51:22 +00:00
)
2024-08-16 20:22:49 +00:00
x = unpack ( x . float ( ) , self . height , self . width )
return x
2024-08-29 19:05:44 +00:00
def _prep_inpaint_mask ( self , context : InvocationContext , latents : torch . Tensor ) - > torch . Tensor | None :
""" Prepare the inpaint mask.
2024-08-30 14:46:04 +00:00
- Loads the mask
- Resizes if necessary
- Casts to same device / dtype as latents
- Expands mask to the same shape as latents so that they line up after ' packing '
Args :
context ( InvocationContext ) : The invocation context , for loading the inpaint mask .
latents ( torch . Tensor ) : A latent image tensor . In ' unpacked ' format . Used to determine the target shape ,
device , and dtype for the inpaint mask .
2024-08-29 19:05:44 +00:00
Returns :
2024-08-30 14:46:04 +00:00
torch . Tensor | None : Inpaint mask .
2024-08-29 19:05:44 +00:00
"""
if self . denoise_mask is None :
return None
mask = context . tensors . load ( self . denoise_mask . mask_name )
2024-08-30 14:46:04 +00:00
2024-08-29 19:05:44 +00:00
_ , _ , latent_height , latent_width = latents . shape
mask = tv_resize (
img = mask ,
size = [ latent_height , latent_width ] ,
interpolation = tv_transforms . InterpolationMode . BILINEAR ,
antialias = False ,
)
2024-08-30 14:46:04 +00:00
2024-08-29 19:05:44 +00:00
mask = mask . to ( device = latents . device , dtype = latents . dtype )
2024-08-30 14:46:04 +00:00
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
# `latents`.
return mask . expand_as ( latents )
def _build_step_callback ( self , context : InvocationContext ) - > Callable [ [ ] , None ] :
def step_callback ( ) - > None :
if context . util . is_canceled ( ) :
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
return step_callback