mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for 8-bit quantizatino of the FLUX T5XXL text encoder.
This commit is contained in:
parent
8cce4a40d4
commit
9381211508
@ -6,9 +6,10 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
|||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
from optimum.quanto import qfloat8
|
from optimum.quanto import qfloat8
|
||||||
from optimum.quanto.models import QuantizedDiffusersModel
|
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
||||||
@ -24,6 +25,10 @@ class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
|||||||
base_class = FluxTransformer2DModel
|
base_class = FluxTransformer2DModel
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedModelForTextEncoding(QuantizedTransformersModel):
|
||||||
|
auto_class = AutoModelForTextEncoding
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"flux_text_to_image",
|
"flux_text_to_image",
|
||||||
title="FLUX Text to Image",
|
title="FLUX Text to Image",
|
||||||
@ -196,9 +201,35 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
assert isinstance(model, CLIPTextModel)
|
assert isinstance(model, CLIPTextModel)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||||
def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
|
if self.use_8bit:
|
||||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
model_8bit_path = path / "quantized"
|
||||||
|
if model_8bit_path.exists():
|
||||||
|
# The quantized model exists, load it.
|
||||||
|
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||||
|
# something that we should be able to make much faster.
|
||||||
|
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# Access the underlying wrapped model.
|
||||||
|
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||||
|
# always returning a T5EncoderModel from this function.
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
|
# TODO(ryand): dtype?
|
||||||
|
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||||
|
assert isinstance(model, T5EncoderModel)
|
||||||
|
|
||||||
|
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||||
|
|
||||||
|
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
q_model.save_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# (See earlier comment about accessing the wrapped model.)
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||||
|
|
||||||
assert isinstance(model, T5EncoderModel)
|
assert isinstance(model, T5EncoderModel)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user