feat(backend): clean up choose_precision

- Allow user-defined precision on MPS.
- Use more explicit logic to handle all possible cases.
- Add comments.
- Remove the app_config args (they were effectively unused, just get the config using the singleton getter util)
This commit is contained in:
psychedelicious 2024-04-07 14:28:29 +10:00 committed by Kent Keirsey
parent 29cfe5a274
commit 9ab6655491
2 changed files with 23 additions and 25 deletions

View File

@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger self._logger = logger
self._ram_cache = ram_cache self._ram_cache = ram_cache
self._convert_cache = convert_cache self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device(), app_config) self._torch_dtype = torch_dtype(choose_torch_device())
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """

View File

@ -6,8 +6,7 @@ from typing import Literal, Optional, Union
import torch import torch
from torch import autocast from torch import autocast
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config.config_default import PRECISION, get_config
from invokeai.app.services.config.config_default import get_config
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda") CUDA_DEVICE = torch.device("cuda")
@ -33,35 +32,34 @@ def get_torch_device_name() -> str:
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
# We are in transition here from using a single global AppConfig to allowing multiple def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
# configurations. It is strongly recommended to pass the app_config to this function.
def choose_precision(
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
) -> Literal["float32", "float16", "bfloat16"]:
"""Return an appropriate precision for the given torch device.""" """Return an appropriate precision for the given torch device."""
app_config = app_config or get_config() app_config = get_config()
if device.type == "cuda": if device.type == "cuda":
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
if app_config.precision == "float32": # These GPUs have limited support for float16
return "float32" return "float32"
elif app_config.precision == "bfloat16": elif app_config.precision == "auto" or app_config.precision == "autocast":
return "bfloat16" # Default to float16 for CUDA devices
else: return "float16"
return "float16" else:
# Use the user-defined precision
return app_config.precision
elif device.type == "mps": elif device.type == "mps":
return "float16" if app_config.precision == "auto" or app_config.precision == "autocast":
# Default to float16 for MPS devices
return "float16"
else:
# Use the user-defined precision
return app_config.precision
# CPU / safe fallback
return "float32" return "float32"
# We are in transition here from using a single global AppConfig to allowing multiple def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
# configurations. It is strongly recommended to pass the app_config to this function.
def torch_dtype(
device: Optional[torch.device] = None,
app_config: Optional[InvokeAIAppConfig] = None,
) -> torch.dtype:
device = device or choose_torch_device() device = device or choose_torch_device()
precision = choose_precision(device, app_config) precision = choose_precision(device)
if precision == "float16": if precision == "float16":
return torch.float16 return torch.float16
if precision == "bfloat16": if precision == "bfloat16":
@ -71,7 +69,7 @@ def torch_dtype(
return torch.float32 return torch.float32
def choose_autocast(precision): def choose_autocast(precision: PRECISION):
"""Returns an autocast context or nullcontext for the given precision string""" """Returns an autocast context or nullcontext for the given precision string"""
# float16 currently requires autocast to avoid errors like: # float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float' # 'expected scalar type Half but found Float'