From 823c663e1b3350924ef48c5582b07151a64936b9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 16 Aug 2024 20:22:49 +0000 Subject: [PATCH] WIP on moving from diffusers to FLUX --- .../app/invocations/flux_text_to_image.py | 218 +++++++++++------- .../quantization/load_flux_model_bnb_nf4.py | 43 ++-- pyproject.toml | 1 + 3 files changed, 153 insertions(+), 109 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index de34a6eb5e..19829c47a4 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -3,9 +3,12 @@ from typing import Literal import accelerate import torch -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from diffusers.pipelines.flux.pipeline_flux import FluxPipeline +from einops import rearrange, repeat +from flux.model import Flux +from flux.modules.autoencoder import AutoEncoder +from flux.sampling import denoise, get_noise, get_schedule, unpack +from flux.util import configs as flux_configs from PIL import Image from safetensors.torch import load_file from transformers.models.auto import AutoModelForTextEncoding @@ -21,11 +24,11 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo +from invokeai.backend.util.devices import TorchDevice TFluxModelKeys = Literal["flux-schnell"] FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"} @@ -70,7 +73,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) + # model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) + flux_transformer_path = context.models.download_and_cache_model( + "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors" + ) + flux_ae_path = context.models.download_and_cache_model( + "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors" + ) # Load the conditioning data. cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) @@ -78,123 +87,155 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): flux_conditioning = cond_data.conditionings[0] assert isinstance(flux_conditioning, FLUXConditioningInfo) - latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) - image = self._run_vae_decoding(context, model_path, latents) + latents = self._run_diffusion( + context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds + ) + image = self._run_vae_decoding(context, flux_ae_path, latents) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) def _run_diffusion( self, context: InvocationContext, - flux_model_dir: Path, + flux_transformer_path: Path, clip_embeddings: torch.Tensor, t5_embeddings: torch.Tensor, ): - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True) + inference_dtype = TorchDevice.choose_torch_dtype() + + # Prepare input noise. + # TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a + # CPU RNG? + x = get_noise( + num_samples=1, + height=self.height, + width=self.width, + device=TorchDevice.choose_torch_device(), + dtype=inference_dtype, + seed=self.seed, + ) + + img, img_ids = self._prepare_latent_img_patches(x) + + # HACK(ryand): Find a better way to determine if this is a schnell model or not. + is_schnell = "shnell" in str(flux_transformer_path) + timesteps = get_schedule( + num_steps=self.num_steps, + image_seq_len=img.shape[1], + shift=not is_schnell, + ) + + bs, t5_seq_len, _ = t5_embeddings.shape + txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems # if the cache is not empty. context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) - transformer_path = flux_model_dir / "transformer" with context.models.load_local_model( - model_path=transformer_path, loader=self._load_flux_transformer + model_path=flux_transformer_path, loader=self._load_flux_transformer ) as transformer: - assert isinstance(transformer, FluxTransformer2DModel) + assert isinstance(transformer, Flux) - flux_pipeline_with_transformer = FluxPipeline( - scheduler=scheduler, - vae=None, - text_encoder=None, - tokenizer=None, - text_encoder_2=None, - tokenizer_2=None, - transformer=transformer, + x = denoise( + model=transformer, + img=img, + img_ids=img_ids, + txt=t5_embeddings, + txt_ids=txt_ids, + vec=clip_embeddings, + timesteps=timesteps, + guidance=self.guidance, ) - dtype = torch.bfloat16 - t5_embeddings = t5_embeddings.to(dtype=dtype) - clip_embeddings = clip_embeddings.to(dtype=dtype) + x = unpack(x.float(), self.height, self.width) - latents = flux_pipeline_with_transformer( - height=self.height, - width=self.width, - num_inference_steps=self.num_steps, - guidance_scale=self.guidance, - generator=torch.Generator().manual_seed(self.seed), - prompt_embeds=t5_embeddings, - pooled_prompt_embeds=clip_embeddings, - output_type="latent", - return_dict=False, - )[0] + return x - assert isinstance(latents, torch.Tensor) - return latents + def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert an input image in latent space to patches for diffusion. + + This implementation was extracted from: + https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 + + Returns: + tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. + """ + bs, c, h, w = latent_img.shape + + # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. + img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + # Generate patch position ids. + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + return img, img_ids def _run_vae_decoding( self, context: InvocationContext, - flux_model_dir: Path, + flux_ae_path: Path, latents: torch.Tensor, ) -> Image.Image: - vae_path = flux_model_dir / "vae" - with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae: - assert isinstance(vae, AutoencoderKL) + with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae: + assert isinstance(vae, AutoEncoder) + # TODO(ryand): Test that this works with both float16 and bfloat16. + with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()): + img = vae.decode(latents) - flux_pipeline_with_vae = FluxPipeline( - scheduler=None, - vae=vae, - text_encoder=None, - tokenizer=None, - text_encoder_2=None, - tokenizer_2=None, - transformer=None, - ) + img.clamp(-1, 1) + img = rearrange(img[0], "c h w -> h w c") + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) - latents = flux_pipeline_with_vae._unpack_latents( - latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor - ) - latents = ( - latents / flux_pipeline_with_vae.vae.config.scaling_factor - ) + flux_pipeline_with_vae.vae.config.shift_factor - latents = latents.to(dtype=vae.dtype) - image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0] - image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0] - - assert isinstance(image, Image.Image) - return image + return img_pil def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: + inference_dtype = TorchDevice.choose_torch_dtype() if self.quantization_type == "raw": - model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) - elif self.quantization_type == "NF4": - model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) - with accelerate.init_empty_weights(): - empty_model = FluxTransformer2DModel.from_config(model_config) - assert isinstance(empty_model, FluxTransformer2DModel) + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. + params = flux_configs["flux-schnell"].params - model_nf4_path = path / "bnb_nf4" - assert model_nf4_path.exists() + # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): - model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + model = Flux(params).to(inference_dtype) + + state_dict = load_file(path) + # TODO(ryand): Cast the state_dict to the appropriate dtype? + model.load_state_dict(state_dict, strict=True, assign=True) + elif self.quantization_type == "NF4": + model_path = path.parent / "bnb_nf4.safetensors" + + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. + params = flux_configs["flux-schnell"].params + # Initialize the model on the "meta" device. + with accelerate.init_empty_weights(): + model = Flux(params) + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle # this on GPUs without bfloat16 support. - sd = load_file(model_nf4_path / "model.safetensors") - model.load_state_dict(sd, strict=True, assign=True) - elif self.quantization_type == "llm_int8": - model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) - with accelerate.init_empty_weights(): - empty_model = FluxTransformer2DModel.from_config(model_config) - assert isinstance(empty_model, FluxTransformer2DModel) - model_int8_path = path / "bnb_llm_int8" - assert model_int8_path.exists() - with accelerate.init_empty_weights(): - model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set()) + state_dict = load_file(model_path) + model.load_state_dict(state_dict, strict=True, assign=True) - sd = load_file(model_int8_path / "model.safetensors") - model.load_state_dict(sd, strict=True, assign=True) + elif self.quantization_type == "llm_int8": + raise NotImplementedError("LLM int8 quantization is not yet supported.") + # model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) + # with accelerate.init_empty_weights(): + # empty_model = FluxTransformer2DModel.from_config(model_config) + # assert isinstance(empty_model, FluxTransformer2DModel) + # model_int8_path = path / "bnb_llm_int8" + # assert model_int8_path.exists() + # with accelerate.init_empty_weights(): + # model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set()) + + # sd = load_file(model_int8_path / "model.safetensors") + # model.load_state_dict(sd, strict=True, assign=True) else: raise ValueError(f"Unsupported quantization type: {self.quantization_type}") @@ -202,7 +243,12 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): return model @staticmethod - def _load_flux_vae(path: Path) -> AutoencoderKL: - model = AutoencoderKL.from_pretrained(path, local_files_only=True) - assert isinstance(model, AutoencoderKL) - return model + def _load_flux_vae(path: Path) -> AutoEncoder: + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. + ae_params = flux_configs["flux1-schnell"].ae_params + with accelerate.init_empty_weights(): + ae = AutoEncoder(ae_params) + + state_dict = load_file(path) + ae.load_state_dict(state_dict, strict=True, assign=True) + return ae diff --git a/invokeai/backend/quantization/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/load_flux_model_bnb_nf4.py index b55c56a032..80f3e71901 100644 --- a/invokeai/backend/quantization/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/quantization/load_flux_model_bnb_nf4.py @@ -4,7 +4,8 @@ from pathlib import Path import accelerate import torch -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from flux.model import Flux +from flux.util import configs as flux_configs from safetensors.torch import load_file, save_file from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 @@ -22,22 +23,24 @@ def log_time(name: str): def main(): - # Load the FLUX transformer model onto the meta device. model_path = Path( - "/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/" + "/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors" ) + # inference_dtype = torch.bfloat16 with log_time("Intialize FLUX transformer on meta device"): - model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True) + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. + params = flux_configs["flux-schnell"].params + + # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): - empty_model = FluxTransformer2DModel.from_config(model_config) - assert isinstance(empty_model, FluxTransformer2DModel) + model = Flux(params) # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. modules_to_not_convert: set[str] = set() - model_nf4_path = model_path / "bnb_nf4" + model_nf4_path = model_path.parent / "bnb_nf4.safetensors" if model_nf4_path.exists(): # The quantized model already exists, load it and return it. print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...") @@ -45,12 +48,12 @@ def main(): # Replace the linear layers with NF4 quantized linear layers (still on the meta device). with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): model = quantize_model_nf4( - empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 + model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 ) with log_time("Load state dict into model"): - sd = load_file(model_nf4_path / "model.safetensors") - model.load_state_dict(sd, strict=True, assign=True) + state_dict = load_file(model_nf4_path) + model.load_state_dict(state_dict, strict=True, assign=True) with log_time("Move model to cuda"): model = model.to("cuda") @@ -63,30 +66,24 @@ def main(): with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): model = quantize_model_nf4( - empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 + model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 ) with log_time("Load state dict into model"): - # Load sharded state dict. - files = list(model_path.glob("*.safetensors")) - state_dict = dict() - for file in files: - sd = load_file(file) - state_dict.update(sd) - + state_dict = load_file(model_path) + # TODO(ryand): Cast the state_dict to the appropriate dtype? model.load_state_dict(state_dict, strict=True, assign=True) with log_time("Move model to cuda and quantize"): model = model.to("cuda") with log_time("Save quantized model"): - model_nf4_path.mkdir(parents=True, exist_ok=True) - output_path = model_nf4_path / "model.safetensors" - save_file(model.state_dict(), output_path) + model_nf4_path.parent.mkdir(parents=True, exist_ok=True) + save_file(model.state_dict(), model_nf4_path) - print(f"Successfully quantized and saved model to '{output_path}'.") + print(f"Successfully quantized and saved model to '{model_nf4_path}'.") - assert isinstance(model, FluxTransformer2DModel) + assert isinstance(model, Flux) return model diff --git a/pyproject.toml b/pyproject.toml index 93394f2955..3561c658b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "controlnet-aux==0.0.7", # TODO(ryand): Bump this once the next diffusers release is ready. "diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3", + "flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3", "invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids "mediapipe==0.10.7", # needed for "mediapipeface" controlnet model "numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()