From 544ab296e77931767de380b84e69552b3c5f6e1e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 23 Aug 2024 17:44:03 +0000 Subject: [PATCH] Update the T5 8-bit quantized starter model to use the BnB LLM.int8() variant. --- .../model_manager/load/model_loaders/flux.py | 28 +++++++++++++++++-- .../backend/model_manager/starter_models.py | 12 ++++---- .../StarterModels/StartModelsResultItem.tsx | 4 +-- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 58b4843395..f3e44fc221 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -9,7 +9,7 @@ import accelerate import torch import yaml from safetensors.torch import load_file -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer +from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.services.config.config_default import get_config from invokeai.backend.flux.model import Flux, FluxParams @@ -33,7 +33,8 @@ from invokeai.backend.model_manager.config import ( ) from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel +from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 +from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.util.silence_warnings import SilenceWarnings try: @@ -115,12 +116,33 @@ class T5Encoder8bCheckpointModel(ModelLoader): case SubModelType.Tokenizer2: return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) case SubModelType.TextEncoder2: - return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2") + te2_model_path = Path(config.path) / "text_encoder_2" + model_config = AutoConfig.from_pretrained(te2_model_path) + with accelerate.init_empty_weights(): + model = AutoModelForTextEncoding.from_config(model_config) + model = quantize_model_llm_int8(model, modules_to_not_convert=set()) + + state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors" + state_dict = load_file(state_dict_path) + self._load_state_dict_into_t5(model, state_dict) + + return model raise ValueError( f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" ) + @classmethod + def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]): + # There is a shared reference to a single weight tensor in the model. + # Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should + # be present in the state_dict. + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) + assert len(unexpected_keys) == 0 + assert set(missing_keys) == {"encoder.embed_tokens.weight"} + # Assert that the layers we expect to be shared are actually shared. + assert model.encoder.embed_tokens.weight is model.shared.weight + @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) class T5EncoderCheckpointModel(ModelLoader): diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index 7d5233d767..13a22ee219 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -2,7 +2,7 @@ from typing import Optional from pydantic import BaseModel -from invokeai.backend.model_manager.config import BaseModelType, ModelType +from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType class StarterModelWithoutDependencies(BaseModel): @@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel): name: str base: BaseModelType type: ModelType + format: Optional[ModelFormat] = None is_installed: bool = False @@ -54,17 +55,18 @@ cyberrealistic_negative = StarterModel( t5_base_encoder = StarterModel( name="t5_base_encoder", base=BaseModelType.Any, - source="InvokeAI/flux_schnell::t5_xxl_encoder/base", + source="InvokeAI/t5-v1_1-xxl::bfloat16", description="T5-XXL text encoder (used in FLUX pipelines). ~8GB", type=ModelType.T5Encoder, ) t5_8b_quantized_encoder = StarterModel( - name="t5_8b_quantized_encoder", + name="t5_bnb_int8_quantized_encoder", base=BaseModelType.Any, - source="invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8", - description="T5-XXL text encoder with optimum-quanto qfloat8 quantization (used in FLUX pipelines). ~6GB", + source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8", + description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB", type=ModelType.T5Encoder, + format=ModelFormat.T5Encoder8b, ) clip_l_encoder = StarterModel( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx index 4fc8390890..bd6a2b4268 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx @@ -15,14 +15,14 @@ export const StarterModelsResultItem = memo(({ result }: Props) => { const _allSources = [ { source: result.source, - config: { name: result.name, description: result.description, type: result.type, base: result.base }, + config: { name: result.name, description: result.description, type: result.type, base: result.base, format: result.format }, }, ]; if (result.dependencies) { for (const d of result.dependencies) { _allSources.push({ source: d.source, - config: { name: d.name, description: d.description, type: d.type, base: d.base }, + config: { name: d.name, description: d.description, type: d.type, base: d.base, format: d.format }, }); } }