diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index caca495ccd..b059ab23da 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -6,9 +6,10 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from optimum.quanto import qfloat8 -from optimum.quanto.models import QuantizedDiffusersModel +from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel from PIL import Image from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers.models.auto import AutoModelForTextEncoding from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata @@ -24,6 +25,10 @@ class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel): base_class = FluxTransformer2DModel +class QuantizedModelForTextEncoding(QuantizedTransformersModel): + auto_class = AutoModelForTextEncoding + + @invocation( "flux_text_to_image", title="FLUX Text to Image", @@ -196,9 +201,35 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): assert isinstance(model, CLIPTextModel) return model - @staticmethod - def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel: - model = T5EncoderModel.from_pretrained(path, local_files_only=True) + def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel: + if self.use_8bit: + model_8bit_path = path / "quantized" + if model_8bit_path.exists(): + # The quantized model exists, load it. + # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like + # something that we should be able to make much faster. + q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path) + + # Access the underlying wrapped model. + # We access the wrapped model, even though it is private, because it simplifies the type checking by + # always returning a T5EncoderModel from this function. + model = q_model._wrapped + else: + # The quantized model does not exist yet, quantize and save it. + # TODO(ryand): dtype? + model = T5EncoderModel.from_pretrained(path, local_files_only=True) + assert isinstance(model, T5EncoderModel) + + q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8) + + model_8bit_path.mkdir(parents=True, exist_ok=True) + q_model.save_pretrained(model_8bit_path) + + # (See earlier comment about accessing the wrapped model.) + model = q_model._wrapped + else: + model = T5EncoderModel.from_pretrained(path, local_files_only=True) + assert isinstance(model, T5EncoderModel) return model