From da766f5a7e1d5f7c56533b028159be8c693efddf Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 20 Aug 2024 12:37:12 -0400 Subject: [PATCH] Fix support for 8b quantized t5 encoders, update exception messages in flux loaders --- invokeai/app/invocations/model.py | 4 +- invokeai/backend/model_manager/config.py | 9 +++++ .../model_manager/load/model_loaders/flux.py | 39 +++++++++++++++---- 3 files changed, 42 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 3d5f38927d..9c9d8eb834 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -143,9 +143,9 @@ T5_ENCODER_MAP: Dict[str, Dict[str, str]] = { "format": ModelFormat.T5Encoder, }, "8b_quantized": { - "repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8", + "repo": "invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8", "name": "t5_8b_quantized_encoder", - "format": ModelFormat.T5Encoder, + "format": ModelFormat.T5Encoder8b, }, } diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index ce6b8ed8cc..5dd74dbacc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -225,6 +225,14 @@ class T5EncoderConfig(T5EncoderConfigBase): return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}") +class T5Encoder8bConfig(T5EncoderConfigBase): + format: Literal[ModelFormat.T5Encoder8b] = ModelFormat.T5Encoder8b + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder8b.value}") + + class LoRALyCORISConfig(LoRAConfigBase): """Model config for LoRA/Lycoris models.""" @@ -460,6 +468,7 @@ AnyModelConfig = Annotated[ Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()], + Annotated[T5Encoder8bConfig, T5Encoder8bConfig.get_tag()], Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 3ba933bf48..5872936965 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -28,12 +28,14 @@ from invokeai.backend.model_manager.config import ( MainBnbQuantized4bCheckpointConfig, MainCheckpointConfig, T5EncoderConfig, + T5Encoder8bConfig, VAECheckpointConfig, ) from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.util.silence_warnings import SilenceWarnings +from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel app_config = get_config() @@ -82,7 +84,7 @@ class ClipCheckpointModel(GenericDiffusersLoader): submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, CLIPEmbedDiffusersConfig): - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.") match submodel_type: case SubModelType.Tokenizer: @@ -90,7 +92,28 @@ class ClipCheckpointModel(GenericDiffusersLoader): case SubModelType.TextEncoder: return CLIPTextModel.from_pretrained(config.path) - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.") + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b) +class T5Encoder8bCheckpointModel(GenericDiffusersLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, T5Encoder8bConfig): + raise Exception("Only T5Encoder8bConfig models are currently supported here.") + + match submodel_type: + case SubModelType.Tokenizer2: + return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) + case SubModelType.TextEncoder2: + return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2") + + raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.") @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) @@ -103,7 +126,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader): submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, T5EncoderConfig): - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only T5EncoderConfig models are currently supported here.") match submodel_type: case SubModelType.Tokenizer2: @@ -113,7 +136,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader): Path(config.path) / "text_encoder_2" ) # TODO: Fix hf subfolder install - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.") @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) @@ -126,7 +149,7 @@ class FluxCheckpointModel(GenericDiffusersLoader): submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, CheckpointConfigBase): - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only CheckpointConfigBase models are currently supported here.") legacy_config_path = app_config.legacy_conf_path / config.config_path config_path = legacy_config_path.as_posix() with open(config_path, "r") as stream: @@ -139,7 +162,7 @@ class FluxCheckpointModel(GenericDiffusersLoader): case SubModelType.Transformer: return self._load_from_singlefile(config, flux_conf) - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only Transformer submodels are currently supported.") def _load_from_singlefile( self, @@ -171,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader): submodel_type: Optional[SubModelType] = None, ) -> AnyModel: if not isinstance(config, CheckpointConfigBase): - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only CheckpointConfigBase models are currently supported here.") legacy_config_path = app_config.legacy_conf_path / config.config_path config_path = legacy_config_path.as_posix() with open(config_path, "r") as stream: @@ -184,7 +207,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader): case SubModelType.Transformer: return self._load_from_singlefile(config, flux_conf) - raise Exception("Only Checkpoint Flux models are currently supported.") + raise Exception("Only Transformer submodels are currently supported.") def _load_from_singlefile( self,