mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove all references to optimum-quanto and downgrade diffusers.
This commit is contained in:
parent
75d8ac378c
commit
97562504b7
@ -1,79 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from diffusers.models.model_loading_utils import load_state_dict
|
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
|
||||||
from diffusers.utils import (
|
|
||||||
CONFIG_NAME,
|
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
SAFETENSORS_WEIGHTS_NAME,
|
|
||||||
_get_checkpoint_shard_files,
|
|
||||||
is_accelerate_available,
|
|
||||||
)
|
|
||||||
from optimum.quanto.models import QuantizedDiffusersModel
|
|
||||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
|
||||||
|
|
||||||
from invokeai.backend.quantization.requantize import requantize
|
|
||||||
|
|
||||||
|
|
||||||
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class=FluxTransformer2DModel, **kwargs):
|
|
||||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
|
||||||
base_class = base_class or cls.base_class
|
|
||||||
if base_class is None:
|
|
||||||
raise ValueError("The `base_class` attribute needs to be configured.")
|
|
||||||
|
|
||||||
if not is_accelerate_available():
|
|
||||||
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
|
|
||||||
if os.path.isdir(model_name_or_path):
|
|
||||||
# Look for a quantization map
|
|
||||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
|
||||||
if not os.path.exists(qmap_path):
|
|
||||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
|
||||||
|
|
||||||
# Look for original model config file.
|
|
||||||
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
|
|
||||||
if not os.path.exists(model_config_path):
|
|
||||||
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
|
|
||||||
|
|
||||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
|
||||||
qmap = json.load(f)
|
|
||||||
|
|
||||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
|
||||||
original_model_cls_name = json.load(f)["_class_name"]
|
|
||||||
configured_cls_name = base_class.__name__
|
|
||||||
if configured_cls_name != original_model_cls_name:
|
|
||||||
raise ValueError(
|
|
||||||
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create an empty model
|
|
||||||
config = base_class.load_config(model_name_or_path)
|
|
||||||
with init_empty_weights():
|
|
||||||
model = base_class.from_config(config)
|
|
||||||
|
|
||||||
# Look for the index of a sharded checkpoint
|
|
||||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
if os.path.exists(checkpoint_file):
|
|
||||||
# Convert the checkpoint path to a list of shards
|
|
||||||
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
|
||||||
# Create a mapping for the sharded safetensor files
|
|
||||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
|
||||||
else:
|
|
||||||
# Look for a single checkpoint file
|
|
||||||
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
|
|
||||||
if not os.path.exists(checkpoint_file):
|
|
||||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
|
||||||
# Get state_dict from model checkpoint
|
|
||||||
state_dict = load_state_dict(checkpoint_file)
|
|
||||||
|
|
||||||
# Requantize and load quantized weights from state_dict
|
|
||||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
|
||||||
model.eval()
|
|
||||||
return cls(model)._wrapped
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
|
@ -1,65 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from optimum.quanto.models import QuantizedTransformersModel
|
|
||||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
|
||||||
from transformers import AutoConfig
|
|
||||||
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
|
||||||
from transformers.models.auto import AutoModelForTextEncoding
|
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
|
||||||
|
|
||||||
from invokeai.backend.quantization.requantize import requantize
|
|
||||||
|
|
||||||
|
|
||||||
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(
|
|
||||||
cls, model_name_or_path: Union[str, os.PathLike], auto_class=AutoModelForTextEncoding, **kwargs
|
|
||||||
):
|
|
||||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
|
||||||
auto_class = auto_class or cls.auto_class
|
|
||||||
if auto_class is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
|
||||||
)
|
|
||||||
if not is_accelerate_available():
|
|
||||||
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
|
|
||||||
if os.path.isdir(model_name_or_path):
|
|
||||||
# Look for a quantization map
|
|
||||||
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
|
|
||||||
if not os.path.exists(qmap_path):
|
|
||||||
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
|
|
||||||
with open(qmap_path, "r", encoding="utf-8") as f:
|
|
||||||
qmap = json.load(f)
|
|
||||||
# Create an empty model
|
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
||||||
with init_empty_weights():
|
|
||||||
model = auto_class.from_config(config)
|
|
||||||
# Look for the index of a sharded checkpoint
|
|
||||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
if os.path.exists(checkpoint_file):
|
|
||||||
# Convert the checkpoint path to a list of shards
|
|
||||||
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
|
|
||||||
# Create a mapping for the sharded safetensor files
|
|
||||||
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
|
|
||||||
else:
|
|
||||||
# Look for a single checkpoint file
|
|
||||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
|
|
||||||
if not os.path.exists(checkpoint_file):
|
|
||||||
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
|
|
||||||
# Get state_dict from model checkpoint
|
|
||||||
state_dict = load_state_dict(checkpoint_file)
|
|
||||||
# Requantize and load quantized weights from state_dict
|
|
||||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
|
||||||
if getattr(model.config, "tie_word_embeddings", True):
|
|
||||||
# Tie output weight embeddings to input weight embeddings
|
|
||||||
# Note that if they were quantized they would NOT be tied
|
|
||||||
model.tie_weights()
|
|
||||||
# Set model in evaluation mode as it is done in transformers
|
|
||||||
model.eval()
|
|
||||||
return cls(model)._wrapped
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
|
@ -1,56 +0,0 @@
|
|||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from optimum.quanto.quantize import _quantize_submodule
|
|
||||||
|
|
||||||
|
|
||||||
def requantize(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
state_dict: Dict[str, Any],
|
|
||||||
quantization_map: Dict[str, Dict[str, str]],
|
|
||||||
device: torch.device | None = None,
|
|
||||||
):
|
|
||||||
"""This function was initially copied from:
|
|
||||||
https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101
|
|
||||||
|
|
||||||
The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the
|
|
||||||
weights are about to be loaded from a state_dict.
|
|
||||||
|
|
||||||
TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library.
|
|
||||||
"""
|
|
||||||
if device is None:
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
if device.type == "meta":
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
# Quantize the model with parameters from the quantization map
|
|
||||||
for name, m in model.named_modules():
|
|
||||||
qconfig = quantization_map.get(name, None)
|
|
||||||
if qconfig is not None:
|
|
||||||
weights = qconfig["weights"]
|
|
||||||
if weights == "none":
|
|
||||||
weights = None
|
|
||||||
activations = qconfig["activations"]
|
|
||||||
if activations == "none":
|
|
||||||
activations = None
|
|
||||||
_quantize_submodule(model, name, m, weights=weights, activations=activations)
|
|
||||||
|
|
||||||
# Move model parameters and buffers to CPU before materializing quantized weights
|
|
||||||
for name, m in model.named_modules():
|
|
||||||
|
|
||||||
def move_tensor(t, device):
|
|
||||||
if t.device.type == "meta":
|
|
||||||
return torch.empty_like(t, device=device)
|
|
||||||
return t.to(device)
|
|
||||||
|
|
||||||
for name, param in m.named_parameters(recurse=False):
|
|
||||||
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
|
|
||||||
for name, param in m.named_buffers(recurse=False):
|
|
||||||
setattr(m, name, move_tensor(param, "cpu"))
|
|
||||||
|
|
||||||
# Freeze model and move to target device
|
|
||||||
# freeze(model)
|
|
||||||
# model.to(device)
|
|
||||||
|
|
||||||
# Load the quantized model weights
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
@ -38,8 +38,7 @@ dependencies = [
|
|||||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.2",
|
"compel==2.0.2",
|
||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
# TODO(ryand): Bump this once the next diffusers release is ready.
|
"diffusers[torch]==0.27.2",
|
||||||
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
|
|
||||||
"flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3",
|
"flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3",
|
||||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||||
@ -47,7 +46,6 @@ dependencies = [
|
|||||||
"onnx==1.15.0",
|
"onnx==1.15.0",
|
||||||
"onnxruntime==1.16.3",
|
"onnxruntime==1.16.3",
|
||||||
"opencv-python==4.9.0.80",
|
"opencv-python==4.9.0.80",
|
||||||
"optimum-quanto==0.2.4",
|
|
||||||
"pytorch-lightning==2.1.3",
|
"pytorch-lightning==2.1.3",
|
||||||
"safetensors==0.4.3",
|
"safetensors==0.4.3",
|
||||||
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
||||||
|
Loading…
Reference in New Issue
Block a user