mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix support for 8b quantized t5 encoders, update exception messages in flux loaders
This commit is contained in:
parent
e49105ece5
commit
dee6d2c98e
@ -143,9 +143,9 @@ T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
||||
"format": ModelFormat.T5Encoder,
|
||||
},
|
||||
"8b_quantized": {
|
||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
|
||||
"repo": "invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8",
|
||||
"name": "t5_8b_quantized_encoder",
|
||||
"format": ModelFormat.T5Encoder,
|
||||
"format": ModelFormat.T5Encoder8b,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -225,6 +225,14 @@ class T5EncoderConfig(T5EncoderConfigBase):
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||
|
||||
|
||||
class T5Encoder8bConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder8b] = ModelFormat.T5Encoder8b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder8b.value}")
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@ -460,6 +468,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5Encoder8bConfig, T5Encoder8bConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
|
@ -28,12 +28,14 @@ from invokeai.backend.model_manager.config import (
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
T5EncoderConfig,
|
||||
T5Encoder8bConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
@ -82,7 +84,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
@ -90,7 +92,28 @@ class ClipCheckpointModel(GenericDiffusersLoader):
|
||||
case SubModelType.TextEncoder:
|
||||
return CLIPTextModel.from_pretrained(config.path)
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
||||
class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5Encoder8bConfig):
|
||||
raise Exception("Only T5Encoder8bConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||
|
||||
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||
@ -103,7 +126,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderConfig):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
@ -113,7 +136,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||
Path(config.path) / "text_encoder_2"
|
||||
) # TODO: Fix hf subfolder install
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ -126,7 +149,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only CheckpointConfigBase models are currently supported here.")
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
@ -139,7 +162,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only Transformer submodels are currently supported.")
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
@ -171,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only CheckpointConfigBase models are currently supported here.")
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
@ -184,7 +207,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
raise Exception("Only Transformer submodels are currently supported.")
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user