From b7c5a396856509fbe9b613e8e765bd84452e1d6c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 17 May 2023 12:19:19 -0400 Subject: [PATCH] make invokeai.yaml more hierarchical; fix list configuration bug --- invokeai/app/services/config.py | 186 ++++++++++++------ invokeai/backend/config/invokeai_configure.py | 1 + tests/test_config.py | 18 +- 3 files changed, 143 insertions(+), 62 deletions(-) diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index 780f12fe24..8634d595a3 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -4,33 +4,57 @@ Arguments and fields are taken from the pydantic definition of the model. Defaults can be set by creating a yaml configuration file that -has top-level keys corresponding to an invocation name, a command, or -"globals" for global values such as `xformers_enabled`. Currently -graphs cannot be configured this way, but their constituents can be. +has a top-level key of "InvokeAI" and subheadings for each of the +categories returned by `invokeai --help`. The file looks like this: [file: invokeai.yaml] - globals: - nsfw_checker: False - max_loaded_models: 5 - - txt2img: - steps: 20 - scheduler: k_heun - width: 768 - - img2img: - width: 1024 - height: 1024 +InvokeAI: + Paths: + root: /home/lstein/invokeai-main + conf_path: configs/models.yaml + legacy_conf_dir: configs/stable-diffusion + outdir: outputs + embedding_dir: embeddings + lora_dir: loras + autoconvert_dir: null + gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth + Models: + model: stable-diffusion-1.5 + embeddings: true + Memory/Performance: + xformers_enabled: false + sequential_guidance: false + precision: float16 + max_loaded_models: 4 + always_use_cpu: false + free_gpu_mem: false + Features: + nsfw_checker: true + restore: true + esrgan: true + patchmatch: true + internet_available: true + log_tokenization: false + Cross-Origin Resource Sharing: + allow_origins: [] + allow_credentials: true + allow_methods: + - '*' + allow_headers: + - '*' + Web Server: + host: 127.0.0.1 + port: 8081 The default name of the configuration file is `invokeai.yaml`, located -in INVOKEAI_ROOT. You can use any OmegaConf dictionary by passing it -to the config object at initialization time: +in INVOKEAI_ROOT. You can replace supersede this by providing any +OmegaConf dictionary object initialization time: omegaconf = OmegaConf.load('/tmp/init.yaml') conf = InvokeAIAppConfig(conf=omegaconf) -By default, InvokeAIAppConfig will parse the contents of argv at +By default, InvokeAIAppConfig will parse the contents of `sys.argv` at initialization time. You may pass a list of strings in the optional `argv` argument to use instead of the system argv: @@ -42,9 +66,9 @@ has highest priority. conf = InvokeAIAppConfig(xformers_enabled=True) Any setting can be overwritten by setting an environment variable of -form: "INVOKEAI__", as in: +form: "INVOKEAI_", as in: - export INVOKEAI_txt2img_steps=30 + export INVOKEAI_port=8080 Order of precedence (from highest): 1) initialization options @@ -86,8 +110,40 @@ does this: config = get_invokeai_config() print(config.root) +# Subclassing + +If you wish to create a similar class, please subclass the +`InvokeAISettings` class and define a Literal field named "type", +which is set to the desired top-level name. For example, to create a +"InvokeBatch" configuration, define like this: + + class InvokeBatch(InvokeAISettings): + type: Literal["InvokeBatch"] = "InvokeBatch" + node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources') + cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources') + +This will now read and write from the "InvokeBatch" section of the +config file, look for environment variables named INVOKEBATCH_*, and +accept the command-line arguments `--node_count` and `--cpu_count`. The +two configs are kept in separate sections of the config file: + + # invokeai.yaml + + InvokeBatch: + Resources: + node_count: 1 + cpu_count: 8 + + InvokeAI: + Paths: + root: /home/lstein/invokeai-main + conf_path: configs/models.yaml + legacy_conf_dir: configs/stable-diffusion + outdir: outputs + ... ''' import argparse +import typing import os import sys from argparse import ArgumentParser @@ -117,33 +173,62 @@ class InvokeAISettings(BaseSettings): if name not in self._excluded(): setattr(self, name, getattr(opt,name)) + def to_yaml(self)->str: + """ + Return a YAML string representing our settings. This can be used + as the contents of `invokeai.yaml` to restore settings later. + """ + cls = self.__class__ + type = get_args(get_type_hints(cls)['type'])[0] + field_dict = dict({type:dict()}) + for name,field in self.__fields__.items(): + if name in cls._excluded(): + continue + category = field.field_info.extra.get("category") or "Uncategorized" + value = getattr(self,name) + if category not in field_dict[type]: + field_dict[type][category] = dict() + # keep paths as strings to make it easier to read + field_dict[type][category][name] = str(value) if isinstance(value,Path) else value + conf = OmegaConf.create(field_dict) + return OmegaConf.to_yaml(conf) + @classmethod def add_parser_arguments(cls, parser): - env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else 'INVOKEAI_' if 'type' in get_type_hints(cls): - default_settings_stanza = get_args(get_type_hints(cls)['type'])[0] + settings_stanza = get_args(get_type_hints(cls)['type'])[0] else: - default_settings_stanza = 'globals' - initconf = cls.initconf.get(default_settings_stanza) if cls.initconf and default_settings_stanza in cls.initconf else None + settings_stanza = "Uncategorized" + + env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper() + + initconf = cls.initconf.get(settings_stanza) \ + if cls.initconf and settings_stanza in cls.initconf \ + else OmegaConf.create() fields = cls.__fields__ cls.argparse_groups = {} for name, field in fields.items(): if name not in cls._excluded(): - env_name = env_prefix+f'{cls.cmd_name()}_{name}' - if initconf and name in initconf: - field.default = initconf.get(name) + current_default = field.default + + category = field.field_info.extra.get("category","Uncategorized") + env_name = env_prefix + '_' + name + if category in initconf and name in initconf.get(category): + field.default = initconf.get(category).get(name) if env_name in os.environ: field.default = os.environ[env_name] cls.add_field_argument(parser, name, field) + field.default = current_default + @classmethod def cmd_name(self, command_field: str='type')->str: hints = get_type_hints(self) if command_field in hints: return get_args(hints[command_field])[0] else: - return 'globals' + return 'Uncategorized' @classmethod def get_parser(cls)->ArgumentParser: @@ -165,31 +250,11 @@ class InvokeAISettings(BaseSettings): class Config: env_file_encoding = 'utf-8' arbitrary_types_allowed = True - env_prefix = 'INVOKEAI_' case_sensitive = True - @classmethod - def customise_sources( - cls, - init_settings, - env_settings, - file_secret_settings, - ): - return ( - init_settings, - cls._omegaconf_settings_source, - env_settings, - file_secret_settings, - ) - - @classmethod - def _omegaconf_settings_source(cls, settings: BaseSettings) -> dict[str, Any]: - if initconf := InvokeAISettings.initconf: - return initconf.get(settings.cmd_name(),{}) - else: - return {} @classmethod def add_field_argument(cls, command_parser, name: str, field, default_override = None): + field_type = get_type_hints(cls).get(name) default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() if category := field.field_info.extra.get("category"): if category not in cls.argparse_groups: @@ -198,7 +263,7 @@ class InvokeAISettings(BaseSettings): else: argparse_group = command_parser - if get_origin(field.type_) == Literal: + if get_origin(field_type) == Literal: allowed_values = get_args(field.type_) allowed_types = set() for val in allowed_values: @@ -214,6 +279,17 @@ class InvokeAISettings(BaseSettings): choices=allowed_values, help=field.field_info.description, ) + + elif get_origin(field_type) == list: + argparse_group.add_argument( + f"--{name}", + dest=name, + nargs='*', + type=field.type_, + default=default, + action=argparse.BooleanOptionalAction if field.type_==bool else 'store', + help=field.field_info.description, + ) else: argparse_group.add_argument( f"--{name}", @@ -243,7 +319,7 @@ class InvokeAIAppConfig(InvokeAISettings): Application-wide settings. ''' #fmt: off - type: Literal["globals"] = "globals" + type: Literal["InvokeAI"] = "InvokeAI" root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths') conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths') @@ -266,10 +342,10 @@ class InvokeAIAppConfig(InvokeAISettings): patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') - allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing') + allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing') allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing') - allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing') - allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing') + allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing') + allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing') host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') port : int = Field(default=9090, description="Port to bind to", category='Web Server') #fmt: on @@ -286,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings): # Set the runtime root directory. We parse command-line switches here # in order to pick up the --root_dir option. self.parse_args(argv) - if not conf: + if conf is None: try: conf = OmegaConf.load(self.root_dir / INIT_FILE) except: diff --git a/invokeai/backend/config/invokeai_configure.py b/invokeai/backend/config/invokeai_configure.py index e1a336a9b8..12260be208 100755 --- a/invokeai/backend/config/invokeai_configure.py +++ b/invokeai/backend/config/invokeai_configure.py @@ -727,6 +727,7 @@ def write_default_options(program_opts: Namespace, initfile: Path): def migrate_init_file(legacy_format:Path): old = legacy_parser.parse_args([f'@{str(legacy_format)}']) + new = new = OmegaConf.create() new.globals = dict() diff --git a/tests/test_config.py b/tests/test_config.py index c371e77a58..f61216ab15 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,17 +10,21 @@ from invokeai.app.invocations.generate import TextToImageInvocation init1 = OmegaConf.create( ''' -globals: - nsfw_checker: False - max_loaded_models: 5 +InvokeAI: + Features: + nsfw_checker: False + Memory/Performance: + max_loaded_models: 5 ''' ) init2 = OmegaConf.create( ''' - globals: - nsfw_checker: True - max_loaded_models: 2 +InvokeAI: + Features: + nsfw_checker: true + Memory/Performance: + max_loaded_models: 2 ''' ) @@ -50,7 +54,7 @@ def test_env_override(): conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10']) assert conf.nsfw_checker==False - os.environ['INVOKEAI_globals_nsfw_checker'] = 'True' + os.environ['INVOKEAI_nsfw_checker'] = 'True' conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10']) assert conf.nsfw_checker==True