diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index e37b12c4f7..58b4843395 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -33,10 +33,16 @@ from invokeai.backend.model_manager.config import ( ) from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.util.silence_warnings import SilenceWarnings +try: + from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 + + bnb_nf4_available = True +except ImportError: + bnb_nf4_available = False + app_config = get_config() @@ -213,6 +219,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): flux_conf: Any, ) -> AnyModel: assert isinstance(config, MainBnbQuantized4bCheckpointConfig) + if not bnb_nf4_available: + raise ImportError( + "The bnb_nf4 module is not available. Please install bitsandbytes if available on your platform." + ) model_path = Path(config.path) dataclass_fields = {f.name for f in fields(FluxParams)} filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}