mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
partially address --root CLI argument handling
- fix places where `get_config()` is being called at import time rather than at run time. - add regression test for import time get_config() calling.
This commit is contained in:
parent
8cd65755ef
commit
d871fca643
@ -33,12 +33,12 @@ from tqdm import tqdm
|
|||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.configs as model_configs
|
import invokeai.configs as model_configs
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||||
from invokeai.backend.model_manager import ModelType
|
from invokeai.backend.model_manager import ModelType
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
from invokeai.frontend.install.model_install import addModelsForm
|
from invokeai.frontend.install.model_install import addModelsForm
|
||||||
|
|
||||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||||
@ -63,8 +63,7 @@ def get_literal_fields(field: str) -> Tuple[Any]:
|
|||||||
|
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
|
config = None
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
PRECISION_CHOICES = get_literal_fields("precision")
|
PRECISION_CHOICES = get_literal_fields("precision")
|
||||||
DEVICE_CHOICES = get_literal_fields("device")
|
DEVICE_CHOICES = get_literal_fields("device")
|
||||||
@ -745,6 +744,8 @@ def is_v2_install(root: Path) -> bool:
|
|||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
global FORCE_FULL_PRECISION # FIXME
|
global FORCE_FULL_PRECISION # FIXME
|
||||||
|
global config
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-sd-weights",
|
"--skip-sd-weights",
|
||||||
@ -787,10 +788,12 @@ def main() -> None:
|
|||||||
default=None,
|
default=None,
|
||||||
help="path to root of install directory",
|
help="path to root of install directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if opt.root:
|
|
||||||
config.set_root(Path(opt.root))
|
InvokeAIArgs.args = opt
|
||||||
|
config = get_config()
|
||||||
if opt.full_precision:
|
if opt.full_precision:
|
||||||
updates["precision"] = "float32"
|
updates["precision"] = "float32"
|
||||||
|
|
||||||
|
@ -12,11 +12,11 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
|
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Convenience routine for guessing which GPU device to run model on"""
|
||||||
|
config = get_config()
|
||||||
if config.device == "auto":
|
if config.device == "auto":
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return torch.device("cuda")
|
return torch.device("cuda")
|
||||||
@ -34,7 +34,7 @@ def choose_precision(
|
|||||||
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
||||||
) -> Literal["float32", "float16", "bfloat16"]:
|
) -> Literal["float32", "float16", "bfloat16"]:
|
||||||
"""Return an appropriate precision for the given torch device."""
|
"""Return an appropriate precision for the given torch device."""
|
||||||
app_config = app_config or config
|
app_config = app_config or get_config()
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||||
|
@ -339,7 +339,8 @@ class InvokeAILogger(object): # noqa D102
|
|||||||
loggers: Dict[str, logging.Logger] = {}
|
loggers: Dict[str, logging.Logger] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_logger(cls, name: str = "InvokeAI", config: InvokeAIAppConfig = get_config()) -> logging.Logger: # noqa D102
|
def get_logger(cls, name: str = "InvokeAI", config: Optional[InvokeAIAppConfig] = None) -> logging.Logger: # noqa D102
|
||||||
|
config = config or get_config()
|
||||||
if name in cls.loggers:
|
if name in cls.loggers:
|
||||||
return cls.loggers[name]
|
return cls.loggers[name]
|
||||||
|
|
||||||
|
@ -30,11 +30,11 @@ class InvokeAIArgs:
|
|||||||
Example:
|
Example:
|
||||||
```
|
```
|
||||||
# In a CLI wrapper
|
# In a CLI wrapper
|
||||||
from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
InvokeAIArgs.parse_args()
|
InvokeAIArgs.parse_args()
|
||||||
|
|
||||||
# In the application
|
# In the application
|
||||||
from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
args = InvokeAIArgs.args
|
args = InvokeAIArgs.args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -7,6 +8,7 @@ from omegaconf import OmegaConf
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
|
||||||
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
v4_config = """
|
v4_config = """
|
||||||
schema_version: 4
|
schema_version: 4
|
||||||
@ -76,6 +78,13 @@ def test_read_config_from_file(tmp_path: Path):
|
|||||||
assert config.port == 8080
|
assert config.port == 8080
|
||||||
|
|
||||||
|
|
||||||
|
def test_arg_parsing():
|
||||||
|
sys.argv = ["test_config.py", "--root", "/tmp"]
|
||||||
|
InvokeAIArgs.parse_args()
|
||||||
|
config = get_config()
|
||||||
|
assert config.root_path == Path("/tmp")
|
||||||
|
|
||||||
|
|
||||||
def test_migrate_v3_config_from_file(tmp_path: Path):
|
def test_migrate_v3_config_from_file(tmp_path: Path):
|
||||||
"""Test reading configuration from a file."""
|
"""Test reading configuration from a file."""
|
||||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
Loading…
Reference in New Issue
Block a user