From 56fda669fd6242c14828ce1004715de4f4532fb1 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Mon, 12 Aug 2024 18:01:42 -0400 Subject: [PATCH] Manage quantization of models within the loader --- invokeai/app/invocations/fields.py | 1 + invokeai/app/invocations/flux_text_encoder.py | 108 +++++-------- .../app/invocations/flux_text_to_image.py | 151 ++++-------------- invokeai/app/invocations/model.py | 10 +- .../load/model_loaders/generic_diffusers.py | 7 +- .../backend/model_manager/load/model_util.py | 9 +- .../fast_quantized_diffusion_model.py | 14 +- .../fast_quantized_transformers_model.py | 11 +- .../frontend/web/src/services/api/schema.ts | 56 +++---- 9 files changed, 130 insertions(+), 237 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 91dfcb51a7..ba2c75aa13 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -126,6 +126,7 @@ class FieldDescriptions: negative_cond = "Negative conditioning tensor" noise = "Noise tensor" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + t5Encoder = "T5 tokenizer and text encoder" unet = "UNet (scheduler, LoRAs)" transformer = "Transformer" vae = "VAE" diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 582ae6fabc..ce173a49a1 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -6,8 +6,10 @@ from optimum.quanto import qfloat8 from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.fields import InputField -from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys +from invokeai.app.invocations.model import CLIPField, T5EncoderField +from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input +from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding +from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo @@ -22,9 +24,15 @@ from invokeai.backend.util.devices import TorchDevice version="1.0.0", ) class FluxTextEncoderInvocation(BaseInvocation): - model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") - use_8bit: bool = InputField( - default=False, description="Whether to quantize the transformer model to 8-bit precision." + clip: CLIPField = InputField( + title="CLIP", + description=FieldDescriptions.clip, + input=Input.Connection, + ) + t5Encoder: T5EncoderField = InputField( + title="T5EncoderField", + description=FieldDescriptions.t5Encoder, + input=Input.Connection, ) positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.") @@ -32,47 +40,43 @@ class FluxTextEncoderInvocation(BaseInvocation): # compatible with other ConditioningOutputs. @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) - t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path) + t5_embeddings, clip_embeddings = self._encode_prompt(context) conditioning_data = ConditioningFieldData( conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] ) conditioning_name = context.conditioning.save(conditioning_data) return ConditioningOutput.build(conditioning_name) + + def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: + # TODO: Determine the T5 max sequence length based on the model. + # if self.model == "flux-schnell": + max_seq_len = 256 + # # elif self.model == "flux-dev": + # # max_seq_len = 512 + # else: + # raise ValueError(f"Unknown model: {self.model}") - def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]: - # Determine the T5 max sequence length based on the model. - if self.model == "flux-schnell": - max_seq_len = 256 - # elif self.model == "flux-dev": - # max_seq_len = 512 - else: - raise ValueError(f"Unknown model: {self.model}") + # Load CLIP. + clip_tokenizer_info = context.models.load(self.clip.tokenizer) + clip_text_encoder_info = context.models.load(self.clip.text_encoder) - # Load the CLIP tokenizer. - clip_tokenizer_path = flux_model_dir / "tokenizer" - clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True) - assert isinstance(clip_tokenizer, CLIPTokenizer) + # Load T5. + t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer) + t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder) - # Load the T5 tokenizer. - t5_tokenizer_path = flux_model_dir / "tokenizer_2" - t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True) - assert isinstance(t5_tokenizer, T5TokenizerFast) - - clip_text_encoder_path = flux_model_dir / "text_encoder" - t5_text_encoder_path = flux_model_dir / "text_encoder_2" with ( - context.models.load_local_model( - model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder - ) as clip_text_encoder, - context.models.load_local_model( - model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2 - ) as t5_text_encoder, + clip_text_encoder_info as clip_text_encoder, + t5_text_encoder_info as t5_text_encoder, + clip_tokenizer_info as clip_tokenizer, + t5_tokenizer_info as t5_tokenizer, ): assert isinstance(clip_text_encoder, CLIPTextModel) assert isinstance(t5_text_encoder, T5EncoderModel) + assert isinstance(clip_tokenizer, CLIPTokenizer) + assert isinstance(t5_tokenizer, T5TokenizerFast) + pipeline = FluxPipeline( scheduler=None, vae=None, @@ -85,7 +89,7 @@ class FluxTextEncoderInvocation(BaseInvocation): # prompt_embeds: T5 embeddings # pooled_prompt_embeds: CLIP embeddings - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt( prompt=self.positive_prompt, prompt_2=self.positive_prompt, device=TorchDevice.choose_torch_device(), @@ -95,41 +99,3 @@ class FluxTextEncoderInvocation(BaseInvocation): assert isinstance(prompt_embeds, torch.Tensor) assert isinstance(pooled_prompt_embeds, torch.Tensor) return prompt_embeds, pooled_prompt_embeds - - @staticmethod - def _load_flux_text_encoder(path: Path) -> CLIPTextModel: - model = CLIPTextModel.from_pretrained(path, local_files_only=True) - assert isinstance(model, CLIPTextModel) - return model - - def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel: - if self.use_8bit: - model_8bit_path = path / "quantized" - if model_8bit_path.exists(): - # The quantized model exists, load it. - # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like - # something that we should be able to make much faster. - q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path) - - # Access the underlying wrapped model. - # We access the wrapped model, even though it is private, because it simplifies the type checking by - # always returning a T5EncoderModel from this function. - model = q_model._wrapped - else: - # The quantized model does not exist yet, quantize and save it. - # TODO(ryand): dtype? - model = T5EncoderModel.from_pretrained(path, local_files_only=True) - assert isinstance(model, T5EncoderModel) - - q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8) - - model_8bit_path.mkdir(parents=True, exist_ok=True) - q_model.save_pretrained(model_8bit_path) - - # (See earlier comment about accessing the wrapped model.) - model = q_model._wrapped - else: - model = T5EncoderModel.from_pretrained(path, local_files_only=True) - - assert isinstance(model, T5EncoderModel) - return model diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 7a577215f8..334e8fd1ea 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -6,7 +6,7 @@ import accelerate import torch from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline -from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.invocations.model import TransformerField, VAEField from optimum.quanto import qfloat8 from PIL import Image from safetensors.torch import load_file @@ -52,17 +52,14 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel): class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Text-to-image generation using a FLUX model.""" - flux_model: ModelIdentifierField = InputField( - description="The Flux model", - input=Input.Any, - ui_type=UIType.FluxMainModel + transformer: TransformerField = InputField( + description=FieldDescriptions.unet, + input=Input.Connection, + title="Transformer", ) - model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") - quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField( - default="raw", description="The type of quantization to use for the transformer model." - ) - use_8bit: bool = InputField( - default=False, description="Whether to quantize the transformer model to 8-bit precision." + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, ) positive_text_conditioning: ConditioningField = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection @@ -78,13 +75,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - # model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) - flux_transformer_path = context.models.download_and_cache_model( - "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors" - ) - flux_ae_path = context.models.download_and_cache_model( - "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors" - ) # Load the conditioning data. cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) @@ -92,56 +82,31 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): flux_conditioning = cond_data.conditionings[0] assert isinstance(flux_conditioning, FLUXConditioningInfo) - latents = self._run_diffusion( - context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds - ) - image = self._run_vae_decoding(context, flux_ae_path, latents) + latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) + image = self._run_vae_decoding(context, latents) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) def _run_diffusion( self, context: InvocationContext, - flux_transformer_path: Path, clip_embeddings: torch.Tensor, t5_embeddings: torch.Tensor, ): - inference_dtype = TorchDevice.choose_torch_dtype() - - # Prepare input noise. - # TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a - # CPU RNG? - x = get_noise( - num_samples=1, - height=self.height, - width=self.width, - device=TorchDevice.choose_torch_device(), - dtype=inference_dtype, - seed=self.seed, - ) - - img, img_ids = self._prepare_latent_img_patches(x) - - # HACK(ryand): Find a better way to determine if this is a schnell model or not. - is_schnell = "shnell" in str(flux_transformer_path) - timesteps = get_schedule( - num_steps=self.num_steps, - image_seq_len=img.shape[1], - shift=not is_schnell, - ) - - bs, t5_seq_len, _ = t5_embeddings.shape - txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) + scheduler_info = context.models.load(self.transformer.scheduler) + transformer_info = context.models.load(self.transformer.transformer) # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems # if the cache is not empty. - context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) + # context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) - with context.models.load_local_model( - model_path=flux_transformer_path, loader=self._load_flux_transformer - ) as transformer: - assert isinstance(transformer, Flux) + with ( + transformer_info as transformer, + scheduler_info as scheduler + ): + assert isinstance(transformer, FluxTransformer2DModel) + assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler) x = denoise( model=transformer, @@ -185,75 +150,25 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def _run_vae_decoding( self, context: InvocationContext, - flux_ae_path: Path, latents: torch.Tensor, ) -> Image.Image: - with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae: - assert isinstance(vae, AutoEncoder) - # TODO(ryand): Test that this works with both float16 and bfloat16. - with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()): - img = vae.decode(latents) + vae_info = context.models.load(self.vae.vae) + with vae_info as vae: + assert isinstance(vae, AutoencoderKL) img.clamp(-1, 1) img = rearrange(img[0], "c h w -> h w c") img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) - return img_pil + latents = flux_pipeline_with_vae._unpack_latents( + latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor + ) + latents = ( + latents / flux_pipeline_with_vae.vae.config.scaling_factor + ) + flux_pipeline_with_vae.vae.config.shift_factor + latents = latents.to(dtype=vae.dtype) + image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0] + image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0] - def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel: - inference_dtype = TorchDevice.choose_torch_dtype() - if self.quantization_type == "raw": - # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - params = flux_configs["flux-schnell"].params - - # Initialize the model on the "meta" device. - with accelerate.init_empty_weights(): - model = Flux(params).to(inference_dtype) - - state_dict = load_file(path) - # TODO(ryand): Cast the state_dict to the appropriate dtype? - model.load_state_dict(state_dict, strict=True, assign=True) - elif self.quantization_type == "NF4": - model_path = path.parent / "bnb_nf4.safetensors" - - # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - params = flux_configs["flux-schnell"].params - # Initialize the model on the "meta" device. - with accelerate.init_empty_weights(): - model = Flux(params) - model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) - - # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle - # this on GPUs without bfloat16 support. - state_dict = load_file(model_path) - model.load_state_dict(state_dict, strict=True, assign=True) - - elif self.quantization_type == "llm_int8": - raise NotImplementedError("LLM int8 quantization is not yet supported.") - # model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) - # with accelerate.init_empty_weights(): - # empty_model = FluxTransformer2DModel.from_config(model_config) - # assert isinstance(empty_model, FluxTransformer2DModel) - # model_int8_path = path / "bnb_llm_int8" - # assert model_int8_path.exists() - # with accelerate.init_empty_weights(): - # model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set()) - - # sd = load_file(model_int8_path / "model.safetensors") - # model.load_state_dict(sd, strict=True, assign=True) - else: - raise ValueError(f"Unsupported quantization type: {self.quantization_type}") - - assert isinstance(model, FluxTransformer2DModel) - return model - - @staticmethod - def _load_flux_vae(path: Path) -> AutoEncoder: - # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - ae_params = flux_configs["flux1-schnell"].ae_params - with accelerate.init_empty_weights(): - ae = AutoEncoder(ae_params) - - state_dict = load_file(path) - ae.load_state_dict(state_dict, strict=True, assign=True) - return ae + assert isinstance(image, Image.Image) + return image diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index dd12109269..4672f6a83d 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -65,6 +65,10 @@ class TransformerField(BaseModel): transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel") scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") +class T5EncoderField(BaseModel): + tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") + text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") + class VAEField(BaseModel): vae: ModelIdentifierField = Field(description="Info to load vae submodel") @@ -133,8 +137,8 @@ class FluxModelLoaderOutput(BaseInvocationOutput): """Flux base model loader output""" transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") - clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") - clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP") + t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder") vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") @@ -166,7 +170,7 @@ class FluxModelLoaderInvocation(BaseInvocation): return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer, scheduler=scheduler), clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), - clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), + t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2), vae=VAEField(vae=vae), ) diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index dfe38aa79c..f1691ec4d4 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader): # TO DO: Add exception handling def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type - if module in ["diffusers", "transformers"]: + if module in [ + "diffusers", + "transformers", + "invokeai.backend.quantization.fast_quantized_transformers_model", + "invokeai.backend.quantization.fast_quantized_diffusion_model", + ]: res_type = sys.modules[module] else: res_type = sys.modules["diffusers"].pipelines diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index bc612043e3..b3b78104d9 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -9,7 +9,7 @@ from typing import Optional import torch from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, T5TokenizerFast from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline @@ -50,6 +50,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: ), ): return model.calc_size() + elif isinstance( + model, + ( + T5TokenizerFast, + ), + ): + return len(model) else: # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the # supported model types. diff --git a/invokeai/backend/quantization/fast_quantized_diffusion_model.py b/invokeai/backend/quantization/fast_quantized_diffusion_model.py index 0759984bf9..395efc99c4 100644 --- a/invokeai/backend/quantization/fast_quantized_diffusion_model.py +++ b/invokeai/backend/quantization/fast_quantized_diffusion_model.py @@ -12,15 +12,17 @@ from diffusers.utils import ( ) from optimum.quanto.models import QuantizedDiffusersModel from optimum.quanto.models.shared_dict import ShardedStateDict +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from invokeai.backend.requantize import requantize class FastQuantizedDiffusersModel(QuantizedDiffusersModel): @classmethod - def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): + def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs): """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" - if cls.base_class is None: + base_class = base_class or cls.base_class + if base_class is None: raise ValueError("The `base_class` attribute needs to be configured.") if not is_accelerate_available(): @@ -43,16 +45,16 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel): with open(model_config_path, "r", encoding="utf-8") as f: original_model_cls_name = json.load(f)["_class_name"] - configured_cls_name = cls.base_class.__name__ + configured_cls_name = base_class.__name__ if configured_cls_name != original_model_cls_name: raise ValueError( f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." ) # Create an empty model - config = cls.base_class.load_config(model_name_or_path) + config = base_class.load_config(model_name_or_path) with init_empty_weights(): - model = cls.base_class.from_config(config) + model = base_class.from_config(config) # Look for the index of a sharded checkpoint checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) @@ -72,6 +74,6 @@ class FastQuantizedDiffusersModel(QuantizedDiffusersModel): # Requantize and load quantized weights from state_dict requantize(model, state_dict=state_dict, quantization_map=qmap) model.eval() - return cls(model) + return cls(model)._wrapped else: raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") diff --git a/invokeai/backend/quantization/fast_quantized_transformers_model.py b/invokeai/backend/quantization/fast_quantized_transformers_model.py index ce5cc7a3a9..99f889b4af 100644 --- a/invokeai/backend/quantization/fast_quantized_transformers_model.py +++ b/invokeai/backend/quantization/fast_quantized_transformers_model.py @@ -1,5 +1,6 @@ import json import os +import torch from typing import Union from optimum.quanto.models import QuantizedTransformersModel @@ -7,15 +8,17 @@ from optimum.quanto.models.shared_dict import ShardedStateDict from transformers import AutoConfig from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available +from transformers.models.auto import AutoModelForTextEncoding from invokeai.backend.requantize import requantize class FastQuantizedTransformersModel(QuantizedTransformersModel): @classmethod - def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): + def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], auto_class = AutoModelForTextEncoding, **kwargs): """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" - if cls.auto_class is None: + auto_class = auto_class or cls.auto_class + if auto_class is None: raise ValueError( "Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead." ) @@ -33,7 +36,7 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel): # Create an empty model config = AutoConfig.from_pretrained(model_name_or_path) with init_empty_weights(): - model = cls.auto_class.from_config(config) + model = auto_class.from_config(config) # Look for the index of a sharded checkpoint checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) if os.path.exists(checkpoint_file): @@ -56,6 +59,6 @@ class FastQuantizedTransformersModel(QuantizedTransformersModel): model.tie_weights() # Set model in evaluation mode as it is done in transformers model.eval() - return cls(model) + return cls(model)._wrapped else: raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.") diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index ef0b869b8e..b8cdc2e88d 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -5697,15 +5697,15 @@ export type components = { */ transformer: components["schemas"]["TransformerField"]; /** - * CLIP 1 + * CLIP * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count */ clip: components["schemas"]["CLIPField"]; /** - * CLIP 2 - * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * T5 Encoder + * @description T5 tokenizer and text encoder */ - clip2: components["schemas"]["CLIPField"]; + t5Encoder: components["schemas"]["T5EncoderField"]; /** * VAE * @description VAE @@ -5739,19 +5739,17 @@ export type components = { */ use_cache?: boolean; /** - * Model - * @description The FLUX model to use for text-to-image generation. + * CLIP + * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count * @default null - * @constant - * @enum {string} */ - model?: "flux-schnell"; + clip?: components["schemas"]["CLIPField"]; /** - * Use 8Bit - * @description Whether to quantize the transformer model to 8-bit precision. - * @default false + * T5EncoderField + * @description T5 tokenizer and text encoder + * @default null */ - use_8bit?: boolean; + t5Encoder?: components["schemas"]["T5EncoderField"]; /** * Positive Prompt * @description Positive prompt for text-to-image generation. @@ -5799,31 +5797,16 @@ export type components = { */ use_cache?: boolean; /** - * @description The Flux model + * Transformer + * @description UNet (scheduler, LoRAs) * @default null */ - flux_model?: components["schemas"]["ModelIdentifierField"]; + transformer?: components["schemas"]["TransformerField"]; /** - * Model - * @description The FLUX model to use for text-to-image generation. + * @description VAE * @default null - * @constant - * @enum {string} */ - model?: "flux-schnell"; - /** - * Quantization Type - * @description The type of quantization to use for the transformer model. - * @default raw - * @enum {string} - */ - quantization_type?: "raw" | "NF4" | "llm_int8"; - /** - * Use 8Bit - * @description Whether to quantize the transformer model to 8-bit precision. - * @default false - */ - use_8bit?: boolean; + vae?: components["schemas"]["VAEField"]; /** * @description Positive conditioning tensor * @default null @@ -14268,6 +14251,13 @@ export type components = { */ type: "t2i_adapter_output"; }; + /** T5EncoderField */ + T5EncoderField: { + /** @description Info to load tokenizer submodel */ + tokenizer: components["schemas"]["ModelIdentifierField"]; + /** @description Info to load text_encoder submodel */ + text_encoder: components["schemas"]["ModelIdentifierField"]; + }; /** TBLR */ TBLR: { /** Top */