mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation.
This commit is contained in:
parent
55a242b2d6
commit
01d8c62c57
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@ -6,9 +5,9 @@ import torch
|
|||||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
|
from optimum.quanto import qfloat8
|
||||||
|
from optimum.quanto.models import QuantizedDiffusersModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file, save_file
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
@ -21,6 +20,10 @@ TFluxModelKeys = Literal["flux-schnell"]
|
|||||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||||
|
base_class = FluxTransformer2DModel
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"flux_text_to_image",
|
"flux_text_to_image",
|
||||||
title="FLUX Text to Image",
|
title="FLUX Text to Image",
|
||||||
@ -202,23 +205,16 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||||
if self.use_8bit:
|
if self.use_8bit:
|
||||||
model_8bit_path = path / "quantized"
|
model_8bit_path = path / "quantized"
|
||||||
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
|
|
||||||
model_8bit_map_path = model_8bit_path / "quantization_map.json"
|
|
||||||
if model_8bit_path.exists():
|
if model_8bit_path.exists():
|
||||||
# The quantized model exists, load it.
|
# The quantized model exists, load it.
|
||||||
# TODO(ryand): Make loading from quantized model work properly.
|
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||||
# Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
|
# something that we should be able to make much faster.
|
||||||
model = FluxTransformer2DModel.from_pretrained(
|
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||||
path,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
|
||||||
model = model.to(device=torch.device("meta"))
|
|
||||||
|
|
||||||
state_dict = load_file(model_8bit_weights_path)
|
# Access the underlying wrapped model.
|
||||||
with open(model_8bit_map_path, "r") as f:
|
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||||
quant_map = json.load(f)
|
# always returning a FluxTransformer2DModel from this function.
|
||||||
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
|
model = q_model._wrapped
|
||||||
else:
|
else:
|
||||||
# The quantized model does not exist yet, quantize and save it.
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||||
@ -227,13 +223,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
|
||||||
quantize(model, weights=qfloat8)
|
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||||
freeze(model)
|
|
||||||
|
|
||||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||||
save_file(model.state_dict(), model_8bit_weights_path)
|
q_model.save_pretrained(model_8bit_path)
|
||||||
with open(model_8bit_map_path, "w") as f:
|
|
||||||
json.dump(quantization_map(model), f)
|
# (See earlier comment about accessing the wrapped model.)
|
||||||
|
model = q_model._wrapped
|
||||||
else:
|
else:
|
||||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user