mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update load_flux_model_bnb_llm_int8.py to work with a single-file FLUX transformer checkpoint.
This commit is contained in:
parent
19a68afb3a
commit
4105a78b83
@ -1,7 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from flux.model import Flux
|
||||||
|
from flux.util import configs as flux_configs
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
@ -11,30 +12,32 @@ from invokeai.backend.quantization.load_flux_model_bnb_nf4 import log_time
|
|||||||
def main():
|
def main():
|
||||||
# Load the FLUX transformer model onto the meta device.
|
# Load the FLUX transformer model onto the meta device.
|
||||||
model_path = Path(
|
model_path = Path(
|
||||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||||
)
|
)
|
||||||
|
|
||||||
with log_time("Initialize FLUX transformer on meta device"):
|
with log_time("Intialize FLUX transformer on meta device"):
|
||||||
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
|
params = flux_configs["flux-schnell"].params
|
||||||
|
|
||||||
|
# Initialize the model on the "meta" device.
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
model = Flux(params)
|
||||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
|
||||||
|
|
||||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
modules_to_not_convert: set[str] = set()
|
modules_to_not_convert: set[str] = set()
|
||||||
|
|
||||||
model_int8_path = model_path / "bnb_llm_int8"
|
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
|
||||||
if model_int8_path.exists():
|
if model_int8_path.exists():
|
||||||
# The quantized model already exists, load it and return it.
|
# The quantized model already exists, load it and return it.
|
||||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||||
|
|
||||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
with log_time("Load state dict into model"):
|
with log_time("Load state dict into model"):
|
||||||
sd = load_file(model_int8_path / "model.safetensors")
|
sd = load_file(model_int8_path)
|
||||||
model.load_state_dict(sd, strict=True, assign=True)
|
model.load_state_dict(sd, strict=True, assign=True)
|
||||||
|
|
||||||
with log_time("Move model to cuda"):
|
with log_time("Move model to cuda"):
|
||||||
@ -47,29 +50,23 @@ def main():
|
|||||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||||
|
|
||||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=modules_to_not_convert)
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
with log_time("Load state dict into model"):
|
with log_time("Load state dict into model"):
|
||||||
# Load sharded state dict.
|
state_dict = load_file(model_path)
|
||||||
files = list(model_path.glob("*.safetensors"))
|
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||||
state_dict = {}
|
|
||||||
for file in files:
|
|
||||||
sd = load_file(file)
|
|
||||||
state_dict.update(sd)
|
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
with log_time("Move model to cuda and quantize"):
|
with log_time("Move model to cuda and quantize"):
|
||||||
model = model.to("cuda")
|
model = model.to("cuda")
|
||||||
|
|
||||||
with log_time("Save quantized model"):
|
with log_time("Save quantized model"):
|
||||||
model_int8_path.mkdir(parents=True, exist_ok=True)
|
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
output_path = model_int8_path / "model.safetensors"
|
save_file(model.state_dict(), model_int8_path)
|
||||||
save_file(model.state_dict(), output_path)
|
|
||||||
|
|
||||||
print(f"Successfully quantized and saved model to '{output_path}'.")
|
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||||
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, Flux)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user