mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split a FluxTextEncoderInvocation out from the FluxTextToImageInvocation. This has the advantage that we benfit from automatic caching when the prompt isn't changed.
This commit is contained in:
parent
a8a2fc106d
commit
f0baf880b5
135
invokeai/app/invocations/flux_text_encoder.py
Normal file
135
invokeai/app/invocations/flux_text_encoder.py
Normal file
@ -0,0 +1,135 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import qfloat8
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
)
|
||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||
|
||||
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
||||
# compatible with other ConditioningOutputs.
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine the T5 max sequence length based on the model.
|
||||
if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# elif self.model == "flux-dev":
|
||||
# max_seq_len = 512
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
# Load the CLIP tokenizer.
|
||||
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
# Load the T5 tokenizer.
|
||||
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||
|
||||
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
||||
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||
with (
|
||||
context.models.load_local_model(
|
||||
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
||||
) as clip_text_encoder,
|
||||
context.models.load_local_model(
|
||||
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||
) as t5_text_encoder,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
pipeline = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=None,
|
||||
text_encoder=clip_text_encoder,
|
||||
tokenizer=clip_tokenizer,
|
||||
text_encoder_2=t5_text_encoder,
|
||||
tokenizer_2=t5_tokenizer,
|
||||
transformer=None,
|
||||
)
|
||||
|
||||
# prompt_embeds: T5 embeddings
|
||||
# pooled_prompt_embeds: CLIP embeddings
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
prompt=self.positive_prompt,
|
||||
prompt_2=self.positive_prompt,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
max_sequence_length=max_seq_len,
|
||||
)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
||||
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
return model
|
||||
|
||||
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a T5EncoderModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): dtype?
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
|
||||
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
@ -7,16 +7,22 @@ from diffusers.models.transformers.transformer_flux import FluxTransformer2DMode
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import qfloat8
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
|
||||
TFluxModelKeys = Literal["flux-schnell"]
|
||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||
@ -44,7 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
)
|
||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||
positive_text_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
|
||||
@ -58,66 +66,17 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
||||
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
image = self._run_vae_decoding(context, model_path, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine the T5 max sequence length based on the model.
|
||||
if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# elif self.model == "flux-dev":
|
||||
# max_seq_len = 512
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
# Load the CLIP tokenizer.
|
||||
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
# Load the T5 tokenizer.
|
||||
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||
|
||||
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
||||
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||
with (
|
||||
context.models.load_local_model(
|
||||
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
||||
) as clip_text_encoder,
|
||||
context.models.load_local_model(
|
||||
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||
) as t5_text_encoder,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
pipeline = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=None,
|
||||
text_encoder=clip_text_encoder,
|
||||
tokenizer=clip_tokenizer,
|
||||
text_encoder_2=t5_text_encoder,
|
||||
tokenizer_2=t5_tokenizer,
|
||||
transformer=None,
|
||||
)
|
||||
|
||||
# prompt_embeds: T5 embeddings
|
||||
# pooled_prompt_embeds: CLIP embeddings
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
prompt=self.positive_prompt,
|
||||
prompt_2=self.positive_prompt,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
max_sequence_length=max_seq_len,
|
||||
)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -199,44 +158,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(image, Image.Image)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
||||
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
return model
|
||||
|
||||
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a T5EncoderModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): dtype?
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
|
||||
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
|
@ -25,11 +25,6 @@ class BasicConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
@ -43,6 +38,17 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
|
Loading…
Reference in New Issue
Block a user