InvokeAI/invokeai/app/invocations/flux_text_to_image.py
2024-08-26 20:17:50 -04:00

248 lines
11 KiB
Python

import json
from pathlib import Path
from typing import Literal
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
from PIL import Image
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image"],
category="image",
version="1.0.0",
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
use_8bit: bool = InputField(
default=False, description="Whether to quantize the transformer model to 8-bit precision."
)
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
image = self._run_vae_decoding(context, model_path, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
# Determine the T5 max sequence length based on the model.
if self.model == "flux-schnell":
max_seq_len = 256
# elif self.model == "flux-dev":
# max_seq_len = 512
else:
raise ValueError(f"Unknown model: {self.model}")
# Load the CLIP tokenizer.
clip_tokenizer_path = flux_model_dir / "tokenizer"
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
assert isinstance(clip_tokenizer, CLIPTokenizer)
# Load the T5 tokenizer.
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
assert isinstance(t5_tokenizer, T5TokenizerFast)
clip_text_encoder_path = flux_model_dir / "text_encoder"
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
with (
context.models.load_local_model(
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
) as clip_text_encoder,
context.models.load_local_model(
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
) as t5_text_encoder,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(t5_text_encoder, T5EncoderModel)
pipeline = FluxPipeline(
scheduler=None,
vae=None,
text_encoder=clip_text_encoder,
tokenizer=clip_tokenizer,
text_encoder_2=t5_text_encoder,
tokenizer_2=t5_tokenizer,
transformer=None,
)
# prompt_embeds: T5 embeddings
# pooled_prompt_embeds: CLIP embeddings
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=self.positive_prompt,
prompt_2=self.positive_prompt,
device=TorchDevice.choose_torch_device(),
max_sequence_length=max_seq_len,
)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds
def _run_diffusion(
self,
context: InvocationContext,
flux_model_dir: Path,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
transformer_path = flux_model_dir / "transformer"
with context.models.load_local_model(
model_path=transformer_path, loader=self._load_flux_transformer
) as transformer:
assert isinstance(transformer, FluxTransformer2DModel)
flux_pipeline_with_transformer = FluxPipeline(
scheduler=scheduler,
vae=None,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
transformer=transformer,
)
latents = flux_pipeline_with_transformer(
height=self.height,
width=self.width,
num_inference_steps=self.num_steps,
guidance_scale=self.guidance,
generator=torch.Generator().manual_seed(self.seed),
prompt_embeds=t5_embeddings,
pooled_prompt_embeds=clip_embeddings,
output_type="latent",
return_dict=False,
)[0]
assert isinstance(latents, torch.Tensor)
return latents
def _run_vae_decoding(
self,
context: InvocationContext,
flux_model_dir: Path,
latents: torch.Tensor,
) -> Image.Image:
vae_path = flux_model_dir / "vae"
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
assert isinstance(vae, AutoencoderKL)
flux_pipeline_with_vae = FluxPipeline(
scheduler=None,
vae=vae,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
transformer=None,
)
latents = flux_pipeline_with_vae._unpack_latents(
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
)
latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
assert isinstance(image, Image.Image)
return image
@staticmethod
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, CLIPTextModel)
return model
@staticmethod
def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)
return model
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
model_8bit_map_path = model_8bit_path / "quantization_map.json"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): Make loading from quantized model work properly.
# Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
model = FluxTransformer2DModel.from_pretrained(
path,
local_files_only=True,
)
assert isinstance(model, FluxTransformer2DModel)
model = model.to(device=torch.device("meta"))
state_dict = load_file(model_8bit_weights_path)
with open(model_8bit_map_path, "r") as f:
quant_map = json.load(f)
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
# here.
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
quantize(model, weights=qfloat8)
freeze(model)
model_8bit_path.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_8bit_weights_path)
with open(model_8bit_map_path, "w") as f:
json.dump(quantization_map(model), f)
else:
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
return model
@staticmethod
def _load_flux_vae(path: Path) -> AutoencoderKL:
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
assert isinstance(model, AutoencoderKL)
return model