reverse stupid hack

This commit is contained in:
Lincoln Stein 2024-04-16 18:13:59 -04:00
parent d04c880cce
commit edac01d4fb
2 changed files with 2 additions and 9 deletions

View File

@ -43,7 +43,6 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import TorchDevice
from .model_install_base import ( from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP, MODEL_SOURCE_TO_TYPE_MAP,
@ -637,7 +636,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _guess_variant(self) -> Optional[ModelRepoVariant]: def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download.""" """Guess the best HuggingFace variant type to download."""
precision = TorchDevice.choose_torch_dtype() precision = torch.float16 if self._app_config.precision == "auto" else torch.dtype(self._app_config.precision)
return ModelRepoVariant.FP16 if precision == torch.float16 else None return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:

View File

@ -59,13 +59,7 @@ class TorchDevice:
def choose_torch_device(cls) -> torch.device: def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference.""" """Return the torch.device to use for accelerated inference."""
if cls._model_cache: if cls._model_cache:
try:
return cls._model_cache.get_execution_device() return cls._model_cache.get_execution_device()
except ValueError as e: # May happen if no gpu was reserved. Return a generic device.
if str(e).startswith("No GPU has been reserved"):
pass
else:
raise e
app_config = get_config() app_config = get_config()
if app_config.device != "auto": if app_config.device != "auto":
device = torch.device(app_config.device) device = torch.device(app_config.device)