mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for 8-bit quantizatino of the FLUX T5XXL text encoder.
This commit is contained in:
parent
8cce4a40d4
commit
9381211508
@ -6,9 +6,10 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import qfloat8
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
||||
@ -24,6 +25,10 @@ class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
|
||||
|
||||
class QuantizedModelForTextEncoding(QuantizedTransformersModel):
|
||||
auto_class = AutoModelForTextEncoding
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
@ -196,10 +201,36 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
|
||||
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a T5EncoderModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): dtype?
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
|
||||
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
|
Loading…
Reference in New Issue
Block a user