From edac01d4fb8e921b620147efb5bb067c87422229 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 16 Apr 2024 18:13:59 -0400 Subject: [PATCH] reverse stupid hack --- .../app/services/model_install/model_install_default.py | 3 +-- invokeai/backend/util/devices.py | 8 +------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 6a3117bcb8..32b31f744c 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -43,7 +43,6 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import InvokeAILogger -from invokeai.backend.util.devices import TorchDevice from .model_install_base import ( MODEL_SOURCE_TO_TYPE_MAP, @@ -637,7 +636,7 @@ class ModelInstallService(ModelInstallServiceBase): def _guess_variant(self) -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" - precision = TorchDevice.choose_torch_dtype() + precision = torch.float16 if self._app_config.precision == "auto" else torch.dtype(self._app_config.precision) return ModelRepoVariant.FP16 if precision == torch.float16 else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index c7db33a667..745c128099 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -59,13 +59,7 @@ class TorchDevice: def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" if cls._model_cache: - try: - return cls._model_cache.get_execution_device() - except ValueError as e: # May happen if no gpu was reserved. Return a generic device. - if str(e).startswith("No GPU has been reserved"): - pass - else: - raise e + return cls._model_cache.get_execution_device() app_config = get_config() if app_config.device != "auto": device = torch.device(app_config.device)