Got FLUX schnell working with 8-bit quantization. Still lots of rough edges to clean up.

This commit is contained in:
Ryan Dick 2024-08-07 19:50:03 +00:00 committed by Brandon
parent 3319491861
commit 45263b339f
2 changed files with 49 additions and 9 deletions

View File

@ -1,11 +1,14 @@
import json
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
import torch import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux import FluxPipeline from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
from PIL import Image from PIL import Image
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -29,6 +32,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model.""" """Text-to-image generation using a FLUX model."""
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
use_8bit: bool = InputField(
default=False, description="Whether to quantize the T5 model and transformer model to 8-bit precision."
)
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.") positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
@ -110,7 +116,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
clip_embeddings: torch.Tensor, clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor, t5_embeddings: torch.Tensor,
): ):
scheduler = FlowMatchEulerDiscreteScheduler() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
# HACK(ryand): Manually empty the cache.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
transformer_path = flux_model_dir / "transformer" transformer_path = flux_model_dir / "transformer"
with context.models.load_local_model( with context.models.load_local_model(
@ -144,7 +153,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
self, self,
context: InvocationContext, context: InvocationContext,
flux_model_dir: Path, flux_model_dir: Path,
latent: torch.Tensor, latents: torch.Tensor,
) -> Image.Image: ) -> Image.Image:
vae_path = flux_model_dir / "vae" vae_path = flux_model_dir / "vae"
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae: with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
@ -166,8 +175,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents = ( latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor ) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0] image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil") image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
assert isinstance(image, Image.Image) assert isinstance(image, Image.Image)
return image return image
@ -184,9 +194,38 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(model, T5EncoderModel) assert isinstance(model, T5EncoderModel)
return model return model
@staticmethod def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
def _load_flux_transformer(path: Path) -> FluxTransformer2DModel: if self.use_8bit:
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True) model_8bit_path = path / "quantized"
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
model_8bit_map_path = model_8bit_path / "quantization_map.json"
if model_8bit_path.exists():
# The quantized model exists, load it.
with torch.device("meta"):
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, FluxTransformer2DModel)
state_dict = load_file(model_8bit_weights_path)
with open(model_8bit_map_path, "r") as f:
quant_map = json.load(f)
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
else:
# The quantized model does not exist yet, quantize and save it.
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
quantize(model, weights=qfloat8)
freeze(model)
model_8bit_path.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_8bit_weights_path)
with open(model_8bit_map_path, "w") as f:
json.dump(quantization_map(model), f)
else:
model = FluxTransformer2DModel.from_pretrained(
path, local_files_only=True, torch_dtype=TorchDevice.choose_torch_dtype()
)
assert isinstance(model, FluxTransformer2DModel) assert isinstance(model, FluxTransformer2DModel)
return model return model

View File

@ -45,16 +45,17 @@ dependencies = [
"onnx==1.15.0", "onnx==1.15.0",
"onnxruntime==1.16.3", "onnxruntime==1.16.3",
"opencv-python==4.9.0.80", "opencv-python==4.9.0.80",
"optimum-quanto==0.2.4",
"pytorch-lightning==2.1.3", "pytorch-lightning==2.1.3",
"safetensors==0.4.3", "safetensors==0.4.3",
# sentencepiece is required to load T5TokenizerFast (used by FLUX). # sentencepiece is required to load T5TokenizerFast (used by FLUX).
"sentencepiece==0.2.0", "sentencepiece==0.2.0",
"spandrel==0.3.4", "spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2", "torch==2.4.0",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",
"torchsde==0.2.6", "torchsde==0.2.6",
"torchvision==0.17.2", "torchvision==0.19.0",
"transformers==4.41.1", "transformers==4.41.1",
# Core application dependencies, pinned for reproducible builds. # Core application dependencies, pinned for reproducible builds.