WIP on moving from diffusers to FLUX

This commit is contained in:
Ryan Dick 2024-08-16 20:22:49 +00:00
parent d40c9ff60a
commit 823c663e1b
3 changed files with 153 additions and 109 deletions

View File

@ -3,9 +3,12 @@ from typing import Literal
import accelerate import accelerate
import torch import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from einops import rearrange, repeat
from flux.model import Flux
from flux.modules.autoencoder import AutoEncoder
from flux.sampling import denoise, get_noise, get_schedule, unpack
from flux.util import configs as flux_configs
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding from transformers.models.auto import AutoModelForTextEncoding
@ -21,11 +24,11 @@ from invokeai.app.invocations.fields import (
) )
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
TFluxModelKeys = Literal["flux-schnell"] TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"} FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
@ -70,7 +73,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) # model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
flux_transformer_path = context.models.download_and_cache_model(
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
)
flux_ae_path = context.models.download_and_cache_model(
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
)
# Load the conditioning data. # Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
@ -78,123 +87,155 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
flux_conditioning = cond_data.conditionings[0] flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo) assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) latents = self._run_diffusion(
image = self._run_vae_decoding(context, model_path, latents) context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
)
image = self._run_vae_decoding(context, flux_ae_path, latents)
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
def _run_diffusion( def _run_diffusion(
self, self,
context: InvocationContext, context: InvocationContext,
flux_model_dir: Path, flux_transformer_path: Path,
clip_embeddings: torch.Tensor, clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor, t5_embeddings: torch.Tensor,
): ):
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True) inference_dtype = TorchDevice.choose_torch_dtype()
# Prepare input noise.
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
# CPU RNG?
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in str(flux_transformer_path)
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty. # if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
transformer_path = flux_model_dir / "transformer"
with context.models.load_local_model( with context.models.load_local_model(
model_path=transformer_path, loader=self._load_flux_transformer model_path=flux_transformer_path, loader=self._load_flux_transformer
) as transformer: ) as transformer:
assert isinstance(transformer, FluxTransformer2DModel) assert isinstance(transformer, Flux)
flux_pipeline_with_transformer = FluxPipeline( x = denoise(
scheduler=scheduler, model=transformer,
vae=None, img=img,
text_encoder=None, img_ids=img_ids,
tokenizer=None, txt=t5_embeddings,
text_encoder_2=None, txt_ids=txt_ids,
tokenizer_2=None, vec=clip_embeddings,
transformer=transformer, timesteps=timesteps,
guidance=self.guidance,
) )
dtype = torch.bfloat16 x = unpack(x.float(), self.height, self.width)
t5_embeddings = t5_embeddings.to(dtype=dtype)
clip_embeddings = clip_embeddings.to(dtype=dtype)
latents = flux_pipeline_with_transformer( return x
height=self.height,
width=self.width,
num_inference_steps=self.num_steps,
guidance_scale=self.guidance,
generator=torch.Generator().manual_seed(self.seed),
prompt_embeds=t5_embeddings,
pooled_prompt_embeds=clip_embeddings,
output_type="latent",
return_dict=False,
)[0]
assert isinstance(latents, torch.Tensor) def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return latents """Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids
def _run_vae_decoding( def _run_vae_decoding(
self, self,
context: InvocationContext, context: InvocationContext,
flux_model_dir: Path, flux_ae_path: Path,
latents: torch.Tensor, latents: torch.Tensor,
) -> Image.Image: ) -> Image.Image:
vae_path = flux_model_dir / "vae" with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae: assert isinstance(vae, AutoEncoder)
assert isinstance(vae, AutoencoderKL) # TODO(ryand): Test that this works with both float16 and bfloat16.
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
img = vae.decode(latents)
flux_pipeline_with_vae = FluxPipeline( img.clamp(-1, 1)
scheduler=None, img = rearrange(img[0], "c h w -> h w c")
vae=vae, img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
transformer=None,
)
latents = flux_pipeline_with_vae._unpack_latents( return img_pil
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
)
latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
assert isinstance(image, Image.Image)
return image
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
inference_dtype = TorchDevice.choose_torch_dtype()
if self.quantization_type == "raw": if self.quantization_type == "raw":
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
elif self.quantization_type == "NF4": params = flux_configs["flux-schnell"].params
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config)
assert isinstance(empty_model, FluxTransformer2DModel)
model_nf4_path = path / "bnb_nf4" # Initialize the model on the "meta" device.
assert model_nf4_path.exists()
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) model = Flux(params).to(inference_dtype)
state_dict = load_file(path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
elif self.quantization_type == "NF4":
model_path = path.parent / "bnb_nf4.safetensors"
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
# this on GPUs without bfloat16 support. # this on GPUs without bfloat16 support.
sd = load_file(model_nf4_path / "model.safetensors") state_dict = load_file(model_path)
model.load_state_dict(sd, strict=True, assign=True) model.load_state_dict(state_dict, strict=True, assign=True)
elif self.quantization_type == "llm_int8":
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config)
assert isinstance(empty_model, FluxTransformer2DModel)
model_int8_path = path / "bnb_llm_int8"
assert model_int8_path.exists()
with accelerate.init_empty_weights():
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
sd = load_file(model_int8_path / "model.safetensors") elif self.quantization_type == "llm_int8":
model.load_state_dict(sd, strict=True, assign=True) raise NotImplementedError("LLM int8 quantization is not yet supported.")
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
# with accelerate.init_empty_weights():
# empty_model = FluxTransformer2DModel.from_config(model_config)
# assert isinstance(empty_model, FluxTransformer2DModel)
# model_int8_path = path / "bnb_llm_int8"
# assert model_int8_path.exists()
# with accelerate.init_empty_weights():
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
# sd = load_file(model_int8_path / "model.safetensors")
# model.load_state_dict(sd, strict=True, assign=True)
else: else:
raise ValueError(f"Unsupported quantization type: {self.quantization_type}") raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
@ -202,7 +243,12 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return model return model
@staticmethod @staticmethod
def _load_flux_vae(path: Path) -> AutoencoderKL: def _load_flux_vae(path: Path) -> AutoEncoder:
model = AutoencoderKL.from_pretrained(path, local_files_only=True) # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
assert isinstance(model, AutoencoderKL) ae_params = flux_configs["flux1-schnell"].ae_params
return model with accelerate.init_empty_weights():
ae = AutoEncoder(ae_params)
state_dict = load_file(path)
ae.load_state_dict(state_dict, strict=True, assign=True)
return ae

View File

@ -4,7 +4,8 @@ from pathlib import Path
import accelerate import accelerate
import torch import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from flux.model import Flux
from flux.util import configs as flux_configs
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@ -22,22 +23,24 @@ def log_time(name: str):
def main(): def main():
# Load the FLUX transformer model onto the meta device.
model_path = Path( model_path = Path(
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/" "/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
) )
# inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"): with log_time("Intialize FLUX transformer on meta device"):
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True) # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config) model = Flux(params)
assert isinstance(empty_model, FluxTransformer2DModel)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set() modules_to_not_convert: set[str] = set()
model_nf4_path = model_path / "bnb_nf4" model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
if model_nf4_path.exists(): if model_nf4_path.exists():
# The quantized model already exists, load it and return it. # The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...") print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
@ -45,12 +48,12 @@ def main():
# Replace the linear layers with NF4 quantized linear layers (still on the meta device). # Replace the linear layers with NF4 quantized linear layers (still on the meta device).
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4( model = quantize_model_nf4(
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
) )
with log_time("Load state dict into model"): with log_time("Load state dict into model"):
sd = load_file(model_nf4_path / "model.safetensors") state_dict = load_file(model_nf4_path)
model.load_state_dict(sd, strict=True, assign=True) model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda"): with log_time("Move model to cuda"):
model = model.to("cuda") model = model.to("cuda")
@ -63,30 +66,24 @@ def main():
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4( model = quantize_model_nf4(
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
) )
with log_time("Load state dict into model"): with log_time("Load state dict into model"):
# Load sharded state dict. state_dict = load_file(model_path)
files = list(model_path.glob("*.safetensors")) # TODO(ryand): Cast the state_dict to the appropriate dtype?
state_dict = dict()
for file in files:
sd = load_file(file)
state_dict.update(sd)
model.load_state_dict(state_dict, strict=True, assign=True) model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"): with log_time("Move model to cuda and quantize"):
model = model.to("cuda") model = model.to("cuda")
with log_time("Save quantized model"): with log_time("Save quantized model"):
model_nf4_path.mkdir(parents=True, exist_ok=True) model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
output_path = model_nf4_path / "model.safetensors" save_file(model.state_dict(), model_nf4_path)
save_file(model.state_dict(), output_path)
print(f"Successfully quantized and saved model to '{output_path}'.") print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
assert isinstance(model, FluxTransformer2DModel) assert isinstance(model, Flux)
return model return model

View File

@ -40,6 +40,7 @@ dependencies = [
"controlnet-aux==0.0.7", "controlnet-aux==0.0.7",
# TODO(ryand): Bump this once the next diffusers release is ready. # TODO(ryand): Bump this once the next diffusers release is ready.
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3", "diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
"flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model "mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal() "numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()