diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index a680908461..0a7290214d 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -6,7 +6,6 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from optimum.quanto import qfloat8 -from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel from PIL import Image from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers.models.auto import AutoModelForTextEncoding @@ -15,17 +14,19 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel +from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.util.devices import TorchDevice TFluxModelKeys = Literal["flux-schnell"] FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"} -class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel): +class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel): base_class = FluxTransformer2DModel -class QuantizedModelForTextEncoding(QuantizedTransformersModel): +class QuantizedModelForTextEncoding(FastQuantizedTransformersModel): auto_class = AutoModelForTextEncoding diff --git a/invokeai/backend/quantization/fast_quantized_diffusion_model.py b/invokeai/backend/quantization/fast_quantized_diffusion_model.py new file mode 100644 index 0000000000..0759984bf9 --- /dev/null +++ b/invokeai/backend/quantization/fast_quantized_diffusion_model.py @@ -0,0 +1,77 @@ +import json +import os +from typing import Union + +from diffusers.models.model_loading_utils import load_state_dict +from diffusers.utils import ( + CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_WEIGHTS_NAME, + _get_checkpoint_shard_files, + is_accelerate_available, +) +from optimum.quanto.models import QuantizedDiffusersModel +from optimum.quanto.models.shared_dict import ShardedStateDict + +from invokeai.backend.requantize import requantize + + +class FastQuantizedDiffusersModel(QuantizedDiffusersModel): + @classmethod + def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): + """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" + if cls.base_class is None: + raise ValueError("The `base_class` attribute needs to be configured.") + + if not is_accelerate_available(): + raise ValueError("Reloading a quantized diffusers model requires the accelerate library.") + from accelerate import init_empty_weights + + if os.path.isdir(model_name_or_path): + # Look for a quantization map + qmap_path = os.path.join(model_name_or_path, cls._qmap_name()) + if not os.path.exists(qmap_path): + raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?") + + # Look for original model config file. + model_config_path = os.path.join(model_name_or_path, CONFIG_NAME) + if not os.path.exists(model_config_path): + raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.") + + with open(qmap_path, "r", encoding="utf-8") as f: + qmap = json.load(f) + + with open(model_config_path, "r", encoding="utf-8") as f: + original_model_cls_name = json.load(f)["_class_name"] + configured_cls_name = cls.base_class.__name__ + if configured_cls_name != original_model_cls_name: + raise ValueError( + f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." + ) + + # Create an empty model + config = cls.base_class.load_config(model_name_or_path) + with init_empty_weights(): + model = cls.base_class.from_config(config) + + # Look for the index of a sharded checkpoint + checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + if os.path.exists(checkpoint_file): + # Convert the checkpoint path to a list of shards + _, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file) + # Create a mapping for the sharded safetensor files + state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"]) + else: + # Look for a single checkpoint file + checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME) + if not os.path.exists(checkpoint_file): + raise ValueError(f"No safetensor weights found in {model_name_or_path}.") + # Get state_dict from model checkpoint + state_dict = load_state_dict(checkpoint_file) + + # Requantize and load quantized weights from state_dict + requantize(model, state_dict=state_dict, quantization_map=qmap) + model.eval() + return cls(model) + else: + raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") diff --git a/invokeai/backend/quantization/fast_quantized_transformers_model.py b/invokeai/backend/quantization/fast_quantized_transformers_model.py new file mode 100644 index 0000000000..ce5cc7a3a9 --- /dev/null +++ b/invokeai/backend/quantization/fast_quantized_transformers_model.py @@ -0,0 +1,61 @@ +import json +import os +from typing import Union + +from optimum.quanto.models import QuantizedTransformersModel +from optimum.quanto.models.shared_dict import ShardedStateDict +from transformers import AutoConfig +from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available + +from invokeai.backend.requantize import requantize + + +class FastQuantizedTransformersModel(QuantizedTransformersModel): + @classmethod + def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): + """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" + if cls.auto_class is None: + raise ValueError( + "Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead." + ) + if not is_accelerate_available(): + raise ValueError("Reloading a quantized transformers model requires the accelerate library.") + from accelerate import init_empty_weights + + if os.path.isdir(model_name_or_path): + # Look for a quantization map + qmap_path = os.path.join(model_name_or_path, cls._qmap_name()) + if not os.path.exists(qmap_path): + raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?") + with open(qmap_path, "r", encoding="utf-8") as f: + qmap = json.load(f) + # Create an empty model + config = AutoConfig.from_pretrained(model_name_or_path) + with init_empty_weights(): + model = cls.auto_class.from_config(config) + # Look for the index of a sharded checkpoint + checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + if os.path.exists(checkpoint_file): + # Convert the checkpoint path to a list of shards + checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file) + # Create a mapping for the sharded safetensor files + state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"]) + else: + # Look for a single checkpoint file + checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME) + if not os.path.exists(checkpoint_file): + raise ValueError(f"No safetensor weights found in {model_name_or_path}.") + # Get state_dict from model checkpoint + state_dict = load_state_dict(checkpoint_file) + # Requantize and load quantized weights from state_dict + requantize(model, state_dict=state_dict, quantization_map=qmap) + if getattr(model.config, "tie_word_embeddings", True): + # Tie output weight embeddings to input weight embeddings + # Note that if they were quantized they would NOT be tied + model.tie_weights() + # Set model in evaluation mode as it is done in transformers + model.eval() + return cls(model) + else: + raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")