Update the T5 8-bit quantized starter model to use the BnB LLM.int8() variant.

This commit is contained in:
Ryan Dick 2024-08-23 17:44:03 +00:00
parent 86e49c423c
commit 544ab296e7
3 changed files with 34 additions and 10 deletions

View File

@ -9,7 +9,7 @@ import accelerate
import torch import torch
import yaml import yaml
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux, FluxParams from invokeai.backend.flux.model import Flux, FluxParams
@ -33,7 +33,8 @@ from invokeai.backend.model_manager.config import (
) )
from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
try: try:
@ -115,12 +116,33 @@ class T5Encoder8bCheckpointModel(ModelLoader):
case SubModelType.Tokenizer2: case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2: case SubModelType.TextEncoder2:
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2") te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
state_dict = load_file(state_dict_path)
self._load_state_dict_into_t5(model, state_dict)
return model
raise ValueError( raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
) )
@classmethod
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(ModelLoader): class T5EncoderCheckpointModel(ModelLoader):

View File

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from invokeai.backend.model_manager.config import BaseModelType, ModelType from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
class StarterModelWithoutDependencies(BaseModel): class StarterModelWithoutDependencies(BaseModel):
@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
name: str name: str
base: BaseModelType base: BaseModelType
type: ModelType type: ModelType
format: Optional[ModelFormat] = None
is_installed: bool = False is_installed: bool = False
@ -54,17 +55,18 @@ cyberrealistic_negative = StarterModel(
t5_base_encoder = StarterModel( t5_base_encoder = StarterModel(
name="t5_base_encoder", name="t5_base_encoder",
base=BaseModelType.Any, base=BaseModelType.Any,
source="InvokeAI/flux_schnell::t5_xxl_encoder/base", source="InvokeAI/t5-v1_1-xxl::bfloat16",
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB", description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
type=ModelType.T5Encoder, type=ModelType.T5Encoder,
) )
t5_8b_quantized_encoder = StarterModel( t5_8b_quantized_encoder = StarterModel(
name="t5_8b_quantized_encoder", name="t5_bnb_int8_quantized_encoder",
base=BaseModelType.Any, base=BaseModelType.Any,
source="invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8", source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
description="T5-XXL text encoder with optimum-quanto qfloat8 quantization (used in FLUX pipelines). ~6GB", description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
type=ModelType.T5Encoder, type=ModelType.T5Encoder,
format=ModelFormat.T5Encoder8b,
) )
clip_l_encoder = StarterModel( clip_l_encoder = StarterModel(

View File

@ -15,14 +15,14 @@ export const StarterModelsResultItem = memo(({ result }: Props) => {
const _allSources = [ const _allSources = [
{ {
source: result.source, source: result.source,
config: { name: result.name, description: result.description, type: result.type, base: result.base }, config: { name: result.name, description: result.description, type: result.type, base: result.base, format: result.format },
}, },
]; ];
if (result.dependencies) { if (result.dependencies) {
for (const d of result.dependencies) { for (const d of result.dependencies) {
_allSources.push({ _allSources.push({
source: d.source, source: d.source,
config: { name: d.name, description: d.description, type: d.type, base: d.base }, config: { name: d.name, description: d.description, type: d.type, base: d.base, format: d.format },
}); });
} }
} }