mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
partially address --root CLI argument handling
- fix places where `get_config()` is being called at import time rather than at run time. - add regression test for import time get_config() calling.
This commit is contained in:
committed by
psychedelicious
parent
8cd65755ef
commit
d871fca643
@ -12,11 +12,11 @@ from invokeai.app.services.config.config_default import get_config
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
config = get_config()
|
||||
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
config = get_config()
|
||||
if config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
@ -34,7 +34,7 @@ def choose_precision(
|
||||
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
||||
) -> Literal["float32", "float16", "bfloat16"]:
|
||||
"""Return an appropriate precision for the given torch device."""
|
||||
app_config = app_config or config
|
||||
app_config = app_config or get_config()
|
||||
if device.type == "cuda":
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
|
@ -339,7 +339,8 @@ class InvokeAILogger(object): # noqa D102
|
||||
loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, name: str = "InvokeAI", config: InvokeAIAppConfig = get_config()) -> logging.Logger: # noqa D102
|
||||
def get_logger(cls, name: str = "InvokeAI", config: Optional[InvokeAIAppConfig] = None) -> logging.Logger: # noqa D102
|
||||
config = config or get_config()
|
||||
if name in cls.loggers:
|
||||
return cls.loggers[name]
|
||||
|
||||
|
Reference in New Issue
Block a user