mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): clean up choose_precision
- Allow user-defined precision on MPS. - Use more explicit logic to handle all possible cases. - Add comments. - Remove the app_config args (they were effectively unused, just get the config using the singleton getter util)
This commit is contained in:
parent
29cfe5a274
commit
9ab6655491
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
|
@ -6,8 +6,7 @@ from typing import Literal, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import PRECISION, get_config
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
@ -33,35 +32,34 @@ def get_torch_device_name() -> str:
|
|||||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||||
|
|
||||||
|
|
||||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
||||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
|
||||||
def choose_precision(
|
|
||||||
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
|
||||||
) -> Literal["float32", "float16", "bfloat16"]:
|
|
||||||
"""Return an appropriate precision for the given torch device."""
|
"""Return an appropriate precision for the given torch device."""
|
||||||
app_config = app_config or get_config()
|
app_config = get_config()
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||||
if app_config.precision == "float32":
|
# These GPUs have limited support for float16
|
||||||
return "float32"
|
return "float32"
|
||||||
elif app_config.precision == "bfloat16":
|
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
||||||
return "bfloat16"
|
# Default to float16 for CUDA devices
|
||||||
else:
|
return "float16"
|
||||||
return "float16"
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return app_config.precision
|
||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
return "float16"
|
if app_config.precision == "auto" or app_config.precision == "autocast":
|
||||||
|
# Default to float16 for MPS devices
|
||||||
|
return "float16"
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return app_config.precision
|
||||||
|
# CPU / safe fallback
|
||||||
return "float32"
|
return "float32"
|
||||||
|
|
||||||
|
|
||||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
|
||||||
def torch_dtype(
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
app_config: Optional[InvokeAIAppConfig] = None,
|
|
||||||
) -> torch.dtype:
|
|
||||||
device = device or choose_torch_device()
|
device = device or choose_torch_device()
|
||||||
precision = choose_precision(device, app_config)
|
precision = choose_precision(device)
|
||||||
if precision == "float16":
|
if precision == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if precision == "bfloat16":
|
if precision == "bfloat16":
|
||||||
@ -71,7 +69,7 @@ def torch_dtype(
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def choose_autocast(precision):
|
def choose_autocast(precision: PRECISION):
|
||||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
"""Returns an autocast context or nullcontext for the given precision string"""
|
||||||
# float16 currently requires autocast to avoid errors like:
|
# float16 currently requires autocast to avoid errors like:
|
||||||
# 'expected scalar type Half but found Float'
|
# 'expected scalar type Half but found Float'
|
||||||
|
Loading…
Reference in New Issue
Block a user