mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
First draft of FluxTextToImageInvocation.
This commit is contained in:
parent
0b77511271
commit
b39031ea53
201
invokeai/app/invocations/flux_text_to_image.py
Normal file
201
invokeai/app/invocations/flux_text_to_image.py
Normal file
@ -0,0 +1,201 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux import FluxPipeline
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
TFluxModelKeys = Literal["flux-schnell"]
|
||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
|
||||
clip_embeddings = self._run_clip_text_encoder(context, model_path)
|
||||
t5_embeddings = self._run_t5_text_encoder(context, model_path)
|
||||
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
|
||||
image = self._run_vae_decoding(context, model_path, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_clip_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
|
||||
"""Run the CLIP text encoder."""
|
||||
tokenizer_path = flux_model_dir / "tokenizer"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
|
||||
text_encoder_path = flux_model_dir / "text_encoder"
|
||||
with context.models.load_local_model(
|
||||
model_path=text_encoder_path, loader=self._load_flux_text_encoder
|
||||
) as text_encoder:
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
flux_pipeline_with_te = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=None,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
transformer=None,
|
||||
)
|
||||
|
||||
return flux_pipeline_with_te._get_clip_prompt_embeds(
|
||||
prompt=self.positive_prompt, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
def _run_t5_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
|
||||
"""Run the T5 text encoder."""
|
||||
|
||||
if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# elif self.model == "flux-dev":
|
||||
# max_seq_len = 512
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_path, local_files_only=True)
|
||||
assert isinstance(tokenizer_2, T5TokenizerFast)
|
||||
|
||||
text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||
with context.models.load_local_model(
|
||||
model_path=text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||
) as text_encoder_2:
|
||||
flux_pipeline_with_te2 = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=None,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=None,
|
||||
)
|
||||
|
||||
return flux_pipeline_with_te2._get_t5_prompt_embeds(
|
||||
prompt=self.positive_prompt, max_sequence_length=max_seq_len, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_model_dir: Path,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
transformer_path = flux_model_dir / "transformer"
|
||||
with context.models.load_local_model(
|
||||
model_path=transformer_path, loader=self._load_flux_transformer
|
||||
) as transformer:
|
||||
assert isinstance(transformer, FluxTransformer2DModel)
|
||||
|
||||
flux_pipeline_with_transformer = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=None,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
return flux_pipeline_with_transformer(
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
num_inference_steps=self.num_steps,
|
||||
guidance_scale=self.guidance,
|
||||
generator=torch.Generator().manual_seed(self.seed),
|
||||
prompt_embeds=t5_embeddings,
|
||||
pooled_prompt_embeds=clip_embeddings,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_model_dir: Path,
|
||||
latent: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
vae_path = flux_model_dir / "vae"
|
||||
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
flux_pipeline_with_vae = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
transformer=None,
|
||||
)
|
||||
|
||||
latents = flux_pipeline_with_vae._unpack_latents(
|
||||
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
||||
)
|
||||
latents = (
|
||||
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
||||
) + flux_pipeline_with_vae.vae.config.shift_factor
|
||||
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
||||
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")
|
||||
|
||||
assert isinstance(image, Image.Image)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
||||
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_vae(path: Path) -> AutoencoderKL:
|
||||
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, AutoencoderKL)
|
||||
return model
|
Loading…
Reference in New Issue
Block a user