feat: single app entrypoint with CLI arg parsing

We have two problems with how argparse is being utilized:
- We parse CLI args as the `api_app.py` file is read. This causes a problem pytest, which has an incompatible set of CLI args. Some tests import the FastAPI app, which triggers the config to parse CLI args, which receives the pytest args and fails.
- We've repeatedly had problems when something that uses the config is imported before the CLI args are parsed. When this happens, the root dir may not be set correctly, so we attempt to operate on incorrect paths.

To resolve these issues, we need to lift CLI arg parsing outside of the application code, but still let the application access the CLI args. We can create a external app entrypoint to do this.

- `InvokeAIArgs` is a simple helper class that parses CLI args and stores the result.
- `run_app()` is the new entrypoint. It first parses CLI args, then runs `invoke_api` to start the app.

The `invokeai-web` project script and `invokeai-web.py` dev script now call `run_app()` instead of `invoke_api()`.

The first time `get_config()` is called to get the singleton config object, it retrieves the args from `InvokeAIArgs`, sets the root dir if provided, then merges settings in from `invokeai.yaml`.

CLI arg parsing is now safely insulated from application code, but still accessible. And we don't need to worry about import order having an impact on anything, because by the time the app is running, we have already parsed CLI args. Whew!
This commit is contained in:
psychedelicious 2024-03-15 16:33:52 +11:00
parent 5ecfa86cd0
commit ce9aeeece3
7 changed files with 134 additions and 95 deletions

View File

@ -1,63 +1,57 @@
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import asyncio
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import (
app_info,
board_images,
boards,
download_queue,
images,
model_manager,
session_queue,
utilities,
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config()
app_config.parse_args()
app_config.merge_from_file()
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import (
app_info,
board_images,
boards,
download_queue,
images,
model_manager,
session_queue,
utilities,
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
logger = InvokeAILogger.get_logger(config=app_config)

12
invokeai/app/run_app.py Normal file
View File

@ -0,0 +1,12 @@
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
def run_app() -> None:
# Before doing _anything_, parse CLI args!
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
InvokeAIArgs.parse_args()
from invokeai.app.api_app import invoke_api
invoke_api()

View File

@ -11,7 +11,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.app_arg_parser import app_arg_parser
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
@ -218,24 +218,12 @@ class InvokeAIAppConfig(BaseSettings):
This function will write to the `invokeai.yaml` file if the config is migrated.
If there is no `invokeai.yaml` file, one will be written.
Args:
source_path: Path to the config file. If not provided, the default path is used.
"""
path = source_path or self.init_file_path
if not path.exists():
self.write_file(path)
else:
config_from_file = load_and_migrate_config(path)
self.update_config(config_from_file)
def parse_args(self) -> None:
"""Parse the CLI args and set the runtime root directory."""
opt = app_arg_parser.parse_args()
if root := getattr(opt, "root", None):
self.set_root(Path(root))
config_from_file = load_and_migrate_config(path)
self.update_config(config_from_file)
def set_root(self, root: Path) -> None:
"""Set the runtime root directory. This is typically set using a CLI arg."""
@ -412,5 +400,29 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Return the global singleton app config"""
return InvokeAIAppConfig()
"""Return the global singleton app config.
When called, this function will parse the CLI args and merge in config from the `invokeai.yaml` config file.
"""
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
if root := getattr(args, "root", None):
config.set_root(Path(root))
# TODO(psyche): This shouldn't be wrapped in a try/catch. The configuration script imports a number of classes
# from throughout the app, which in turn call get_config(). At that time, there may not be a config file to
# read from, and this raises.
#
# Once we move all* model installation to the web app, the configure script will not be reaching into the main app
# and we can make this an unhandled error, which feels correct.
#
# *all user-facing models. Core model installation will still be handled by the configure/install script.
try:
config.merge_from_file()
except FileNotFoundError:
pass
return config

View File

@ -1,12 +0,0 @@
from argparse import ArgumentParser, RawTextHelpFormatter
from invokeai.version import __version__
root_help = r"""Sets a root directory for the app. If omitted, the app will search for the root directory in the following order:
- The `$INVOKEAI_ROOT` environment variable
- The currently active virtual environment's parent directory
- `$HOME/invokeai`"""
app_arg_parser = ArgumentParser(description="Invoke Studio", formatter_class=RawTextHelpFormatter)
app_arg_parser.add_argument("--root", type=str, help=root_help)
app_arg_parser.add_argument("--version", action="version", version=__version__, help="Displays the version and exits.")

View File

@ -0,0 +1,41 @@
from argparse import ArgumentParser, Namespace, RawTextHelpFormatter
from typing import Optional
from invokeai.version import __version__
_root_help = r"""Sets a root directory for the app. If omitted, the app will search for the root directory in the following order:
- The `$INVOKEAI_ROOT` environment variable
- The currently active virtual environment's parent directory
- `$HOME/invokeai`"""
_parser = ArgumentParser(description="Invoke Studio", formatter_class=RawTextHelpFormatter)
_parser.add_argument("--root", type=str, help=_root_help)
_parser.add_argument("--version", action="version", version=__version__, help="Displays the version and exits.")
class InvokeAIArgs:
"""Helper class for parsing CLI args.
Args should never be parsed within the application code, only in the CLI entrypoints. Parsing args within the
application creates conflicts when running tests or when using application modules directly.
If the args are needed within the application, the consumer should access them from this class.
Example:
```
# In a CLI wrapper
from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs
InvokeAIArgs.parse_args()
# In the application
from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs
args = InvokeAIArgs.args
"""
args: Optional[Namespace] = None
@staticmethod
def parse_args() -> Optional[Namespace]:
"""Parse CLI args and store the result."""
InvokeAIArgs.args = _parser.parse_args()
return InvokeAIArgs.args

View File

@ -135,7 +135,7 @@ dependencies = [
# "invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
# new shortcut to launch web interface
"invokeai-web" = "invokeai.app.api_app:invoke_api"
"invokeai-web" = "invokeai.app.run_app:run_app"
# full commands
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
@ -146,7 +146,6 @@ dependencies = [
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"

View File

@ -5,22 +5,15 @@
import logging
import os
from invokeai.frontend.cli.app_arg_parser import app_arg_parser
from invokeai.app.run_app import run_app
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
def main():
# Parse CLI args immediately to handle `version` and `help` commands. Once the app starts up, we will parse the
# args again to get configuration args.
app_arg_parser.parse_args()
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from invokeai.app.api_app import invoke_api
invoke_api()
run_app()
if __name__ == "__main__":