mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make quantized loading fast for both T5XXL and FLUX transformer.
This commit is contained in:
parent
d23ad1818d
commit
a8a2fc106d
@ -6,7 +6,6 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
|||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
from optimum.quanto import qfloat8
|
from optimum.quanto import qfloat8
|
||||||
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
from transformers.models.auto import AutoModelForTextEncoding
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
@ -15,17 +14,19 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|||||||
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||||
|
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
TFluxModelKeys = Literal["flux-schnell"]
|
TFluxModelKeys = Literal["flux-schnell"]
|
||||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||||
|
|
||||||
|
|
||||||
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
|
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
|
||||||
base_class = FluxTransformer2DModel
|
base_class = FluxTransformer2DModel
|
||||||
|
|
||||||
|
|
||||||
class QuantizedModelForTextEncoding(QuantizedTransformersModel):
|
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
||||||
auto_class = AutoModelForTextEncoding
|
auto_class = AutoModelForTextEncoding
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,77 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from diffusers.models.model_loading_utils import load_state_dict
|
||||||
|
from diffusers.utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
_get_checkpoint_shard_files,
|
||||||
|
is_accelerate_available,
|
||||||
|
)
|
||||||
|
from optimum.quanto.models import QuantizedDiffusersModel
|
||||||
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
|
|
||||||
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
|
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||||
|
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||||
|
if cls.base_class is None:
|
||||||
|
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||||
|
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
if os.path.isdir(model_name_or_path):
|
||||||
|
# Look for a quantization map
|
||||||
|
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||||
|
if not os.path.exists(qmap_path):
|
||||||
|
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||||
|
|
||||||
|
# Look for original model config file.
|
||||||
|
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
||||||
|
if not os.path.exists(model_config_path):
|
||||||
|
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
||||||
|
|
||||||
|
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||||
|
qmap = json.load(f)
|
||||||
|
|
||||||
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||||
|
original_model_cls_name = json.load(f)["_class_name"]
|
||||||
|
configured_cls_name = cls.base_class.__name__
|
||||||
|
if configured_cls_name != original_model_cls_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an empty model
|
||||||
|
config = cls.base_class.load_config(model_name_or_path)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = cls.base_class.from_config(config)
|
||||||
|
|
||||||
|
# Look for the index of a sharded checkpoint
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
# Convert the checkpoint path to a list of shards
|
||||||
|
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||||
|
# Create a mapping for the sharded safetensor files
|
||||||
|
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||||
|
else:
|
||||||
|
# Look for a single checkpoint file
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
||||||
|
if not os.path.exists(checkpoint_file):
|
||||||
|
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||||
|
# Get state_dict from model checkpoint
|
||||||
|
state_dict = load_state_dict(checkpoint_file)
|
||||||
|
|
||||||
|
# Requantize and load quantized weights from state_dict
|
||||||
|
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||||
|
model.eval()
|
||||||
|
return cls(model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
@ -0,0 +1,61 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from optimum.quanto.models import QuantizedTransformersModel
|
||||||
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||||
|
|
||||||
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
|
||||||
|
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||||
|
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||||
|
if cls.auto_class is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
||||||
|
)
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
if os.path.isdir(model_name_or_path):
|
||||||
|
# Look for a quantization map
|
||||||
|
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
||||||
|
if not os.path.exists(qmap_path):
|
||||||
|
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
||||||
|
with open(qmap_path, "r", encoding="utf-8") as f:
|
||||||
|
qmap = json.load(f)
|
||||||
|
# Create an empty model
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = cls.auto_class.from_config(config)
|
||||||
|
# Look for the index of a sharded checkpoint
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
# Convert the checkpoint path to a list of shards
|
||||||
|
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
||||||
|
# Create a mapping for the sharded safetensor files
|
||||||
|
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
||||||
|
else:
|
||||||
|
# Look for a single checkpoint file
|
||||||
|
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||||
|
if not os.path.exists(checkpoint_file):
|
||||||
|
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
||||||
|
# Get state_dict from model checkpoint
|
||||||
|
state_dict = load_state_dict(checkpoint_file)
|
||||||
|
# Requantize and load quantized weights from state_dict
|
||||||
|
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||||
|
if getattr(model.config, "tie_word_embeddings", True):
|
||||||
|
# Tie output weight embeddings to input weight embeddings
|
||||||
|
# Note that if they were quantized they would NOT be tied
|
||||||
|
model.tie_weights()
|
||||||
|
# Set model in evaluation mode as it is done in transformers
|
||||||
|
model.eval()
|
||||||
|
return cls(model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
Loading…
Reference in New Issue
Block a user