mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Got FLUX schnell working with 8-bit quantization. Still lots of rough edges to clean up.
This commit is contained in:
parent
3319491861
commit
45263b339f
@ -1,11 +1,14 @@
|
|||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.pipelines.flux import FluxPipeline
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
|
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
@ -29,6 +32,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
"""Text-to-image generation using a FLUX model."""
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||||
|
use_8bit: bool = InputField(
|
||||||
|
default=False, description="Whether to quantize the T5 model and transformer model to 8-bit precision."
|
||||||
|
)
|
||||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||||
@ -110,7 +116,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
clip_embeddings: torch.Tensor,
|
clip_embeddings: torch.Tensor,
|
||||||
t5_embeddings: torch.Tensor,
|
t5_embeddings: torch.Tensor,
|
||||||
):
|
):
|
||||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
|
||||||
|
|
||||||
|
# HACK(ryand): Manually empty the cache.
|
||||||
|
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||||
|
|
||||||
transformer_path = flux_model_dir / "transformer"
|
transformer_path = flux_model_dir / "transformer"
|
||||||
with context.models.load_local_model(
|
with context.models.load_local_model(
|
||||||
@ -144,7 +153,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
flux_model_dir: Path,
|
flux_model_dir: Path,
|
||||||
latent: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
vae_path = flux_model_dir / "vae"
|
vae_path = flux_model_dir / "vae"
|
||||||
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
|
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
|
||||||
@ -166,8 +175,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
latents = (
|
latents = (
|
||||||
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
||||||
) + flux_pipeline_with_vae.vae.config.shift_factor
|
) + flux_pipeline_with_vae.vae.config.shift_factor
|
||||||
|
latents = latents.to(dtype=vae.dtype)
|
||||||
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
||||||
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")
|
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
||||||
|
|
||||||
assert isinstance(image, Image.Image)
|
assert isinstance(image, Image.Image)
|
||||||
return image
|
return image
|
||||||
@ -184,9 +194,38 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
assert isinstance(model, T5EncoderModel)
|
assert isinstance(model, T5EncoderModel)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||||
def _load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
if self.use_8bit:
|
||||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
|
model_8bit_path = path / "quantized"
|
||||||
|
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
|
||||||
|
model_8bit_map_path = model_8bit_path / "quantization_map.json"
|
||||||
|
if model_8bit_path.exists():
|
||||||
|
# The quantized model exists, load it.
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
|
||||||
|
state_dict = load_file(model_8bit_weights_path)
|
||||||
|
with open(model_8bit_map_path, "r") as f:
|
||||||
|
quant_map = json.load(f)
|
||||||
|
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
|
||||||
|
quantize(model, weights=qfloat8)
|
||||||
|
freeze(model)
|
||||||
|
|
||||||
|
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_file(model.state_dict(), model_8bit_weights_path)
|
||||||
|
with open(model_8bit_map_path, "w") as f:
|
||||||
|
json.dump(quantization_map(model), f)
|
||||||
|
else:
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(
|
||||||
|
path, local_files_only=True, torch_dtype=TorchDevice.choose_torch_dtype()
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -45,16 +45,17 @@ dependencies = [
|
|||||||
"onnx==1.15.0",
|
"onnx==1.15.0",
|
||||||
"onnxruntime==1.16.3",
|
"onnxruntime==1.16.3",
|
||||||
"opencv-python==4.9.0.80",
|
"opencv-python==4.9.0.80",
|
||||||
|
"optimum-quanto==0.2.4",
|
||||||
"pytorch-lightning==2.1.3",
|
"pytorch-lightning==2.1.3",
|
||||||
"safetensors==0.4.3",
|
"safetensors==0.4.3",
|
||||||
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
||||||
"sentencepiece==0.2.0",
|
"sentencepiece==0.2.0",
|
||||||
"spandrel==0.3.4",
|
"spandrel==0.3.4",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"torch==2.2.2",
|
"torch==2.4.0",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchsde==0.2.6",
|
"torchsde==0.2.6",
|
||||||
"torchvision==0.17.2",
|
"torchvision==0.19.0",
|
||||||
"transformers==4.41.1",
|
"transformers==4.41.1",
|
||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
|
Loading…
Reference in New Issue
Block a user