Add nf4 bnb quantized format

This commit is contained in:
Brandon Rising 2024-08-19 12:08:24 -04:00 committed by Brandon
parent 1bd90e0fd4
commit 723f3ab0a9
3 changed files with 26 additions and 7 deletions

View File

@ -136,12 +136,12 @@ class ModelIdentifierInvocation(BaseInvocation):
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"] T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = { T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"base": { "base": {
"repo": "invokeai/flux_dev::t5_xxl_encoder/base", "repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base",
"name": "t5_base_encoder", "name": "t5_base_encoder",
"format": ModelFormat.T5Encoder, "format": ModelFormat.T5Encoder,
}, },
"8b_quantized": { "8b_quantized": {
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized", "repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
"name": "t5_8b_quantized_encoder", "name": "t5_8b_quantized_encoder",
"format": ModelFormat.T5Encoder, "format": ModelFormat.T5Encoder,
}, },

View File

@ -111,6 +111,7 @@ class ModelFormat(str, Enum):
T5Encoder = "t5_encoder" T5Encoder = "t5_encoder"
T5Encoder8b = "t5_encoder_8b" T5Encoder8b = "t5_encoder_8b"
T5Encoder4b = "t5_encoder_4b" T5Encoder4b = "t5_encoder_4b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
class SchedulerPredictionType(str, Enum): class SchedulerPredictionType(str, Enum):
@ -193,7 +194,7 @@ class ModelConfigBase(BaseModel):
class CheckpointConfigBase(ModelConfigBase): class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models.""" """Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint)
config_path: str = Field(description="path to the checkpoint model config file") config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field( converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time description="When this model was last converted to diffusers", default_factory=time.time
@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models.""" """Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:
@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:
@ -336,6 +335,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}") return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase): class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models.""" """Model config for main diffusers models."""
@ -438,6 +452,7 @@ AnyModelConfig = Annotated[
Union[ Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],

View File

@ -162,7 +162,7 @@ class ModelProbe(object):
fields["description"] = ( fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
) )
fields["format"] = ModelFormat(fields.get("format")) or probe.get_format() fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings") fields["default_settings"] = fields.get("default_settings")
@ -179,7 +179,7 @@ class ModelProbe(object):
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if ( if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint and fields["format"] in [ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b]
): ):
ckpt_config_path = cls._get_checkpoint_config_path( ckpt_config_path = cls._get_checkpoint_config_path(
model_path, model_path,
@ -323,6 +323,7 @@ class ModelProbe(object):
if model_type is ModelType.Main: if model_type is ModelType.Main:
if base_type == BaseModelType.Flux: if base_type == BaseModelType.Flux:
# TODO: Decide between dev/schnell
config_file = "flux/flux1-schnell.yaml" config_file = "flux/flux1-schnell.yaml"
else: else:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
@ -422,6 +423,9 @@ class CheckpointProbeBase(ProbeBase):
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat: def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
return ModelFormat.BnbQuantizednf4b
return ModelFormat("checkpoint") return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType: def get_variant_type(self) -> ModelVariantType: