mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Manage quantization of models within the loader
This commit is contained in:
parent
1d8545a76c
commit
56fda669fd
@ -126,6 +126,7 @@ class FieldDescriptions:
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5Encoder = "T5 tokenizer and text encoder"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
|
@ -6,8 +6,10 @@ from optimum.quanto import qfloat8
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
|
||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
@ -22,9 +24,15 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5Encoder: T5EncoderField = InputField(
|
||||
title="T5EncoderField",
|
||||
description=FieldDescriptions.t5Encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||
|
||||
@ -32,47 +40,43 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
# compatible with other ConditioningOutputs.
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO: Determine the T5 max sequence length based on the model.
|
||||
# if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# # elif self.model == "flux-dev":
|
||||
# # max_seq_len = 512
|
||||
# else:
|
||||
# raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine the T5 max sequence length based on the model.
|
||||
if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# elif self.model == "flux-dev":
|
||||
# max_seq_len = 512
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
# Load CLIP.
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
# Load the CLIP tokenizer.
|
||||
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
# Load T5.
|
||||
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
|
||||
|
||||
# Load the T5 tokenizer.
|
||||
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||
|
||||
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
||||
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||
with (
|
||||
context.models.load_local_model(
|
||||
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
||||
) as clip_text_encoder,
|
||||
context.models.load_local_model(
|
||||
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||
) as t5_text_encoder,
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||
|
||||
pipeline = FluxPipeline(
|
||||
scheduler=None,
|
||||
vae=None,
|
||||
@ -85,7 +89,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
# prompt_embeds: T5 embeddings
|
||||
# pooled_prompt_embeds: CLIP embeddings
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
|
||||
prompt=self.positive_prompt,
|
||||
prompt_2=self.positive_prompt,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
@ -95,41 +99,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
|
||||
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
return model
|
||||
|
||||
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
|
||||
if self.use_8bit:
|
||||
model_8bit_path = path / "quantized"
|
||||
if model_8bit_path.exists():
|
||||
# The quantized model exists, load it.
|
||||
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
|
||||
# something that we should be able to make much faster.
|
||||
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
|
||||
|
||||
# Access the underlying wrapped model.
|
||||
# We access the wrapped model, even though it is private, because it simplifies the type checking by
|
||||
# always returning a T5EncoderModel from this function.
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
# The quantized model does not exist yet, quantize and save it.
|
||||
# TODO(ryand): dtype?
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
|
||||
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
|
||||
|
||||
model_8bit_path.mkdir(parents=True, exist_ok=True)
|
||||
q_model.save_pretrained(model_8bit_path)
|
||||
|
||||
# (See earlier comment about accessing the wrapped model.)
|
||||
model = q_model._wrapped
|
||||
else:
|
||||
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
|
@ -6,7 +6,7 @@ import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from optimum.quanto import qfloat8
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
@ -52,17 +52,14 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
flux_model: ModelIdentifierField = InputField(
|
||||
description="The Flux model",
|
||||
input=Input.Any,
|
||||
ui_type=UIType.FluxMainModel
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
||||
default="raw", description="The type of quantization to use for the transformer model."
|
||||
)
|
||||
use_8bit: bool = InputField(
|
||||
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_text_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
@ -78,13 +75,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||
flux_transformer_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
|
||||
)
|
||||
flux_ae_path = context.models.download_and_cache_model(
|
||||
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
|
||||
)
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
@ -92,56 +82,31 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(
|
||||
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
|
||||
)
|
||||
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_transformer_path: Path,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
# Prepare input noise.
|
||||
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
||||
# CPU RNG?
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "shnell" in str(flux_transformer_path)
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
scheduler_info = context.models.load(self.transformer.scheduler)
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with context.models.load_local_model(
|
||||
model_path=flux_transformer_path, loader=self._load_flux_transformer
|
||||
) as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
with (
|
||||
transformer_info as transformer,
|
||||
scheduler_info as scheduler
|
||||
):
|
||||
assert isinstance(transformer, FluxTransformer2DModel)
|
||||
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
@ -185,75 +150,25 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
flux_ae_path: Path,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||
img = vae.decode(latents)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
||||
latents = flux_pipeline_with_vae._unpack_latents(
|
||||
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
||||
)
|
||||
latents = (
|
||||
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
||||
) + flux_pipeline_with_vae.vae.config.shift_factor
|
||||
latents = latents.to(dtype=vae.dtype)
|
||||
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
||||
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
||||
|
||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
if self.quantization_type == "raw":
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params).to(inference_dtype)
|
||||
|
||||
state_dict = load_file(path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
elif self.quantization_type == "NF4":
|
||||
model_path = path.parent / "bnb_nf4.safetensors"
|
||||
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
state_dict = load_file(model_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
elif self.quantization_type == "llm_int8":
|
||||
raise NotImplementedError("LLM int8 quantization is not yet supported.")
|
||||
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||
# with accelerate.init_empty_weights():
|
||||
# empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||
# assert isinstance(empty_model, FluxTransformer2DModel)
|
||||
# model_int8_path = path / "bnb_llm_int8"
|
||||
# assert model_int8_path.exists()
|
||||
# with accelerate.init_empty_weights():
|
||||
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||
|
||||
# sd = load_file(model_int8_path / "model.safetensors")
|
||||
# model.load_state_dict(sd, strict=True, assign=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||
|
||||
assert isinstance(model, FluxTransformer2DModel)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_flux_vae(path: Path) -> AutoEncoder:
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
ae_params = flux_configs["flux1-schnell"].ae_params
|
||||
with accelerate.init_empty_weights():
|
||||
ae = AutoEncoder(ae_params)
|
||||
|
||||
state_dict = load_file(path)
|
||||
ae.load_state_dict(state_dict, strict=True, assign=True)
|
||||
return ae
|
||||
assert isinstance(image, Image.Image)
|
||||
return image
|
||||
|
@ -65,6 +65,10 @@ class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
@ -133,8 +137,8 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@ -166,7 +170,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
|
@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in ["diffusers", "transformers"]:
|
||||
if module in [
|
||||
"diffusers",
|
||||
"transformers",
|
||||
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||
]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
|
@ -9,7 +9,7 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
@ -50,6 +50,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
T5TokenizerFast,
|
||||
),
|
||||
):
|
||||
return len(model)
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
|
@ -12,15 +12,17 @@ from diffusers.utils import (
|
||||
)
|
||||
from optimum.quanto.models import QuantizedDiffusersModel
|
||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs):
|
||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||
if cls.base_class is None:
|
||||
base_class = base_class or cls.base_class
|
||||
if base_class is None:
|
||||
raise ValueError("The `base_class` attribute needs to be configured.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
@ -43,16 +45,16 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
original_model_cls_name = json.load(f)["_class_name"]
|
||||
configured_cls_name = cls.base_class.__name__
|
||||
configured_cls_name = base_class.__name__
|
||||
if configured_cls_name != original_model_cls_name:
|
||||
raise ValueError(
|
||||
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
|
||||
)
|
||||
|
||||
# Create an empty model
|
||||
config = cls.base_class.load_config(model_name_or_path)
|
||||
config = base_class.load_config(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.base_class.from_config(config)
|
||||
model = base_class.from_config(config)
|
||||
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
@ -72,6 +74,6 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
|
||||
# Requantize and load quantized weights from state_dict
|
||||
requantize(model, state_dict=state_dict, quantization_map=qmap)
|
||||
model.eval()
|
||||
return cls(model)
|
||||
return cls(model)._wrapped
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
from optimum.quanto.models import QuantizedTransformersModel
|
||||
@ -7,15 +8,17 @@ from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
|
||||
from invokeai.backend.requantize import requantize
|
||||
|
||||
|
||||
class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
|
||||
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], auto_class = AutoModelForTextEncoding, **kwargs):
|
||||
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
|
||||
if cls.auto_class is None:
|
||||
auto_class = auto_class or cls.auto_class
|
||||
if auto_class is None:
|
||||
raise ValueError(
|
||||
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
|
||||
)
|
||||
@ -33,7 +36,7 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||
# Create an empty model
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
with init_empty_weights():
|
||||
model = cls.auto_class.from_config(config)
|
||||
model = auto_class.from_config(config)
|
||||
# Look for the index of a sharded checkpoint
|
||||
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if os.path.exists(checkpoint_file):
|
||||
@ -56,6 +59,6 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel):
|
||||
model.tie_weights()
|
||||
# Set model in evaluation mode as it is done in transformers
|
||||
model.eval()
|
||||
return cls(model)
|
||||
return cls(model)._wrapped
|
||||
else:
|
||||
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
|
||||
|
@ -5697,15 +5697,15 @@ export type components = {
|
||||
*/
|
||||
transformer: components["schemas"]["TransformerField"];
|
||||
/**
|
||||
* CLIP 1
|
||||
* CLIP
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
*/
|
||||
clip: components["schemas"]["CLIPField"];
|
||||
/**
|
||||
* CLIP 2
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
* T5 Encoder
|
||||
* @description T5 tokenizer and text encoder
|
||||
*/
|
||||
clip2: components["schemas"]["CLIPField"];
|
||||
t5Encoder: components["schemas"]["T5EncoderField"];
|
||||
/**
|
||||
* VAE
|
||||
* @description VAE
|
||||
@ -5739,19 +5739,17 @@ export type components = {
|
||||
*/
|
||||
use_cache?: boolean;
|
||||
/**
|
||||
* Model
|
||||
* @description The FLUX model to use for text-to-image generation.
|
||||
* CLIP
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
* @default null
|
||||
* @constant
|
||||
* @enum {string}
|
||||
*/
|
||||
model?: "flux-schnell";
|
||||
clip?: components["schemas"]["CLIPField"];
|
||||
/**
|
||||
* Use 8Bit
|
||||
* @description Whether to quantize the transformer model to 8-bit precision.
|
||||
* @default false
|
||||
* T5EncoderField
|
||||
* @description T5 tokenizer and text encoder
|
||||
* @default null
|
||||
*/
|
||||
use_8bit?: boolean;
|
||||
t5Encoder?: components["schemas"]["T5EncoderField"];
|
||||
/**
|
||||
* Positive Prompt
|
||||
* @description Positive prompt for text-to-image generation.
|
||||
@ -5799,31 +5797,16 @@ export type components = {
|
||||
*/
|
||||
use_cache?: boolean;
|
||||
/**
|
||||
* @description The Flux model
|
||||
* Transformer
|
||||
* @description UNet (scheduler, LoRAs)
|
||||
* @default null
|
||||
*/
|
||||
flux_model?: components["schemas"]["ModelIdentifierField"];
|
||||
transformer?: components["schemas"]["TransformerField"];
|
||||
/**
|
||||
* Model
|
||||
* @description The FLUX model to use for text-to-image generation.
|
||||
* @description VAE
|
||||
* @default null
|
||||
* @constant
|
||||
* @enum {string}
|
||||
*/
|
||||
model?: "flux-schnell";
|
||||
/**
|
||||
* Quantization Type
|
||||
* @description The type of quantization to use for the transformer model.
|
||||
* @default raw
|
||||
* @enum {string}
|
||||
*/
|
||||
quantization_type?: "raw" | "NF4" | "llm_int8";
|
||||
/**
|
||||
* Use 8Bit
|
||||
* @description Whether to quantize the transformer model to 8-bit precision.
|
||||
* @default false
|
||||
*/
|
||||
use_8bit?: boolean;
|
||||
vae?: components["schemas"]["VAEField"];
|
||||
/**
|
||||
* @description Positive conditioning tensor
|
||||
* @default null
|
||||
@ -14268,6 +14251,13 @@ export type components = {
|
||||
*/
|
||||
type: "t2i_adapter_output";
|
||||
};
|
||||
/** T5EncoderField */
|
||||
T5EncoderField: {
|
||||
/** @description Info to load tokenizer submodel */
|
||||
tokenizer: components["schemas"]["ModelIdentifierField"];
|
||||
/** @description Info to load text_encoder submodel */
|
||||
text_encoder: components["schemas"]["ModelIdentifierField"];
|
||||
};
|
||||
/** TBLR */
|
||||
TBLR: {
|
||||
/** Top */
|
||||
|
Loading…
Reference in New Issue
Block a user