WIP on moving from diffusers to FLUX

This commit is contained in:
Ryan Dick
2024-08-16 20:22:49 +00:00
committed by Brandon
parent d3a5ca5247
commit 1fa6bddc89
3 changed files with 153 additions and 109 deletions

View File

@ -4,7 +4,8 @@ from pathlib import Path
import accelerate
import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from flux.model import Flux
from flux.util import configs as flux_configs
from safetensors.torch import load_file, save_file
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@ -22,22 +23,24 @@ def log_time(name: str):
def main():
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
# inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"):
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
empty_model = FluxTransformer2DModel.from_config(model_config)
assert isinstance(empty_model, FluxTransformer2DModel)
model = Flux(params)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_nf4_path = model_path / "bnb_nf4"
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
if model_nf4_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
@ -45,12 +48,12 @@ def main():
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
sd = load_file(model_nf4_path / "model.safetensors")
model.load_state_dict(sd, strict=True, assign=True)
state_dict = load_file(model_nf4_path)
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
@ -63,30 +66,24 @@ def main():
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
# Load sharded state dict.
files = list(model_path.glob("*.safetensors"))
state_dict = dict()
for file in files:
sd = load_file(file)
state_dict.update(sd)
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_nf4_path.mkdir(parents=True, exist_ok=True)
output_path = model_nf4_path / "model.safetensors"
save_file(model.state_dict(), output_path)
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_nf4_path)
print(f"Successfully quantized and saved model to '{output_path}'.")
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
assert isinstance(model, FluxTransformer2DModel)
assert isinstance(model, Flux)
return model