mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add nf4 bnb quantized format
This commit is contained in:
parent
1bd90e0fd4
commit
723f3ab0a9
@ -136,12 +136,12 @@ class ModelIdentifierInvocation(BaseInvocation):
|
|||||||
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
|
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
|
||||||
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
||||||
"base": {
|
"base": {
|
||||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
|
"repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base",
|
||||||
"name": "t5_base_encoder",
|
"name": "t5_base_encoder",
|
||||||
"format": ModelFormat.T5Encoder,
|
"format": ModelFormat.T5Encoder,
|
||||||
},
|
},
|
||||||
"8b_quantized": {
|
"8b_quantized": {
|
||||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
|
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
|
||||||
"name": "t5_8b_quantized_encoder",
|
"name": "t5_8b_quantized_encoder",
|
||||||
"format": ModelFormat.T5Encoder,
|
"format": ModelFormat.T5Encoder,
|
||||||
},
|
},
|
||||||
|
@ -111,6 +111,7 @@ class ModelFormat(str, Enum):
|
|||||||
T5Encoder = "t5_encoder"
|
T5Encoder = "t5_encoder"
|
||||||
T5Encoder8b = "t5_encoder_8b"
|
T5Encoder8b = "t5_encoder_8b"
|
||||||
T5Encoder4b = "t5_encoder_4b"
|
T5Encoder4b = "t5_encoder_4b"
|
||||||
|
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||||
|
|
||||||
|
|
||||||
class SchedulerPredictionType(str, Enum):
|
class SchedulerPredictionType(str, Enum):
|
||||||
@ -193,7 +194,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
class CheckpointConfigBase(ModelConfigBase):
|
class CheckpointConfigBase(ModelConfigBase):
|
||||||
"""Model config for checkpoint-style models."""
|
"""Model config for checkpoint-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint)
|
||||||
config_path: str = Field(description="path to the checkpoint model config file")
|
config_path: str = Field(description="path to the checkpoint model config file")
|
||||||
converted_at: Optional[float] = Field(
|
converted_at: Optional[float] = Field(
|
||||||
description="When this model was last converted to diffusers", default_factory=time.time
|
description="When this model was last converted to diffusers", default_factory=time.time
|
||||||
@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
|
|||||||
"""Model config for standalone VAE models."""
|
"""Model config for standalone VAE models."""
|
||||||
|
|
||||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
|
|||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
@ -336,6 +335,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||||
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.format = ModelFormat.BnbQuantizednf4b
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
@ -438,6 +452,7 @@ AnyModelConfig = Annotated[
|
|||||||
Union[
|
Union[
|
||||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||||
|
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||||
|
@ -162,7 +162,7 @@ class ModelProbe(object):
|
|||||||
fields["description"] = (
|
fields["description"] = (
|
||||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = ModelFormat(fields.get("format")) or probe.get_format()
|
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
|
|
||||||
fields["default_settings"] = fields.get("default_settings")
|
fields["default_settings"] = fields.get("default_settings")
|
||||||
@ -179,7 +179,7 @@ class ModelProbe(object):
|
|||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if (
|
if (
|
||||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||||
and fields["format"] is ModelFormat.Checkpoint
|
and fields["format"] in [ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b]
|
||||||
):
|
):
|
||||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||||
model_path,
|
model_path,
|
||||||
@ -323,6 +323,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
if model_type is ModelType.Main:
|
if model_type is ModelType.Main:
|
||||||
if base_type == BaseModelType.Flux:
|
if base_type == BaseModelType.Flux:
|
||||||
|
# TODO: Decide between dev/schnell
|
||||||
config_file = "flux/flux1-schnell.yaml"
|
config_file = "flux/flux1-schnell.yaml"
|
||||||
else:
|
else:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
@ -422,6 +423,9 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
|
|
||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
|
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
|
||||||
|
return ModelFormat.BnbQuantizednf4b
|
||||||
return ModelFormat("checkpoint")
|
return ModelFormat("checkpoint")
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
Loading…
Reference in New Issue
Block a user