diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f408dc3e0e..a6ec64d5c7 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -136,12 +136,12 @@ class ModelIdentifierInvocation(BaseInvocation): T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"] T5_ENCODER_MAP: Dict[str, Dict[str, str]] = { "base": { - "repo": "invokeai/flux_dev::t5_xxl_encoder/base", + "repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base", "name": "t5_base_encoder", "format": ModelFormat.T5Encoder, }, "8b_quantized": { - "repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized", + "repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8", "name": "t5_8b_quantized_encoder", "format": ModelFormat.T5Encoder, }, diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index dfa6cef29b..ce6b8ed8cc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -111,6 +111,7 @@ class ModelFormat(str, Enum): T5Encoder = "t5_encoder" T5Encoder8b = "t5_encoder_8b" T5Encoder4b = "t5_encoder_4b" + BnbQuantizednf4b = "bnb_quantized_nf4b" class SchedulerPredictionType(str, Enum): @@ -193,7 +194,7 @@ class ModelConfigBase(BaseModel): class CheckpointConfigBase(ModelConfigBase): """Model config for checkpoint-style models.""" - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint) config_path: str = Field(description="path to the checkpoint model config file") converted_at: Optional[float] = Field( description="When this model was last converted to diffusers", default_factory=time.time @@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase): """Model config for standalone VAE models.""" type: Literal[ModelType.VAE] = ModelType.VAE - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint @staticmethod def get_tag() -> Tag: @@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase) """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint @staticmethod def get_tag() -> Tag: @@ -336,6 +335,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase): return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}") +class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase): + """Model config for main checkpoint models.""" + + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.format = ModelFormat.BnbQuantizednf4b + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}") + + class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase): """Model config for main diffusers models.""" @@ -438,6 +452,7 @@ AnyModelConfig = Annotated[ Union[ Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], + Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()], Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index fcb4e9b2f0..dbc2275d85 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -162,7 +162,7 @@ class ModelProbe(object): fields["description"] = ( fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}" ) - fields["format"] = ModelFormat(fields.get("format")) or probe.get_format() + fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format() fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) fields["default_settings"] = fields.get("default_settings") @@ -179,7 +179,7 @@ class ModelProbe(object): # additional fields needed for main and controlnet models if ( fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] - and fields["format"] is ModelFormat.Checkpoint + and fields["format"] in [ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] ): ckpt_config_path = cls._get_checkpoint_config_path( model_path, @@ -323,6 +323,7 @@ class ModelProbe(object): if model_type is ModelType.Main: if base_type == BaseModelType.Flux: + # TODO: Decide between dev/schnell config_file = "flux/flux1-schnell.yaml" else: config_file = LEGACY_CONFIGS[base_type][variant_type] @@ -422,6 +423,9 @@ class CheckpointProbeBase(ProbeBase): self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) def get_format(self) -> ModelFormat: + state_dict = self.checkpoint.get("state_dict") or self.checkpoint + if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict: + return ModelFormat.BnbQuantizednf4b return ModelFormat("checkpoint") def get_variant_type(self) -> ModelVariantType: