mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - experimentation
This commit is contained in:
parent
3cf0365a35
commit
152da482cd
129
invokeai/backend/load_flux_model.py
Normal file
129
invokeai/backend/load_flux_model.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.model_loading_utils import load_state_dict
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
|
from diffusers.utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
_get_checkpoint_shard_files,
|
||||||
|
is_accelerate_available,
|
||||||
|
)
|
||||||
|
from optimum.quanto import qfloat8
|
||||||
|
from optimum.quanto.models import QuantizedDiffusersModel
|
||||||
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
|
|
||||||
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
||||||
|
base_class = FluxTransformer2DModel
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||||
|
if cls.base_class is None:
|
||||||
|
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||||
|
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
if os.path.isdir(model_name_or_path):
|
||||||
|
# Look for a quantization map
|
||||||
|
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||||
|
if not os.path.exists(qmap_path):
|
||||||
|
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||||
|
|
||||||
|
# Look for original model config file.
|
||||||
|
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||||
|
if not os.path.exists(model_config_path):
|
||||||
|
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||||
|
|
||||||
|
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||||
|
qmap = json.load(f)
|
||||||
|
|
||||||
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||||
|
original_model_cls_name = json.load(f)["_class_name"]
|
||||||
|
configured_cls_name = cls.base_class.__name__
|
||||||
|
if configured_cls_name != original_model_cls_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an empty model
|
||||||
|
config = cls.base_class.load_config(model_name_or_path)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = cls.base_class.from_config(config)
|
||||||
|
|
||||||
|
# Look for the index of a sharded checkpoint
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
# Convert the checkpoint path to a list of shards
|
||||||
|
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||||
|
# Create a mapping for the sharded safetensor files
|
||||||
|
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||||
|
else:
|
||||||
|
# Look for a single checkpoint file
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||||
|
if not os.path.exists(checkpoint_file):
|
||||||
|
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||||
|
# Get state_dict from model checkpoint
|
||||||
|
state_dict = load_state_dict(checkpoint_file)
|
||||||
|
|
||||||
|
# Requantize and load quantized weights from state_dict
|
||||||
|
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||||
|
model.eval()
|
||||||
|
return cls(model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||||
|
|
||||||
|
|
||||||
|
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
|
||||||
|
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
model_8bit_path = path / "quantized"
|
||||||
|
if model_8bit_path.exists():
|
||||||
|
# The quantized model exists, load it.
|
||||||
|
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||||
|
# something that we should be able to make much faster.
|
||||||
|
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# Access the underlying wrapped model.
|
||||||
|
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||||
|
# always returning a FluxTransformer2DModel from this function.
|
||||||
|
model = q_model._wrapped
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
|
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||||
|
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||||
|
# here.
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
|
||||||
|
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
|
||||||
|
|
||||||
|
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
q_model.save_pretrained(model_8bit_path)
|
||||||
|
|
||||||
|
# (See earlier comment about accessing the wrapped model.)
|
||||||
|
model = q_model._wrapped
|
||||||
|
|
||||||
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
start = time.time()
|
||||||
|
model = load_flux_transformer(
|
||||||
|
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
|
||||||
|
)
|
||||||
|
print(f"Time to load: {time.time() - start}s")
|
||||||
|
print("hi")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
54
invokeai/backend/requantize.py
Normal file
54
invokeai/backend/requantize.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from optimum.quanto.nn import QModuleMixin
|
||||||
|
from optimum.quanto.quantize import _quantize_submodule, freeze
|
||||||
|
|
||||||
|
|
||||||
|
def custom_freeze(model: torch.nn.Module):
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if isinstance(m, QModuleMixin):
|
||||||
|
m.freeze()
|
||||||
|
|
||||||
|
|
||||||
|
def requantize(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
quantization_map: Dict[str, Dict[str, str]],
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
if device is None:
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
if device.type == "meta":
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Quantize the model with parameters from the quantization map
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
qconfig = quantization_map.get(name, None)
|
||||||
|
if qconfig is not None:
|
||||||
|
weights = qconfig["weights"]
|
||||||
|
if weights == "none":
|
||||||
|
weights = None
|
||||||
|
activations = qconfig["activations"]
|
||||||
|
if activations == "none":
|
||||||
|
activations = None
|
||||||
|
_quantize_submodule(model, name, m, weights=weights, activations=activations)
|
||||||
|
|
||||||
|
# Move model parameters and buffers to CPU before materializing quantized weights
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
|
||||||
|
def move_tensor(t, device):
|
||||||
|
if t.device.type == "meta":
|
||||||
|
return torch.empty_like(t, device=device)
|
||||||
|
return t.to(device)
|
||||||
|
|
||||||
|
for name, param in m.named_parameters(recurse=False):
|
||||||
|
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
|
||||||
|
for name, param in m.named_buffers(recurse=False):
|
||||||
|
setattr(m, name, move_tensor(param, "cpu"))
|
||||||
|
# Freeze model and move to target device
|
||||||
|
freeze(model)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# Load the quantized model weights
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
Loading…
Reference in New Issue
Block a user