From 84f5cbdd9775b2b05acfbad954fff4a476ed92ba Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 16 Apr 2024 19:19:19 -0400 Subject: [PATCH] make choose_torch_dtype() usable outside an invocation context --- .../app/services/model_install/model_install_default.py | 3 ++- invokeai/backend/util/devices.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 32b31f744c..6a3117bcb8 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -43,6 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import InvokeAILogger +from invokeai.backend.util.devices import TorchDevice from .model_install_base import ( MODEL_SOURCE_TO_TYPE_MAP, @@ -636,7 +637,7 @@ class ModelInstallService(ModelInstallServiceBase): def _guess_variant(self) -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" - precision = torch.float16 if self._app_config.precision == "auto" else torch.dtype(self._app_config.precision) + precision = TorchDevice.choose_torch_dtype() return ModelRepoVariant.FP16 if precision == torch.float16 else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 745c128099..dc2bafaa9c 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -60,6 +60,11 @@ class TorchDevice: """Return the torch.device to use for accelerated inference.""" if cls._model_cache: return cls._model_cache.get_execution_device() + else: + return cls._choose_device() + + @classmethod + def _choose_device(cls) -> torch.device: app_config = get_config() if app_config.device != "auto": device = torch.device(app_config.device) @@ -82,8 +87,8 @@ class TorchDevice: @classmethod def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: """Return the precision to use for accelerated inference.""" - device = device or cls.choose_torch_device() config = get_config() + device = device or cls._choose_device() if device.type == "cuda" and torch.cuda.is_available(): device_name = torch.cuda.get_device_name(device) if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name: