mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Only import bnb quantize file if bitsandbytes is installed
This commit is contained in:
parent
6764dcfdaa
commit
1047584b3e
@ -33,10 +33,16 @@ from invokeai.backend.model_manager.config import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
|
||||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
|
try:
|
||||||
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
|
|
||||||
|
bnb_nf4_available = True
|
||||||
|
except ImportError:
|
||||||
|
bnb_nf4_available = False
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
|
|
||||||
@ -213,6 +219,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
flux_conf: Any,
|
flux_conf: Any,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||||
|
if not bnb_nf4_available:
|
||||||
|
raise ImportError(
|
||||||
|
"The bnb_nf4 module is not available. Please install bitsandbytes if available on your platform."
|
||||||
|
)
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||||
|
Loading…
Reference in New Issue
Block a user