mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make choose_torch_dtype() usable outside an invocation context
This commit is contained in:
parent
edac01d4fb
commit
84f5cbdd97
@ -43,6 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
|||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
@ -636,7 +637,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
precision = torch.float16 if self._app_config.precision == "auto" else torch.dtype(self._app_config.precision)
|
precision = TorchDevice.choose_torch_dtype()
|
||||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||||
|
|
||||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
|
@ -60,6 +60,11 @@ class TorchDevice:
|
|||||||
"""Return the torch.device to use for accelerated inference."""
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
if cls._model_cache:
|
if cls._model_cache:
|
||||||
return cls._model_cache.get_execution_device()
|
return cls._model_cache.get_execution_device()
|
||||||
|
else:
|
||||||
|
return cls._choose_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _choose_device(cls) -> torch.device:
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
if app_config.device != "auto":
|
if app_config.device != "auto":
|
||||||
device = torch.device(app_config.device)
|
device = torch.device(app_config.device)
|
||||||
@ -82,8 +87,8 @@ class TorchDevice:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
"""Return the precision to use for accelerated inference."""
|
"""Return the precision to use for accelerated inference."""
|
||||||
device = device or cls.choose_torch_device()
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
device = device or cls._choose_device()
|
||||||
if device.type == "cuda" and torch.cuda.is_available():
|
if device.type == "cuda" and torch.cuda.is_available():
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||||
|
Loading…
Reference in New Issue
Block a user