partially address --root CLI argument handling

- fix places where `get_config()` is being called at import time rather
  than at run time.

- add regression test for import time get_config() calling.
This commit is contained in:
Lincoln Stein
2024-03-16 22:25:19 -04:00
committed by psychedelicious
parent 8cd65755ef
commit d871fca643
5 changed files with 24 additions and 11 deletions

View File

@ -12,11 +12,11 @@ from invokeai.app.services.config.config_default import get_config
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
config = get_config()
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
config = get_config()
if config.device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
@ -34,7 +34,7 @@ def choose_precision(
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
) -> Literal["float32", "float16", "bfloat16"]:
"""Return an appropriate precision for the given torch device."""
app_config = app_config or config
app_config = app_config or get_config()
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):

View File

@ -339,7 +339,8 @@ class InvokeAILogger(object): # noqa D102
loggers: Dict[str, logging.Logger] = {}
@classmethod
def get_logger(cls, name: str = "InvokeAI", config: InvokeAIAppConfig = get_config()) -> logging.Logger: # noqa D102
def get_logger(cls, name: str = "InvokeAI", config: Optional[InvokeAIAppConfig] = None) -> logging.Logger: # noqa D102
config = config or get_config()
if name in cls.loggers:
return cls.loggers[name]