From 97562504b73ffd10860f4909a48495c864aceef4 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 23 Aug 2024 17:48:29 +0000 Subject: [PATCH] Remove all references to optimum-quanto and downgrade diffusers. --- .../fast_quantized_diffusion_model.py | 79 ------------------- .../fast_quantized_transformers_model.py | 65 --------------- invokeai/backend/quantization/requantize.py | 56 ------------- pyproject.toml | 4 +- 4 files changed, 1 insertion(+), 203 deletions(-) delete mode 100644 invokeai/backend/quantization/fast_quantized_diffusion_model.py delete mode 100644 invokeai/backend/quantization/fast_quantized_transformers_model.py delete mode 100644 invokeai/backend/quantization/requantize.py diff --git a/invokeai/backend/quantization/fast_quantized_diffusion_model.py b/invokeai/backend/quantization/fast_quantized_diffusion_model.py deleted file mode 100644 index 6ad82b8e9e..0000000000 --- a/invokeai/backend/quantization/fast_quantized_diffusion_model.py +++ /dev/null @@ -1,79 +0,0 @@ -import json -import os -from typing import Union - -from diffusers.models.model_loading_utils import load_state_dict -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from diffusers.utils import ( - CONFIG_NAME, - SAFE_WEIGHTS_INDEX_NAME, - SAFETENSORS_WEIGHTS_NAME, - _get_checkpoint_shard_files, - is_accelerate_available, -) -from optimum.quanto.models import QuantizedDiffusersModel -from optimum.quanto.models.shared_dict import ShardedStateDict - -from invokeai.backend.quantization.requantize import requantize - - -class FastQuantizedDiffusersModel(QuantizedDiffusersModel): - @classmethod - def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class=FluxTransformer2DModel, **kwargs): - """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" - base_class = base_class or cls.base_class - if base_class is None: - raise ValueError("The `base_class` attribute needs to be configured.") - - if not is_accelerate_available(): - raise ValueError("Reloading a quantized diffusers model requires the accelerate library.") - from accelerate import init_empty_weights - - if os.path.isdir(model_name_or_path): - # Look for a quantization map - qmap_path = os.path.join(model_name_or_path, cls._qmap_name()) - if not os.path.exists(qmap_path): - raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?") - - # Look for original model config file. - model_config_path = os.path.join(model_name_or_path, CONFIG_NAME) - if not os.path.exists(model_config_path): - raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.") - - with open(qmap_path, "r", encoding="utf-8") as f: - qmap = json.load(f) - - with open(model_config_path, "r", encoding="utf-8") as f: - original_model_cls_name = json.load(f)["_class_name"] - configured_cls_name = base_class.__name__ - if configured_cls_name != original_model_cls_name: - raise ValueError( - f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." - ) - - # Create an empty model - config = base_class.load_config(model_name_or_path) - with init_empty_weights(): - model = base_class.from_config(config) - - # Look for the index of a sharded checkpoint - checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - if os.path.exists(checkpoint_file): - # Convert the checkpoint path to a list of shards - _, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file) - # Create a mapping for the sharded safetensor files - state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"]) - else: - # Look for a single checkpoint file - checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME) - if not os.path.exists(checkpoint_file): - raise ValueError(f"No safetensor weights found in {model_name_or_path}.") - # Get state_dict from model checkpoint - state_dict = load_state_dict(checkpoint_file) - - # Requantize and load quantized weights from state_dict - requantize(model, state_dict=state_dict, quantization_map=qmap) - model.eval() - return cls(model)._wrapped - else: - raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") diff --git a/invokeai/backend/quantization/fast_quantized_transformers_model.py b/invokeai/backend/quantization/fast_quantized_transformers_model.py deleted file mode 100644 index b811b598e7..0000000000 --- a/invokeai/backend/quantization/fast_quantized_transformers_model.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -import os -from typing import Union - -from optimum.quanto.models import QuantizedTransformersModel -from optimum.quanto.models.shared_dict import ShardedStateDict -from transformers import AutoConfig -from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict -from transformers.models.auto import AutoModelForTextEncoding -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available - -from invokeai.backend.quantization.requantize import requantize - - -class FastQuantizedTransformersModel(QuantizedTransformersModel): - @classmethod - def from_pretrained( - cls, model_name_or_path: Union[str, os.PathLike], auto_class=AutoModelForTextEncoding, **kwargs - ): - """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" - auto_class = auto_class or cls.auto_class - if auto_class is None: - raise ValueError( - "Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead." - ) - if not is_accelerate_available(): - raise ValueError("Reloading a quantized transformers model requires the accelerate library.") - from accelerate import init_empty_weights - - if os.path.isdir(model_name_or_path): - # Look for a quantization map - qmap_path = os.path.join(model_name_or_path, cls._qmap_name()) - if not os.path.exists(qmap_path): - raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?") - with open(qmap_path, "r", encoding="utf-8") as f: - qmap = json.load(f) - # Create an empty model - config = AutoConfig.from_pretrained(model_name_or_path) - with init_empty_weights(): - model = auto_class.from_config(config) - # Look for the index of a sharded checkpoint - checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - if os.path.exists(checkpoint_file): - # Convert the checkpoint path to a list of shards - checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file) - # Create a mapping for the sharded safetensor files - state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"]) - else: - # Look for a single checkpoint file - checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME) - if not os.path.exists(checkpoint_file): - raise ValueError(f"No safetensor weights found in {model_name_or_path}.") - # Get state_dict from model checkpoint - state_dict = load_state_dict(checkpoint_file) - # Requantize and load quantized weights from state_dict - requantize(model, state_dict=state_dict, quantization_map=qmap) - if getattr(model.config, "tie_word_embeddings", True): - # Tie output weight embeddings to input weight embeddings - # Note that if they were quantized they would NOT be tied - model.tie_weights() - # Set model in evaluation mode as it is done in transformers - model.eval() - return cls(model)._wrapped - else: - raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") diff --git a/invokeai/backend/quantization/requantize.py b/invokeai/backend/quantization/requantize.py deleted file mode 100644 index aae85bed7c..0000000000 --- a/invokeai/backend/quantization/requantize.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Any, Dict - -import torch -from optimum.quanto.quantize import _quantize_submodule - - -def requantize( - model: torch.nn.Module, - state_dict: Dict[str, Any], - quantization_map: Dict[str, Dict[str, str]], - device: torch.device | None = None, -): - """This function was initially copied from: - https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101 - - The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the - weights are about to be loaded from a state_dict. - - TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library. - """ - if device is None: - device = next(model.parameters()).device - if device.type == "meta": - device = torch.device("cpu") - - # Quantize the model with parameters from the quantization map - for name, m in model.named_modules(): - qconfig = quantization_map.get(name, None) - if qconfig is not None: - weights = qconfig["weights"] - if weights == "none": - weights = None - activations = qconfig["activations"] - if activations == "none": - activations = None - _quantize_submodule(model, name, m, weights=weights, activations=activations) - - # Move model parameters and buffers to CPU before materializing quantized weights - for name, m in model.named_modules(): - - def move_tensor(t, device): - if t.device.type == "meta": - return torch.empty_like(t, device=device) - return t.to(device) - - for name, param in m.named_parameters(recurse=False): - setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu"))) - for name, param in m.named_buffers(recurse=False): - setattr(m, name, move_tensor(param, "cpu")) - - # Freeze model and move to target device - # freeze(model) - # model.to(device) - - # Load the quantized model weights - model.load_state_dict(state_dict, strict=False) diff --git a/pyproject.toml b/pyproject.toml index 5be3117af3..1537bf7e6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,7 @@ dependencies = [ "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==2.0.2", "controlnet-aux==0.0.7", - # TODO(ryand): Bump this once the next diffusers release is ready. - "diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3", + "diffusers[torch]==0.27.2", "flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3", "invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids "mediapipe==0.10.7", # needed for "mediapipeface" controlnet model @@ -47,7 +46,6 @@ dependencies = [ "onnx==1.15.0", "onnxruntime==1.16.3", "opencv-python==4.9.0.80", - "optimum-quanto==0.2.4", "pytorch-lightning==2.1.3", "safetensors==0.4.3", # sentencepiece is required to load T5TokenizerFast (used by FLUX).