mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP on moving from diffusers to FLUX
This commit is contained in:
parent
3e8a550fab
commit
d7a39a4d67
@ -3,9 +3,12 @@ from typing import Literal
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from einops import rearrange, repeat
|
||||
from flux.model import Flux
|
||||
from flux.modules.autoencoder import AutoEncoder
|
||||
from flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||
from flux.util import configs as flux_configs
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
@ -21,11 +24,11 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
TFluxModelKeys = Literal["flux-schnell"]
|
||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||
@ -70,7 +73,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
flux_transformer_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
|
||||
)
|
||||
flux_ae_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
|
||||
)
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
@ -78,123 +87,155 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
image = self._run_vae_decoding(context, model_path, latents)
|
||||
latents = self._run_diffusion(
|
||||
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
|
||||
)
|
||||
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_model_dir: Path,
|
||||
flux_transformer_path: Path,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
# Prepare input noise.
|
||||
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
||||
# CPU RNG?
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "shnell" in str(flux_transformer_path)
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
transformer_path = flux_model_dir / "transformer"
|
||||
with context.models.load_local_model(
|
||||
model_path=transformer_path, loader=self._load_flux_transformer
|
||||
model_path=flux_transformer_path, loader=self._load_flux_transformer
|
||||
) as transformer:
|
||||
assert isinstance(transformer, FluxTransformer2DModel)
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
flux_pipeline_with_transformer = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=None,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
transformer=transformer,
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
t5_embeddings = t5_embeddings.to(dtype=dtype)
|
||||
clip_embeddings = clip_embeddings.to(dtype=dtype)
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
latents = flux_pipeline_with_transformer(
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
num_inference_steps=self.num_steps,
|
||||
guidance_scale=self.guidance,
|
||||
generator=torch.Generator().manual_seed(self.seed),
|
||||
prompt_embeds=t5_embeddings,
|
||||
pooled_prompt_embeds=clip_embeddings,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return x
|
||||
|
||||
assert isinstance(latents, torch.Tensor)
|
||||
return latents
|
||||
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_model_dir: Path,
|
||||
flux_ae_path: Path,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
vae_path = flux_model_dir / "vae"
|
||||
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||
img = vae.decode(latents)
|
||||
|
||||
flux_pipeline_with_vae = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
transformer=None,
|
||||
)
|
||||
img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
latents = flux_pipeline_with_vae._unpack_latents(
|
||||
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
||||
)
|
||||
latents = (
|
||||
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
||||
) + flux_pipeline_with_vae.vae.config.shift_factor
|
||||
latents = latents.to(dtype=vae.dtype)
|
||||
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
||||
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
||||
|
||||
assert isinstance(image, Image.Image)
|
||||
return image
|
||||
return img_pil
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
if self.quantization_type == "raw":
|
||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
elif self.quantization_type == "NF4":
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
|
||||
model_nf4_path = path / "bnb_nf4"
|
||||
assert model_nf4_path.exists()
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
model = Flux(params).to(inference_dtype)
|
||||
|
||||
state_dict = load_file(path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
elif self.quantization_type == "NF4":
|
||||
model_path = path.parent / "bnb_nf4.safetensors"
|
||||
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
elif self.quantization_type == "llm_int8":
|
||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
model_int8_path = path / "bnb_llm_int8"
|
||||
assert model_int8_path.exists()
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
state_dict = load_file(model_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
sd = load_file(model_int8_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
elif self.quantization_type == "llm_int8":
|
||||
raise NotImplementedError("LLM int8 quantization is not yet supported.")
|
||||
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
# with accelerate.init_empty_weights():
|
||||
# empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
# assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
# model_int8_path = path / "bnb_llm_int8"
|
||||
# assert model_int8_path.exists()
|
||||
# with accelerate.init_empty_weights():
|
||||
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
|
||||
# sd = load_file(model_int8_path / "model.safetensors")
|
||||
# model.load_state_dict(sd, strict=True, assign=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||
|
||||
@ -202,7 +243,12 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_vae(path: Path) -> AutoencoderKL:
|
||||
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, AutoencoderKL)
|
||||
return model
|
||||
def _load_flux_vae(path: Path) -> AutoEncoder:
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
ae_params = flux_configs["flux1-schnell"].ae_params
|
||||
with accelerate.init_empty_weights():
|
||||
ae = AutoEncoder(ae_params)
|
||||
|
||||
state_dict = load_file(path)
|
||||
ae.load_state_dict(state_dict, strict=True, assign=True)
|
||||
return ae
|
||||
|
@ -4,7 +4,8 @@ from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from flux.model import Flux
|
||||
from flux.util import configs as flux_configs
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
@ -22,22 +23,24 @@ def log_time(name: str):
|
||||
|
||||
|
||||
def main():
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
model = Flux(params)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_nf4_path = model_path / "bnb_nf4"
|
||||
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
|
||||
if model_nf4_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||
@ -45,12 +48,12 @@ def main():
|
||||
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_nf4_path / "model.safetensors")
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
state_dict = load_file(model_nf4_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
@ -63,30 +66,24 @@ def main():
|
||||
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
# Load sharded state dict.
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
state_dict = dict()
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
||||
output_path = model_nf4_path / "model.safetensors"
|
||||
save_file(model.state_dict(), output_path)
|
||||
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_nf4_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{output_path}'.")
|
||||
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -40,6 +40,7 @@ dependencies = [
|
||||
"controlnet-aux==0.0.7",
|
||||
# TODO(ryand): Bump this once the next diffusers release is ready.
|
||||
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
|
||||
"flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||
|
Loading…
Reference in New Issue
Block a user