mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update the T5 8-bit quantized starter model to use the BnB LLM.int8() variant.
This commit is contained in:
parent
b9dd354e2b
commit
75d8ac378c
@ -9,7 +9,7 @@ import accelerate
|
|||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.flux.model import Flux, FluxParams
|
from invokeai.backend.flux.model import Flux, FluxParams
|
||||||
@ -33,7 +33,8 @@ from invokeai.backend.model_manager.config import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -115,12 +116,33 @@ class T5Encoder8bCheckpointModel(ModelLoader):
|
|||||||
case SubModelType.Tokenizer2:
|
case SubModelType.Tokenizer2:
|
||||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||||
case SubModelType.TextEncoder2:
|
case SubModelType.TextEncoder2:
|
||||||
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||||
|
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = AutoModelForTextEncoding.from_config(model_config)
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
||||||
|
|
||||||
|
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
|
||||||
|
state_dict = load_file(state_dict_path)
|
||||||
|
self._load_state_dict_into_t5(model, state_dict)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
|
||||||
|
# There is a shared reference to a single weight tensor in the model.
|
||||||
|
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||||
|
# be present in the state_dict.
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||||
|
assert len(unexpected_keys) == 0
|
||||||
|
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||||
|
# Assert that the layers we expect to be shared are actually shared.
|
||||||
|
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||||
class T5EncoderCheckpointModel(ModelLoader):
|
class T5EncoderCheckpointModel(ModelLoader):
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||||
|
|
||||||
|
|
||||||
class StarterModelWithoutDependencies(BaseModel):
|
class StarterModelWithoutDependencies(BaseModel):
|
||||||
@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
base: BaseModelType
|
base: BaseModelType
|
||||||
type: ModelType
|
type: ModelType
|
||||||
|
format: Optional[ModelFormat] = None
|
||||||
is_installed: bool = False
|
is_installed: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -54,17 +55,18 @@ cyberrealistic_negative = StarterModel(
|
|||||||
t5_base_encoder = StarterModel(
|
t5_base_encoder = StarterModel(
|
||||||
name="t5_base_encoder",
|
name="t5_base_encoder",
|
||||||
base=BaseModelType.Any,
|
base=BaseModelType.Any,
|
||||||
source="InvokeAI/flux_schnell::t5_xxl_encoder/base",
|
source="InvokeAI/t5-v1_1-xxl::bfloat16",
|
||||||
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
|
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
|
||||||
type=ModelType.T5Encoder,
|
type=ModelType.T5Encoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
t5_8b_quantized_encoder = StarterModel(
|
t5_8b_quantized_encoder = StarterModel(
|
||||||
name="t5_8b_quantized_encoder",
|
name="t5_bnb_int8_quantized_encoder",
|
||||||
base=BaseModelType.Any,
|
base=BaseModelType.Any,
|
||||||
source="invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8",
|
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
|
||||||
description="T5-XXL text encoder with optimum-quanto qfloat8 quantization (used in FLUX pipelines). ~6GB",
|
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
|
||||||
type=ModelType.T5Encoder,
|
type=ModelType.T5Encoder,
|
||||||
|
format=ModelFormat.T5Encoder8b,
|
||||||
)
|
)
|
||||||
|
|
||||||
clip_l_encoder = StarterModel(
|
clip_l_encoder = StarterModel(
|
||||||
|
@ -15,14 +15,14 @@ export const StarterModelsResultItem = memo(({ result }: Props) => {
|
|||||||
const _allSources = [
|
const _allSources = [
|
||||||
{
|
{
|
||||||
source: result.source,
|
source: result.source,
|
||||||
config: { name: result.name, description: result.description, type: result.type, base: result.base },
|
config: { name: result.name, description: result.description, type: result.type, base: result.base, format: result.format },
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
if (result.dependencies) {
|
if (result.dependencies) {
|
||||||
for (const d of result.dependencies) {
|
for (const d of result.dependencies) {
|
||||||
_allSources.push({
|
_allSources.push({
|
||||||
source: d.source,
|
source: d.source,
|
||||||
config: { name: d.name, description: d.description, type: d.type, base: d.base },
|
config: { name: d.name, description: d.description, type: d.type, base: d.base, format: d.format },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user