Add support for 8-bit quantizatino of the FLUX T5XXL text encoder.

This commit is contained in:
Ryan Dick 2024-08-08 18:23:20 +00:00 committed by Brandon Rising
parent 8cce4a40d4
commit 9381211508

View File

@ -6,9 +6,10 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import qfloat8
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers.models.auto import AutoModelForTextEncoding
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
@ -24,6 +25,10 @@ class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedModelForTextEncoding(QuantizedTransformersModel):
auto_class = AutoModelForTextEncoding
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
@ -196,9 +201,35 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(model, CLIPTextModel)
return model
@staticmethod
def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a T5EncoderModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): dtype?
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
model_8bit_path.mkdir(parents=True, exist_ok=True)
q_model.save_pretrained(model_8bit_path)
# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
else:
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)
return model