mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix support for 8b quantized t5 encoders, update exception messages in flux loaders
This commit is contained in:
parent
120e1cf1e9
commit
da766f5a7e
@ -143,9 +143,9 @@ T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
|||||||
"format": ModelFormat.T5Encoder,
|
"format": ModelFormat.T5Encoder,
|
||||||
},
|
},
|
||||||
"8b_quantized": {
|
"8b_quantized": {
|
||||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
|
"repo": "invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8",
|
||||||
"name": "t5_8b_quantized_encoder",
|
"name": "t5_8b_quantized_encoder",
|
||||||
"format": ModelFormat.T5Encoder,
|
"format": ModelFormat.T5Encoder8b,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,6 +225,14 @@ class T5EncoderConfig(T5EncoderConfigBase):
|
|||||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class T5Encoder8bConfig(T5EncoderConfigBase):
|
||||||
|
format: Literal[ModelFormat.T5Encoder8b] = ModelFormat.T5Encoder8b
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder8b.value}")
|
||||||
|
|
||||||
|
|
||||||
class LoRALyCORISConfig(LoRAConfigBase):
|
class LoRALyCORISConfig(LoRAConfigBase):
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
@ -460,6 +468,7 @@ AnyModelConfig = Annotated[
|
|||||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||||
|
Annotated[T5Encoder8bConfig, T5Encoder8bConfig.get_tag()],
|
||||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||||
|
@ -28,12 +28,14 @@ from invokeai.backend.model_manager.config import (
|
|||||||
MainBnbQuantized4bCheckpointConfig,
|
MainBnbQuantized4bCheckpointConfig,
|
||||||
MainCheckpointConfig,
|
MainCheckpointConfig,
|
||||||
T5EncoderConfig,
|
T5EncoderConfig,
|
||||||
|
T5Encoder8bConfig,
|
||||||
VAECheckpointConfig,
|
VAECheckpointConfig,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
@ -82,7 +84,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer:
|
case SubModelType.Tokenizer:
|
||||||
@ -90,7 +92,28 @@ class ClipCheckpointModel(GenericDiffusersLoader):
|
|||||||
case SubModelType.TextEncoder:
|
case SubModelType.TextEncoder:
|
||||||
return CLIPTextModel.from_pretrained(config.path)
|
return CLIPTextModel.from_pretrained(config.path)
|
||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
||||||
|
class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, T5Encoder8bConfig):
|
||||||
|
raise Exception("Only T5Encoder8bConfig models are currently supported here.")
|
||||||
|
|
||||||
|
match submodel_type:
|
||||||
|
case SubModelType.Tokenizer2:
|
||||||
|
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||||
|
case SubModelType.TextEncoder2:
|
||||||
|
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||||
|
|
||||||
|
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||||
@ -103,7 +126,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, T5EncoderConfig):
|
if not isinstance(config, T5EncoderConfig):
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only T5EncoderConfig models are currently supported here.")
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer2:
|
case SubModelType.Tokenizer2:
|
||||||
@ -113,7 +136,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
|||||||
Path(config.path) / "text_encoder_2"
|
Path(config.path) / "text_encoder_2"
|
||||||
) # TODO: Fix hf subfolder install
|
) # TODO: Fix hf subfolder install
|
||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
@ -126,7 +149,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only CheckpointConfigBase models are currently supported here.")
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||||
config_path = legacy_config_path.as_posix()
|
config_path = legacy_config_path.as_posix()
|
||||||
with open(config_path, "r") as stream:
|
with open(config_path, "r") as stream:
|
||||||
@ -139,7 +162,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
|||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config, flux_conf)
|
||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Transformer submodels are currently supported.")
|
||||||
|
|
||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
@ -171,7 +194,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only CheckpointConfigBase models are currently supported here.")
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||||
config_path = legacy_config_path.as_posix()
|
config_path = legacy_config_path.as_posix()
|
||||||
with open(config_path, "r") as stream:
|
with open(config_path, "r") as stream:
|
||||||
@ -184,7 +207,7 @@ class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
|
|||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config, flux_conf)
|
||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Transformer submodels are currently supported.")
|
||||||
|
|
||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user