From 01d8c62c57008a4c8e0cf17daf3058de18e8da4b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Aug 2024 16:40:11 +0000 Subject: [PATCH] Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation. --- .../app/invocations/flux_text_to_image.py | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 2efa76b4ec..caca495ccd 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,4 +1,3 @@ -import json from pathlib import Path from typing import Literal @@ -6,9 +5,9 @@ import torch from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline -from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize +from optimum.quanto import qfloat8 +from optimum.quanto.models import QuantizedDiffusersModel from PIL import Image -from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -21,6 +20,10 @@ TFluxModelKeys = Literal["flux-schnell"] FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"} +class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel): + base_class = FluxTransformer2DModel + + @invocation( "flux_text_to_image", title="FLUX Text to Image", @@ -202,23 +205,16 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: if self.use_8bit: model_8bit_path = path / "quantized" - model_8bit_weights_path = model_8bit_path / "weights.safetensors" - model_8bit_map_path = model_8bit_path / "quantization_map.json" if model_8bit_path.exists(): # The quantized model exists, load it. - # TODO(ryand): Make loading from quantized model work properly. - # Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210 - model = FluxTransformer2DModel.from_pretrained( - path, - local_files_only=True, - ) - assert isinstance(model, FluxTransformer2DModel) - model = model.to(device=torch.device("meta")) + # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like + # something that we should be able to make much faster. + q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path) - state_dict = load_file(model_8bit_weights_path) - with open(model_8bit_map_path, "r") as f: - quant_map = json.load(f) - requantize(model=model, state_dict=state_dict, quantization_map=quant_map) + # Access the underlying wrapped model. + # We access the wrapped model, even though it is private, because it simplifies the type checking by + # always returning a FluxTransformer2DModel from this function. + model = q_model._wrapped else: # The quantized model does not exist yet, quantize and save it. # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on @@ -227,13 +223,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) assert isinstance(model, FluxTransformer2DModel) - quantize(model, weights=qfloat8) - freeze(model) + q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8) model_8bit_path.mkdir(parents=True, exist_ok=True) - save_file(model.state_dict(), model_8bit_weights_path) - with open(model_8bit_map_path, "w") as f: - json.dump(quantization_map(model), f) + q_model.save_pretrained(model_8bit_path) + + # (See earlier comment about accessing the wrapped model.) + model = q_model._wrapped else: model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)