Make quantized loading fast for both T5XXL and FLUX transformer.

This commit is contained in:
Ryan Dick 2024-08-09 19:54:09 +00:00 committed by Brandon
parent 8b1cef978c
commit eeabb7ebe5
3 changed files with 142 additions and 3 deletions

View File

@ -6,7 +6,6 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import qfloat8
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers.models.auto import AutoModelForTextEncoding
@ -15,17 +14,19 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.util.devices import TorchDevice
TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedModelForTextEncoding(QuantizedTransformersModel):
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
auto_class = AutoModelForTextEncoding

View File

@ -0,0 +1,77 @@
import json
import os
from typing import Union
from diffusers.models.model_loading_utils import load_state_dict
from diffusers.utils import (
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
_get_checkpoint_shard_files,
is_accelerate_available,
)
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from invokeai.backend.requantize import requantize
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.base_class is None:
raise ValueError("The `base_class` attribute needs to be configured.")
if not is_accelerate_available():
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
from accelerate import init_empty_weights
if os.path.isdir(model_name_or_path):
# Look for a quantization map
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
if not os.path.exists(qmap_path):
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
# Look for original model config file.
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
if not os.path.exists(model_config_path):
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
with open(qmap_path, "r", encoding="utf-8") as f:
qmap = json.load(f)
with open(model_config_path, "r", encoding="utf-8") as f:
original_model_cls_name = json.load(f)["_class_name"]
configured_cls_name = cls.base_class.__name__
if configured_cls_name != original_model_cls_name:
raise ValueError(
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
)
# Create an empty model
config = cls.base_class.load_config(model_name_or_path)
with init_empty_weights():
model = cls.base_class.from_config(config)
# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(checkpoint_file):
# Convert the checkpoint path to a list of shards
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
# Create a mapping for the sharded safetensor files
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
else:
# Look for a single checkpoint file
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
if not os.path.exists(checkpoint_file):
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
# Get state_dict from model checkpoint
state_dict = load_state_dict(checkpoint_file)
# Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap)
model.eval()
return cls(model)
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

View File

@ -0,0 +1,61 @@
import json
import os
from typing import Union
from optimum.quanto.models import QuantizedTransformersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from transformers import AutoConfig
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from invokeai.backend.requantize import requantize
class FastQuantizedTransformersModel(QuantizedTransformersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.auto_class is None:
raise ValueError(
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
)
if not is_accelerate_available():
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
from accelerate import init_empty_weights
if os.path.isdir(model_name_or_path):
# Look for a quantization map
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
if not os.path.exists(qmap_path):
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
with open(qmap_path, "r", encoding="utf-8") as f:
qmap = json.load(f)
# Create an empty model
config = AutoConfig.from_pretrained(model_name_or_path)
with init_empty_weights():
model = cls.auto_class.from_config(config)
# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(checkpoint_file):
# Convert the checkpoint path to a list of shards
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
# Create a mapping for the sharded safetensor files
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
else:
# Look for a single checkpoint file
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
if not os.path.exists(checkpoint_file):
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
# Get state_dict from model checkpoint
state_dict = load_state_dict(checkpoint_file)
# Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap)
if getattr(model.config, "tie_word_embeddings", True):
# Tie output weight embeddings to input weight embeddings
# Note that if they were quantized they would NOT be tied
model.tie_weights()
# Set model in evaluation mode as it is done in transformers
model.eval()
return cls(model)
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")