Manage quantization of models within the loader

This commit is contained in:
Brandon Rising
2024-08-12 18:01:42 -04:00
committed by Brandon
parent 1d8545a76c
commit 56fda669fd
9 changed files with 130 additions and 237 deletions

View File

@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
if module in [
"diffusers",
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines

View File

@ -9,7 +9,7 @@ from typing import Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@ -50,6 +50,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
),
):
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
),
):
return len(model)
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.

View File

@ -12,15 +12,17 @@ from diffusers.utils import (
)
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from invokeai.backend.requantize import requantize
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.base_class is None:
base_class = base_class or cls.base_class
if base_class is None:
raise ValueError("The `base_class` attribute needs to be configured.")
if not is_accelerate_available():
@ -43,16 +45,16 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
with open(model_config_path, "r", encoding="utf-8") as f:
original_model_cls_name = json.load(f)["_class_name"]
configured_cls_name = cls.base_class.__name__
configured_cls_name = base_class.__name__
if configured_cls_name != original_model_cls_name:
raise ValueError(
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
)
# Create an empty model
config = cls.base_class.load_config(model_name_or_path)
config = base_class.load_config(model_name_or_path)
with init_empty_weights():
model = cls.base_class.from_config(config)
model = base_class.from_config(config)
# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
@ -72,6 +74,6 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
# Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap)
model.eval()
return cls(model)
return cls(model)._wrapped
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

View File

@ -1,5 +1,6 @@
import json
import os
import torch
from typing import Union
from optimum.quanto.models import QuantizedTransformersModel
@ -7,15 +8,17 @@ from optimum.quanto.models.shared_dict import ShardedStateDict
from transformers import AutoConfig
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from transformers.models.auto import AutoModelForTextEncoding
from invokeai.backend.requantize import requantize
class FastQuantizedTransformersModel(QuantizedTransformersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], auto_class = AutoModelForTextEncoding, **kwargs):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.auto_class is None:
auto_class = auto_class or cls.auto_class
if auto_class is None:
raise ValueError(
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
)
@ -33,7 +36,7 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
# Create an empty model
config = AutoConfig.from_pretrained(model_name_or_path)
with init_empty_weights():
model = cls.auto_class.from_config(config)
model = auto_class.from_config(config)
# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(checkpoint_file):
@ -56,6 +59,6 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
model.tie_weights()
# Set model in evaluation mode as it is done in transformers
model.eval()
return cls(model)
return cls(model)._wrapped
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")