Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation.

This commit is contained in:
Ryan Dick 2024-08-08 16:40:11 +00:00 committed by Brandon
parent 55a242b2d6
commit 01d8c62c57

View File

@ -1,4 +1,3 @@
import json
from pathlib import Path
from typing import Literal
@ -6,9 +5,9 @@ import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
from optimum.quanto import qfloat8
from optimum.quanto.models import QuantizedDiffusersModel
from PIL import Image
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -21,6 +20,10 @@ TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
@ -202,23 +205,16 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
model_8bit_map_path = model_8bit_path / "quantization_map.json"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): Make loading from quantized model work properly.
# Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
model = FluxTransformer2DModel.from_pretrained(
path,
local_files_only=True,
)
assert isinstance(model, FluxTransformer2DModel)
model = model.to(device=torch.device("meta"))
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
state_dict = load_file(model_8bit_weights_path)
with open(model_8bit_map_path, "r") as f:
quant_map = json.load(f)
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a FluxTransformer2DModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
@ -227,13 +223,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
quantize(model, weights=qfloat8)
freeze(model)
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
model_8bit_path.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_8bit_weights_path)
with open(model_8bit_map_path, "w") as f:
json.dump(quantization_map(model), f)
q_model.save_pretrained(model_8bit_path)
# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
else:
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)