InvokeAI/invokeai/app/invocations/flux_text_to_image.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

176 lines
6.9 KiB
Python
Raw Normal View History

import numpy as np
import torch
from einops import rearrange
from PIL import Image
2024-08-22 15:29:59 +00:00
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
2024-08-15 14:27:42 +00:00
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.step_callback import PipelineIntermediateState
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
2024-08-16 20:22:49 +00:00
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import image_to_dataURL
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
2024-08-22 15:29:59 +00:00
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
transformer: TransformerField = InputField(
2024-08-21 13:45:22 +00:00
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
2024-08-12 18:04:23 +00:00
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
2024-08-21 13:45:22 +00:00
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
2024-08-21 13:45:22 +00:00
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16
# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = prepare_latent_img_patches(x)
2024-08-22 16:03:54 +00:00
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
2024-08-07 22:10:09 +00:00
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer:
assert isinstance(transformer, Flux)
2024-08-16 20:22:49 +00:00
def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image
latent_image = unpack(img.float(), self.height, self.width)
latent_image = latent_image.squeeze() # Remove unnecessary dimensions
flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# Create a new tensor of the required shape [255, 255, 3]
latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# Convert to a NumPy array and then to a PIL Image
image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
# TODO: move this whole function to invocation context to properly reference these variables
context._services.events.emit_invocation_denoise_progress(
context._data.queue_item,
context._data.invocation,
state,
ProgressImage(dataURL=dataURL, width=width, height=height),
)
2024-08-16 20:22:49 +00:00
x = denoise(
model=transformer,
img=img,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
2024-08-16 20:22:49 +00:00
guidance=self.guidance,
)
2024-08-16 20:22:49 +00:00
x = unpack(x.float(), self.height, self.width)
return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
2024-08-16 20:22:49 +00:00
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil