diff --git a/invokeai/backend/load_flux_model.py b/invokeai/backend/load_flux_model.py new file mode 100644 index 0000000000..9273122396 --- /dev/null +++ b/invokeai/backend/load_flux_model.py @@ -0,0 +1,129 @@ +import json +import os +import time +from pathlib import Path +from typing import Union + +import torch +from diffusers.models.model_loading_utils import load_state_dict +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from diffusers.utils import ( + CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_WEIGHTS_NAME, + _get_checkpoint_shard_files, + is_accelerate_available, +) +from optimum.quanto import qfloat8 +from optimum.quanto.models import QuantizedDiffusersModel +from optimum.quanto.models.shared_dict import ShardedStateDict + +from invokeai.backend.requantize import requantize + + +class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel): + base_class = FluxTransformer2DModel + + @classmethod + def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): + if cls.base_class is None: + raise ValueError("The `base_class` attribute needs to be configured.") + + if not is_accelerate_available(): + raise ValueError("Reloading a quantized diffusers model requires the accelerate library.") + from accelerate import init_empty_weights + + if os.path.isdir(model_name_or_path): + # Look for a quantization map + qmap_path = os.path.join(model_name_or_path, cls._qmap_name()) + if not os.path.exists(qmap_path): + raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?") + + # Look for original model config file. + model_config_path = os.path.join(model_name_or_path, CONFIG_NAME) + if not os.path.exists(model_config_path): + raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.") + + with open(qmap_path, "r", encoding="utf-8") as f: + qmap = json.load(f) + + with open(model_config_path, "r", encoding="utf-8") as f: + original_model_cls_name = json.load(f)["_class_name"] + configured_cls_name = cls.base_class.__name__ + if configured_cls_name != original_model_cls_name: + raise ValueError( + f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." + ) + + # Create an empty model + config = cls.base_class.load_config(model_name_or_path) + with init_empty_weights(): + model = cls.base_class.from_config(config) + + # Look for the index of a sharded checkpoint + checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + if os.path.exists(checkpoint_file): + # Convert the checkpoint path to a list of shards + _, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file) + # Create a mapping for the sharded safetensor files + state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"]) + else: + # Look for a single checkpoint file + checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME) + if not os.path.exists(checkpoint_file): + raise ValueError(f"No safetensor weights found in {model_name_or_path}.") + # Get state_dict from model checkpoint + state_dict = load_state_dict(checkpoint_file) + + # Requantize and load quantized weights from state_dict + requantize(model, state_dict=state_dict, quantization_map=qmap) + model.eval() + return cls(model) + else: + raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") + + +def load_flux_transformer(path: Path) -> FluxTransformer2DModel: + # model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) + model_8bit_path = path / "quantized" + if model_8bit_path.exists(): + # The quantized model exists, load it. + # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like + # something that we should be able to make much faster. + q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path) + + # Access the underlying wrapped model. + # We access the wrapped model, even though it is private, because it simplifies the type checking by + # always returning a FluxTransformer2DModel from this function. + model = q_model._wrapped + else: + # The quantized model does not exist yet, quantize and save it. + # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on + # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it + # here. + model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) + assert isinstance(model, FluxTransformer2DModel) + + q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8) + + model_8bit_path.mkdir(parents=True, exist_ok=True) + q_model.save_pretrained(model_8bit_path) + + # (See earlier comment about accessing the wrapped model.) + model = q_model._wrapped + + assert isinstance(model, FluxTransformer2DModel) + return model + + +def main(): + start = time.time() + model = load_flux_transformer( + Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/") + ) + print(f"Time to load: {time.time() - start}s") + print("hi") + + +if __name__ == "__main__": + main() diff --git a/invokeai/backend/requantize.py b/invokeai/backend/requantize.py new file mode 100644 index 0000000000..0e9356b60b --- /dev/null +++ b/invokeai/backend/requantize.py @@ -0,0 +1,54 @@ +from typing import Any, Dict + +import torch +from optimum.quanto.nn import QModuleMixin +from optimum.quanto.quantize import _quantize_submodule, freeze + + +def custom_freeze(model: torch.nn.Module): + for name, m in model.named_modules(): + if isinstance(m, QModuleMixin): + m.freeze() + + +def requantize( + model: torch.nn.Module, + state_dict: Dict[str, Any], + quantization_map: Dict[str, Dict[str, str]], + device: torch.device = None, +): + if device is None: + device = next(model.parameters()).device + if device.type == "meta": + device = torch.device("cpu") + + # Quantize the model with parameters from the quantization map + for name, m in model.named_modules(): + qconfig = quantization_map.get(name, None) + if qconfig is not None: + weights = qconfig["weights"] + if weights == "none": + weights = None + activations = qconfig["activations"] + if activations == "none": + activations = None + _quantize_submodule(model, name, m, weights=weights, activations=activations) + + # Move model parameters and buffers to CPU before materializing quantized weights + for name, m in model.named_modules(): + + def move_tensor(t, device): + if t.device.type == "meta": + return torch.empty_like(t, device=device) + return t.to(device) + + for name, param in m.named_parameters(recurse=False): + setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu"))) + for name, param in m.named_buffers(recurse=False): + setattr(m, name, move_tensor(param, "cpu")) + # Freeze model and move to target device + freeze(model) + model.to(device) + + # Load the quantized model weights + model.load_state_dict(state_dict, strict=False)