Manage quantization of models within the loader

This commit is contained in:
Brandon Rising 2024-08-12 18:01:42 -04:00 committed by Brandon
parent 1d8545a76c
commit 56fda669fd
9 changed files with 130 additions and 237 deletions

View File

@ -126,6 +126,7 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor" negative_cond = "Negative conditioning tensor"
noise = "Noise tensor" noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5Encoder = "T5 tokenizer and text encoder"
unet = "UNet (scheduler, LoRAs)" unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer" transformer = "Transformer"
vae = "VAE" vae = "VAE"

View File

@ -6,8 +6,10 @@ from optimum.quanto import qfloat8
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@ -22,9 +24,15 @@ from invokeai.backend.util.devices import TorchDevice
version="1.0.0", version="1.0.0",
) )
class FluxTextEncoderInvocation(BaseInvocation): class FluxTextEncoderInvocation(BaseInvocation):
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") clip: CLIPField = InputField(
use_8bit: bool = InputField( title="CLIP",
default=False, description="Whether to quantize the transformer model to 8-bit precision." description=FieldDescriptions.clip,
input=Input.Connection,
)
t5Encoder: T5EncoderField = InputField(
title="T5EncoderField",
description=FieldDescriptions.t5Encoder,
input=Input.Connection,
) )
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.") positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
@ -32,47 +40,43 @@ class FluxTextEncoderInvocation(BaseInvocation):
# compatible with other ConditioningOutputs. # compatible with other ConditioningOutputs.
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path) t5_embeddings, clip_embeddings = self._encode_prompt(context)
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
) )
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# TODO: Determine the T5 max sequence length based on the model.
# if self.model == "flux-schnell":
max_seq_len = 256
# # elif self.model == "flux-dev":
# # max_seq_len = 512
# else:
# raise ValueError(f"Unknown model: {self.model}")
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]: # Load CLIP.
# Determine the T5 max sequence length based on the model. clip_tokenizer_info = context.models.load(self.clip.tokenizer)
if self.model == "flux-schnell": clip_text_encoder_info = context.models.load(self.clip.text_encoder)
max_seq_len = 256
# elif self.model == "flux-dev":
# max_seq_len = 512
else:
raise ValueError(f"Unknown model: {self.model}")
# Load the CLIP tokenizer. # Load T5.
clip_tokenizer_path = flux_model_dir / "tokenizer" t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True) t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
assert isinstance(clip_tokenizer, CLIPTokenizer)
# Load the T5 tokenizer.
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
assert isinstance(t5_tokenizer, T5TokenizerFast)
clip_text_encoder_path = flux_model_dir / "text_encoder"
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
with ( with (
context.models.load_local_model( clip_text_encoder_info as clip_text_encoder,
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder t5_text_encoder_info as t5_text_encoder,
) as clip_text_encoder, clip_tokenizer_info as clip_tokenizer,
context.models.load_local_model( t5_tokenizer_info as t5_tokenizer,
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
) as t5_text_encoder,
): ):
assert isinstance(clip_text_encoder, CLIPTextModel) assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(t5_text_encoder, T5EncoderModel) assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
assert isinstance(t5_tokenizer, T5TokenizerFast)
pipeline = FluxPipeline( pipeline = FluxPipeline(
scheduler=None, scheduler=None,
vae=None, vae=None,
@ -85,7 +89,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
# prompt_embeds: T5 embeddings # prompt_embeds: T5 embeddings
# pooled_prompt_embeds: CLIP embeddings # pooled_prompt_embeds: CLIP embeddings
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=self.positive_prompt, prompt=self.positive_prompt,
prompt_2=self.positive_prompt, prompt_2=self.positive_prompt,
device=TorchDevice.choose_torch_device(), device=TorchDevice.choose_torch_device(),
@ -95,41 +99,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(prompt_embeds, torch.Tensor) assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor) assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
@staticmethod
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, CLIPTextModel)
return model
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a T5EncoderModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): dtype?
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
model_8bit_path.mkdir(parents=True, exist_ok=True)
q_model.save_pretrained(model_8bit_path)
# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
else:
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)
return model

