mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
reverse stupid hack
This commit is contained in:
parent
d04c880cce
commit
edac01d4fb
@ -43,7 +43,6 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
@ -637,7 +636,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
||||
"""Guess the best HuggingFace variant type to download."""
|
||||
precision = TorchDevice.choose_torch_dtype()
|
||||
precision = torch.float16 if self._app_config.precision == "auto" else torch.dtype(self._app_config.precision)
|
||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
|
@ -59,13 +59,7 @@ class TorchDevice:
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
if cls._model_cache:
|
||||
try:
|
||||
return cls._model_cache.get_execution_device()
|
||||
except ValueError as e: # May happen if no gpu was reserved. Return a generic device.
|
||||
if str(e).startswith("No GPU has been reserved"):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
return cls._model_cache.get_execution_device()
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
|
Loading…
Reference in New Issue
Block a user