mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add nf4 bnb quantized format
This commit is contained in:
parent
1bd90e0fd4
commit
723f3ab0a9
@ -136,12 +136,12 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
|
||||
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
||||
"base": {
|
||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
|
||||
"repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base",
|
||||
"name": "t5_base_encoder",
|
||||
"format": ModelFormat.T5Encoder,
|
||||
},
|
||||
"8b_quantized": {
|
||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
|
||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
|
||||
"name": "t5_8b_quantized_encoder",
|
||||
"format": ModelFormat.T5Encoder,
|
||||
},
|
||||
|
@ -111,6 +111,7 @@ class ModelFormat(str, Enum):
|
||||
T5Encoder = "t5_encoder"
|
||||
T5Encoder8b = "t5_encoder_8b"
|
||||
T5Encoder4b = "t5_encoder_4b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@ -193,7 +194,7 @@ class ModelConfigBase(BaseModel):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint)
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@ -336,6 +335,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.BnbQuantizednf4b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@ -438,6 +452,7 @@ AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
|
@ -162,7 +162,7 @@ class ModelProbe(object):
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = ModelFormat(fields.get("format")) or probe.get_format()
|
||||
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
@ -179,7 +179,7 @@ class ModelProbe(object):
|
||||
# additional fields needed for main and controlnet models
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
and fields["format"] in [ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b]
|
||||
):
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
@ -323,6 +323,7 @@ class ModelProbe(object):
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
if base_type == BaseModelType.Flux:
|
||||
# TODO: Decide between dev/schnell
|
||||
config_file = "flux/flux1-schnell.yaml"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
@ -422,6 +423,9 @@ class CheckpointProbeBase(ProbeBase):
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
|
Loading…
Reference in New Issue
Block a user