View File

@ -6,7 +6,7 @@ import accelerate
import torch import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import TransformerField, VAEField
from optimum.quanto import qfloat8 from optimum.quanto import qfloat8
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
@ -52,17 +52,14 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model.""" """Text-to-image generation using a FLUX model."""
flux_model: ModelIdentifierField = InputField( transformer: TransformerField = InputField(
description="The Flux model", description=FieldDescriptions.unet,
input=Input.Any, input=Input.Connection,
ui_type=UIType.FluxMainModel title="Transformer",
) )
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") vae: VAEField = InputField(
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField( description=FieldDescriptions.vae,
default="raw", description="The type of quantization to use for the transformer model." input=Input.Connection,
)
use_8bit: bool = InputField(
default=False, description="Whether to quantize the transformer model to 8-bit precision."
) )
positive_text_conditioning: ConditioningField = InputField( positive_text_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection description=FieldDescriptions.positive_cond, input=Input.Connection
@ -78,13 +75,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
flux_transformer_path = context.models.download_and_cache_model(
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
)
flux_ae_path = context.models.download_and_cache_model(
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
)
# Load the conditioning data. # Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
@ -92,56 +82,31 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
flux_conditioning = cond_data.conditionings[0] flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo) assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion( latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds image = self._run_vae_decoding(context, latents)
)
image = self._run_vae_decoding(context, flux_ae_path, latents)
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
def _run_diffusion( def _run_diffusion(
self, self,
context: InvocationContext, context: InvocationContext,
flux_transformer_path: Path,
clip_embeddings: torch.Tensor, clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor, t5_embeddings: torch.Tensor,
): ):
inference_dtype = TorchDevice.choose_torch_dtype() scheduler_info = context.models.load(self.transformer.scheduler)
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise.
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
# CPU RNG?
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in str(flux_transformer_path)
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty. # if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) # context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with context.models.load_local_model( with (
model_path=flux_transformer_path, loader=self._load_flux_transformer transformer_info as transformer,
) as transformer: scheduler_info as scheduler
assert isinstance(transformer, Flux) ):
assert isinstance(transformer, FluxTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
x = denoise( x = denoise(
model=transformer, model=transformer,
@ -185,75 +150,25 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _run_vae_decoding( def _run_vae_decoding(
self, self,
context: InvocationContext, context: InvocationContext,
flux_ae_path: Path,
latents: torch.Tensor, latents: torch.Tensor,
) -> Image.Image: ) -> Image.Image:
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae: vae_info = context.models.load(self.vae.vae)
assert isinstance(vae, AutoEncoder) with vae_info as vae:
# TODO(ryand): Test that this works with both float16 and bfloat16. assert isinstance(vae, AutoencoderKL)
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
img = vae.decode(latents)
img.clamp(-1, 1) img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil latents = flux_pipeline_with_vae._unpack_latents(
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
)
latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: assert isinstance(image, Image.Image)
inference_dtype = TorchDevice.choose_torch_dtype() return image
if self.quantization_type == "raw":
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(params).to(inference_dtype)
state_dict = load_file(path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
elif self.quantization_type == "NF4":
model_path = path.parent / "bnb_nf4.safetensors"
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
# this on GPUs without bfloat16 support.
state_dict = load_file(model_path)
model.load_state_dict(state_dict, strict=True, assign=True)
elif self.quantization_type == "llm_int8":
raise NotImplementedError("LLM int8 quantization is not yet supported.")
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
# with accelerate.init_empty_weights():
# empty_model = FluxTransformer2DModel.from_config(model_config)
# assert isinstance(empty_model, FluxTransformer2DModel)
# model_int8_path = path / "bnb_llm_int8"
# assert model_int8_path.exists()
# with accelerate.init_empty_weights():
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
# sd = load_file(model_int8_path / "model.safetensors")
# model.load_state_dict(sd, strict=True, assign=True)
else:
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
assert isinstance(model, FluxTransformer2DModel)
return model
@staticmethod
def _load_flux_vae(path: Path) -> AutoEncoder:
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
ae_params = flux_configs["flux1-schnell"].ae_params
with accelerate.init_empty_weights():
ae = AutoEncoder(ae_params)
state_dict = load_file(path)
ae.load_state_dict(state_dict, strict=True, assign=True)
return ae

View File

@ -65,6 +65,10 @@ class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel") transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class VAEField(BaseModel): class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel") vae: ModelIdentifierField = Field(description="Info to load vae submodel")
@ -133,8 +137,8 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output""" """Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@ -166,7 +170,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
return FluxModelLoaderOutput( return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler), transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
vae=VAEField(vae=vae), vae=VAEField(vae=vae),
) )

View File

@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling # TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]: if module in [
"diffusers",
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
]:
res_type = sys.modules[module] res_type = sys.modules[module]
else: else:
res_type = sys.modules["diffusers"].pipelines res_type = sys.modules["diffusers"].pipelines

View File

@ -9,7 +9,7 @@ from typing import Optional
import torch import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer from transformers import CLIPTokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@ -50,6 +50,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
), ),
): ):
return model.calc_size() return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
),
):
return len(model)
else: else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types. # supported model types.

View File

@ -12,15 +12,17 @@ from diffusers.utils import (
) )
from optimum.quanto.models import QuantizedDiffusersModel from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict from optimum.quanto.models.shared_dict import ShardedStateDict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from invokeai.backend.requantize import requantize from invokeai.backend.requantize import requantize
class FastQuantizedDiffusersModel(QuantizedDiffusersModel): class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
@classmethod @classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.base_class is None: base_class = base_class or cls.base_class
if base_class is None:
raise ValueError("The `base_class` attribute needs to be configured.") raise ValueError("The `base_class` attribute needs to be configured.")
if not is_accelerate_available(): if not is_accelerate_available():
@ -43,16 +45,16 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
with open(model_config_path, "r", encoding="utf-8") as f: with open(model_config_path, "r", encoding="utf-8") as f:
original_model_cls_name = json.load(f)["_class_name"] original_model_cls_name = json.load(f)["_class_name"]
configured_cls_name = cls.base_class.__name__ configured_cls_name = base_class.__name__
if configured_cls_name != original_model_cls_name: if configured_cls_name != original_model_cls_name:
raise ValueError( raise ValueError(
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
) )
# Create an empty model # Create an empty model
config = cls.base_class.load_config(model_name_or_path) config = base_class.load_config(model_name_or_path)
with init_empty_weights(): with init_empty_weights():
model = cls.base_class.from_config(config) model = base_class.from_config(config)
# Look for the index of a sharded checkpoint # Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
@ -72,6 +74,6 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
# Requantize and load quantized weights from state_dict # Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap) requantize(model, state_dict=state_dict, quantization_map=qmap)
model.eval() model.eval()
return cls(model) return cls(model)._wrapped
else: else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

View File

@ -1,5 +1,6 @@
import json import json
import os import os
import torch
from typing import Union from typing import Union
from optimum.quanto.models import QuantizedTransformersModel from optimum.quanto.models import QuantizedTransformersModel
@ -7,15 +8,17 @@ from optimum.quanto.models.shared_dict import ShardedStateDict
from transformers import AutoConfig from transformers import AutoConfig
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from transformers.models.auto import AutoModelForTextEncoding
from invokeai.backend.requantize import requantize from invokeai.backend.requantize import requantize
class FastQuantizedTransformersModel(QuantizedTransformersModel): class FastQuantizedTransformersModel(QuantizedTransformersModel):
@classmethod @classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], auto_class = AutoModelForTextEncoding, **kwargs):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.auto_class is None: auto_class = auto_class or cls.auto_class
if auto_class is None:
raise ValueError( raise ValueError(
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead." "Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
) )
@ -33,7 +36,7 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
# Create an empty model # Create an empty model
config = AutoConfig.from_pretrained(model_name_or_path) config = AutoConfig.from_pretrained(model_name_or_path)
with init_empty_weights(): with init_empty_weights():
model = cls.auto_class.from_config(config) model = auto_class.from_config(config)
# Look for the index of a sharded checkpoint # Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
@ -56,6 +59,6 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
model.tie_weights() model.tie_weights()
# Set model in evaluation mode as it is done in transformers # Set model in evaluation mode as it is done in transformers
model.eval() model.eval()
return cls(model) return cls(model)._wrapped
else: else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

View File

@ -5697,15 +5697,15 @@ export type components = {
*/ */
transformer: components["schemas"]["TransformerField"]; transformer: components["schemas"]["TransformerField"];
/** /**
* CLIP 1 * CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/ */
clip: components["schemas"]["CLIPField"]; clip: components["schemas"]["CLIPField"];
/** /**
* CLIP 2 * T5 Encoder
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count * @description T5 tokenizer and text encoder
*/ */
clip2: components["schemas"]["CLIPField"]; t5Encoder: components["schemas"]["T5EncoderField"];
/** /**
* VAE * VAE
* @description VAE * @description VAE
@ -5739,19 +5739,17 @@ export type components = {
*/ */
use_cache?: boolean; use_cache?: boolean;
/** /**
* Model * CLIP
* @description The FLUX model to use for text-to-image generation. * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
* @default null * @default null
* @constant
* @enum {string}
*/ */
model?: "flux-schnell"; clip?: components["schemas"]["CLIPField"];
/** /**
* Use 8Bit * T5EncoderField
* @description Whether to quantize the transformer model to 8-bit precision. * @description T5 tokenizer and text encoder
* @default false * @default null
*/ */
use_8bit?: boolean; t5Encoder?: components["schemas"]["T5EncoderField"];
/** /**
* Positive Prompt * Positive Prompt
* @description Positive prompt for text-to-image generation. * @description Positive prompt for text-to-image generation.
@ -5799,31 +5797,16 @@ export type components = {
*/ */
use_cache?: boolean; use_cache?: boolean;
/** /**
* @description The Flux model * Transformer
* @description UNet (scheduler, LoRAs)
* @default null * @default null
*/ */
flux_model?: components["schemas"]["ModelIdentifierField"]; transformer?: components["schemas"]["TransformerField"];
/** /**
* Model * @description VAE
* @description The FLUX model to use for text-to-image generation.
* @default null * @default null
* @constant
* @enum {string}
*/ */
model?: "flux-schnell"; vae?: components["schemas"]["VAEField"];
/**
* Quantization Type
* @description The type of quantization to use for the transformer model.
* @default raw
* @enum {string}
*/
quantization_type?: "raw" | "NF4" | "llm_int8";
/**
* Use 8Bit
* @description Whether to quantize the transformer model to 8-bit precision.
* @default false
*/
use_8bit?: boolean;
/** /**
* @description Positive conditioning tensor * @description Positive conditioning tensor
* @default null * @default null
@ -14268,6 +14251,13 @@ export type components = {
*/ */
type: "t2i_adapter_output"; type: "t2i_adapter_output";
}; };
/** T5EncoderField */
T5EncoderField: {
/** @description Info to load tokenizer submodel */
tokenizer: components["schemas"]["ModelIdentifierField"];
/** @description Info to load text_encoder submodel */
text_encoder: components["schemas"]["ModelIdentifierField"];
};
/** TBLR */ /** TBLR */
TBLR: { TBLR: {
/** Top */ /** Top */