mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation.
This commit is contained in:
parent
55a242b2d6
commit
01d8c62c57
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
@ -6,9 +5,9 @@ import torch
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
|
||||
from optimum.quanto import qfloat8
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
@ -21,6 +20,10 @@ TFluxModelKeys = Literal["flux-schnell"]
|
||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||
|
||||
|
||||
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
@ -202,23 +205,16 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
|
||||
model_8bit_map_path = model_8bit_path / "quantization_map.json"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): Make loading from quantized model work properly.
|
||||
# Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
|
||||
model = FluxTransformer2DModel.from_pretrained(
|
||||
path,
|
||||
local_files_only=True,
|
||||
)
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
model = model.to(device=torch.device("meta"))
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||
|
||||
state_dict = load_file(model_8bit_weights_path)
|
||||
with open(model_8bit_map_path, "r") as f:
|
||||
quant_map = json.load(f)
|
||||
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a FluxTransformer2DModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||
@ -227,13 +223,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
|
||||
quantize(model, weights=qfloat8)
|
||||
freeze(model)
|
||||
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_8bit_weights_path)
|
||||
with open(model_8bit_map_path, "w") as f:
|
||||
json.dump(quantization_map(model), f)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user