Refactor root detection code

This commit is contained in:
Lincoln Stein 2023-07-31 21:15:44 -04:00
parent 52437205bb
commit 7cd8b2f207
2 changed files with 6 additions and 9 deletions

View File

@ -274,7 +274,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(self) -> List[str]: def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options # internal fields that shouldn't be exposed as command line options
return ["type", "initconf", "cached_root"] return ["type", "initconf"]
@classmethod @classmethod
def _excluded_from_yaml(self) -> List[str]: def _excluded_from_yaml(self) -> List[str]:
@ -290,7 +290,6 @@ class InvokeAISettings(BaseSettings):
"restore", "restore",
"root", "root",
"nsfw_checker", "nsfw_checker",
"cached_root",
] ]
class Config: class Config:
@ -357,6 +356,7 @@ def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".") venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"): if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve() root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
os.environ["INVOKEAI_ROOT"] = str(root) # absolutize it to protect against code doing a cwd()
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]): elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve() root = (venv.parent).resolve()
else: else:
@ -424,7 +424,6 @@ class InvokeAIAppConfig(InvokeAISettings):
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging") log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
# fmt: on # fmt: on
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False): def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
@ -472,15 +471,11 @@ class InvokeAIAppConfig(InvokeAISettings):
""" """
Path to the runtime root directory Path to the runtime root directory
""" """
# we cache value of root to protect against it being '.' and the cwd changing if self.root:
if self.cached_root:
root = self.cached_root
elif self.root:
root = Path(self.root).expanduser().absolute() root = Path(self.root).expanduser().absolute()
else: else:
root = self.find_root() root = self.find_root()
self.cached_root = root return root
return self.cached_root
@property @property
def root_dir(self) -> Path: def root_dir(self) -> Path:

View File

@ -767,7 +767,9 @@ def main():
invoke_args.extend(["--root", opt.root]) invoke_args.extend(["--root", opt.root])
if opt.full_precision: if opt.full_precision:
invoke_args.extend(["--precision", "float32"]) invoke_args.extend(["--precision", "float32"])
print(f"DEBUG: {invoke_args}")
config.parse_args(invoke_args) config.parse_args(invoke_args)
print(f"DEBUG: {config.root} {config.root_path}")
logger = InvokeAILogger().getLogger(config=config) logger = InvokeAILogger().getLogger(config=config)
if not config.model_conf_path.exists(): if not config.model_conf_path.exists():