Fix support for 8b quantized t5 encoders, update exception messages in flux loaders

This commit is contained in:
Brandon Rising 2024-08-20 12:37:12 -04:00 committed by Brandon
parent e49105ece5
commit dee6d2c98e
3 changed files with 42 additions and 10 deletions

View File

@ -143,9 +143,9 @@ T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"format": ModelFormat.T5Encoder,
},
"8b_quantized": {
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
"repo": "invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8",
"name": "t5_8b_quantized_encoder",
"format": ModelFormat.T5Encoder,
"format": ModelFormat.T5Encoder8b,
},
}

View File

@ -225,6 +225,14 @@ class T5EncoderConfig(T5EncoderConfigBase):
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
class T5Encoder8bConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.T5Encoder8b] = ModelFormat.T5Encoder8b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder8b.value}")
class LoRALyCORISConfig(LoRAConfigBase):
"""Model config for LoRA/Lycoris models."""
@ -460,6 +468,7 @@ AnyModelConfig = Annotated[
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5Encoder8bConfig, T5Encoder8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],

View File

@ -28,12 +28,14 @@ from invokeai.backend.model_manager.config import (
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
T5EncoderConfig,
T5Encoder8bConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
app_config = get_config()
@ -82,7 +84,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer:
@ -90,7 +92,28 @@ class ClipCheckpointModel(GenericDiffusersLoader):
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(config.path)
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5Encoder8bConfig):
raise Exception("Only T5Encoder8bConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
@ -103,7 +126,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
@ -113,7 +136,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
Path(config.path) / "text_encoder_2"
) # TODO: Fix hf subfolder install
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ -126,7 +149,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
@ -139,7 +162,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only Transformer submodels are currently supported.")
def _load_from_singlefile(
self,
@ -171,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
@ -184,7 +207,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Checkpoint Flux models are currently supported.")
raise Exception("Only Transformer submodels are currently supported.")
def _load_from_singlefile(
self,