From 45263b339f97e11436004aee2ab9afc4ac7bb46a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 7 Aug 2024 19:50:03 +0000 Subject: [PATCH] Got FLUX schnell working with 8-bit quantization. Still lots of rough edges to clean up. --- .../app/invocations/flux_text_to_image.py | 53 ++++++++++++++++--- pyproject.toml | 5 +- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 0b992909ab..2f78713b0c 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,11 +1,14 @@ +import json from pathlib import Path from typing import Literal import torch from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from diffusers.pipelines.flux import FluxPipeline +from diffusers.pipelines.flux.pipeline_flux import FluxPipeline +from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize from PIL import Image +from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -29,6 +32,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Text-to-image generation using a FLUX model.""" model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") + use_8bit: bool = InputField( + default=False, description="Whether to quantize the T5 model and transformer model to 8-bit precision." + ) positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.") width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") @@ -110,7 +116,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): clip_embeddings: torch.Tensor, t5_embeddings: torch.Tensor, ): - scheduler = FlowMatchEulerDiscreteScheduler() + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True) + + # HACK(ryand): Manually empty the cache. + context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) transformer_path = flux_model_dir / "transformer" with context.models.load_local_model( @@ -144,7 +153,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): self, context: InvocationContext, flux_model_dir: Path, - latent: torch.Tensor, + latents: torch.Tensor, ) -> Image.Image: vae_path = flux_model_dir / "vae" with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae: @@ -166,8 +175,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): latents = ( latents / flux_pipeline_with_vae.vae.config.scaling_factor ) + flux_pipeline_with_vae.vae.config.shift_factor + latents = latents.to(dtype=vae.dtype) image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0] - image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil") + image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0] assert isinstance(image, Image.Image) return image @@ -184,9 +194,38 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): assert isinstance(model, T5EncoderModel) return model - @staticmethod - def _load_flux_transformer(path: Path) -> FluxTransformer2DModel: - model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True) + def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: + if self.use_8bit: + model_8bit_path = path / "quantized" + model_8bit_weights_path = model_8bit_path / "weights.safetensors" + model_8bit_map_path = model_8bit_path / "quantization_map.json" + if model_8bit_path.exists(): + # The quantized model exists, load it. + with torch.device("meta"): + model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True) + assert isinstance(model, FluxTransformer2DModel) + + state_dict = load_file(model_8bit_weights_path) + with open(model_8bit_map_path, "r") as f: + quant_map = json.load(f) + requantize(model=model, state_dict=state_dict, quantization_map=quant_map) + else: + # The quantized model does not exist yet, quantize and save it. + model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) + assert isinstance(model, FluxTransformer2DModel) + + quantize(model, weights=qfloat8) + freeze(model) + + model_8bit_path.mkdir(parents=True, exist_ok=True) + save_file(model.state_dict(), model_8bit_weights_path) + with open(model_8bit_map_path, "w") as f: + json.dump(quantization_map(model), f) + else: + model = FluxTransformer2DModel.from_pretrained( + path, local_files_only=True, torch_dtype=TorchDevice.choose_torch_dtype() + ) + assert isinstance(model, FluxTransformer2DModel) return model diff --git a/pyproject.toml b/pyproject.toml index 1c4e087e54..c6dc025a00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,16 +45,17 @@ dependencies = [ "onnx==1.15.0", "onnxruntime==1.16.3", "opencv-python==4.9.0.80", + "optimum-quanto==0.2.4", "pytorch-lightning==2.1.3", "safetensors==0.4.3", # sentencepiece is required to load T5TokenizerFast (used by FLUX). "sentencepiece==0.2.0", "spandrel==0.3.4", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 - "torch==2.2.2", + "torch==2.4.0", "torchmetrics==0.11.4", "torchsde==0.2.6", - "torchvision==0.17.2", + "torchvision==0.19.0", "transformers==4.41.1", # Core application dependencies, pinned for reproducible builds.