mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: single app entrypoint with CLI arg parsing
We have two problems with how argparse is being utilized: - We parse CLI args as the `api_app.py` file is read. This causes a problem pytest, which has an incompatible set of CLI args. Some tests import the FastAPI app, which triggers the config to parse CLI args, which receives the pytest args and fails. - We've repeatedly had problems when something that uses the config is imported before the CLI args are parsed. When this happens, the root dir may not be set correctly, so we attempt to operate on incorrect paths. To resolve these issues, we need to lift CLI arg parsing outside of the application code, but still let the application access the CLI args. We can create a external app entrypoint to do this. - `InvokeAIArgs` is a simple helper class that parses CLI args and stores the result. - `run_app()` is the new entrypoint. It first parses CLI args, then runs `invoke_api` to start the app. The `invokeai-web` project script and `invokeai-web.py` dev script now call `run_app()` instead of `invoke_api()`. The first time `get_config()` is called to get the singleton config object, it retrieves the args from `InvokeAIArgs`, sets the root dir if provided, then merges settings in from `invokeai.yaml`. CLI arg parsing is now safely insulated from application code, but still accessible. And we don't need to worry about import order having an impact on anything, because by the time the app is running, we have already parsed CLI args. Whew!
This commit is contained in:
parent
5ecfa86cd0
commit
ce9aeeece3
@ -1,63 +1,57 @@
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
app_config = get_config()
|
||||
app_config.parse_args()
|
||||
app_config.merge_from_file()
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
|
12
invokeai/app/run_app.py
Normal file
12
invokeai/app/run_app.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
|
||||
|
||||
|
||||
def run_app() -> None:
|
||||
# Before doing _anything_, parse CLI args!
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
InvokeAIArgs.parse_args()
|
||||
|
||||
from invokeai.app.api_app import invoke_api
|
||||
|
||||
invoke_api()
|
@ -11,7 +11,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.app_arg_parser import app_arg_parser
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@ -218,24 +218,12 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
This function will write to the `invokeai.yaml` file if the config is migrated.
|
||||
|
||||
If there is no `invokeai.yaml` file, one will be written.
|
||||
|
||||
Args:
|
||||
source_path: Path to the config file. If not provided, the default path is used.
|
||||
"""
|
||||
path = source_path or self.init_file_path
|
||||
|
||||
if not path.exists():
|
||||
self.write_file(path)
|
||||
else:
|
||||
config_from_file = load_and_migrate_config(path)
|
||||
self.update_config(config_from_file)
|
||||
|
||||
def parse_args(self) -> None:
|
||||
"""Parse the CLI args and set the runtime root directory."""
|
||||
opt = app_arg_parser.parse_args()
|
||||
if root := getattr(opt, "root", None):
|
||||
self.set_root(Path(root))
|
||||
config_from_file = load_and_migrate_config(path)
|
||||
self.update_config(config_from_file)
|
||||
|
||||
def set_root(self, root: Path) -> None:
|
||||
"""Set the runtime root directory. This is typically set using a CLI arg."""
|
||||
@ -412,5 +400,29 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Return the global singleton app config"""
|
||||
return InvokeAIAppConfig()
|
||||
"""Return the global singleton app config.
|
||||
|
||||
When called, this function will parse the CLI args and merge in config from the `invokeai.yaml` config file.
|
||||
"""
|
||||
config = InvokeAIAppConfig()
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
if root := getattr(args, "root", None):
|
||||
config.set_root(Path(root))
|
||||
|
||||
# TODO(psyche): This shouldn't be wrapped in a try/catch. The configuration script imports a number of classes
|
||||
# from throughout the app, which in turn call get_config(). At that time, there may not be a config file to
|
||||
# read from, and this raises.
|
||||
#
|
||||
# Once we move all* model installation to the web app, the configure script will not be reaching into the main app
|
||||
# and we can make this an unhandled error, which feels correct.
|
||||
#
|
||||
# *all user-facing models. Core model installation will still be handled by the configure/install script.
|
||||
|
||||
try:
|
||||
config.merge_from_file()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return config
|
||||
|
@ -1,12 +0,0 @@
|
||||
from argparse import ArgumentParser, RawTextHelpFormatter
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
root_help = r"""Sets a root directory for the app. If omitted, the app will search for the root directory in the following order:
|
||||
- The `$INVOKEAI_ROOT` environment variable
|
||||
- The currently active virtual environment's parent directory
|
||||
- `$HOME/invokeai`"""
|
||||
|
||||
app_arg_parser = ArgumentParser(description="Invoke Studio", formatter_class=RawTextHelpFormatter)
|
||||
app_arg_parser.add_argument("--root", type=str, help=root_help)
|
||||
app_arg_parser.add_argument("--version", action="version", version=__version__, help="Displays the version and exits.")
|
41
invokeai/frontend/cli/arg_parser.py
Normal file
41
invokeai/frontend/cli/arg_parser.py
Normal file
@ -0,0 +1,41 @@
|
||||
from argparse import ArgumentParser, Namespace, RawTextHelpFormatter
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
_root_help = r"""Sets a root directory for the app. If omitted, the app will search for the root directory in the following order:
|
||||
- The `$INVOKEAI_ROOT` environment variable
|
||||
- The currently active virtual environment's parent directory
|
||||
- `$HOME/invokeai`"""
|
||||
|
||||
_parser = ArgumentParser(description="Invoke Studio", formatter_class=RawTextHelpFormatter)
|
||||
_parser.add_argument("--root", type=str, help=_root_help)
|
||||
_parser.add_argument("--version", action="version", version=__version__, help="Displays the version and exits.")
|
||||
|
||||
|
||||
class InvokeAIArgs:
|
||||
"""Helper class for parsing CLI args.
|
||||
|
||||
Args should never be parsed within the application code, only in the CLI entrypoints. Parsing args within the
|
||||
application creates conflicts when running tests or when using application modules directly.
|
||||
|
||||
If the args are needed within the application, the consumer should access them from this class.
|
||||
|
||||
Example:
|
||||
```
|
||||
# In a CLI wrapper
|
||||
from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs
|
||||
InvokeAIArgs.parse_args()
|
||||
|
||||
# In the application
|
||||
from invokeai.frontend.cli.app_arg_parser import InvokeAIArgs
|
||||
args = InvokeAIArgs.args
|
||||
"""
|
||||
|
||||
args: Optional[Namespace] = None
|
||||
|
||||
@staticmethod
|
||||
def parse_args() -> Optional[Namespace]:
|
||||
"""Parse CLI args and store the result."""
|
||||
InvokeAIArgs.args = _parser.parse_args()
|
||||
return InvokeAIArgs.args
|
@ -135,7 +135,7 @@ dependencies = [
|
||||
# "invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
|
||||
|
||||
# new shortcut to launch web interface
|
||||
"invokeai-web" = "invokeai.app.api_app:invoke_api"
|
||||
"invokeai-web" = "invokeai.app.run_app:run_app"
|
||||
|
||||
# full commands
|
||||
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
|
||||
@ -146,7 +146,6 @@ dependencies = [
|
||||
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
||||
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
||||
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
|
||||
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
|
||||
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
|
||||
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
|
||||
|
||||
|
@ -5,22 +5,15 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from invokeai.frontend.cli.app_arg_parser import app_arg_parser
|
||||
from invokeai.app.run_app import run_app
|
||||
|
||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
||||
|
||||
|
||||
def main():
|
||||
# Parse CLI args immediately to handle `version` and `help` commands. Once the app starts up, we will parse the
|
||||
# args again to get configuration args.
|
||||
app_arg_parser.parse_args()
|
||||
|
||||
# Change working directory to the repo root
|
||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from invokeai.app.api_app import invoke_api
|
||||
|
||||
invoke_api()
|
||||
run_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user