make invokeai.yaml more hierarchical; fix list configuration bug

This commit is contained in:
Lincoln Stein 2023-05-17 12:19:19 -04:00
parent eadfd239a8
commit b7c5a39685
3 changed files with 143 additions and 62 deletions

View File

@ -4,33 +4,57 @@
Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that
has top-level keys corresponding to an invocation name, a command, or
"globals" for global values such as `xformers_enabled`. Currently
graphs cannot be configured this way, but their constituents can be.
has a top-level key of "InvokeAI" and subheadings for each of the
categories returned by `invokeai --help`. The file looks like this:
[file: invokeai.yaml]
globals:
nsfw_checker: False
max_loaded_models: 5
txt2img:
steps: 20
scheduler: k_heun
width: 768
img2img:
width: 1024
height: 1024
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
embedding_dir: embeddings
lora_dir: loras
autoconvert_dir: null
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
Models:
model: stable-diffusion-1.5
embeddings: true
Memory/Performance:
xformers_enabled: false
sequential_guidance: false
precision: float16
max_loaded_models: 4
always_use_cpu: false
free_gpu_mem: false
Features:
nsfw_checker: true
restore: true
esrgan: true
patchmatch: true
internet_available: true
log_tokenization: false
Cross-Origin Resource Sharing:
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
Web Server:
host: 127.0.0.1
port: 8081
The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can use any OmegaConf dictionary by passing it
to the config object at initialization time:
in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf)
By default, InvokeAIAppConfig will parse the contents of argv at
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
@ -42,9 +66,9 @@ has highest priority.
conf = InvokeAIAppConfig(xformers_enabled=True)
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<command>_<value>", as in:
form: "INVOKEAI_<setting>", as in:
export INVOKEAI_txt2img_steps=30
export INVOKEAI_port=8080
Order of precedence (from highest):
1) initialization options
@ -86,8 +110,40 @@ does this:
config = get_invokeai_config()
print(config.root)
# Subclassing
If you wish to create a similar class, please subclass the
`InvokeAISettings` class and define a Literal field named "type",
which is set to the desired top-level name. For example, to create a
"InvokeBatch" configuration, define like this:
class InvokeBatch(InvokeAISettings):
type: Literal["InvokeBatch"] = "InvokeBatch"
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
This will now read and write from the "InvokeBatch" section of the
config file, look for environment variables named INVOKEBATCH_*, and
accept the command-line arguments `--node_count` and `--cpu_count`. The
two configs are kept in separate sections of the config file:
# invokeai.yaml
InvokeBatch:
Resources:
node_count: 1
cpu_count: 8
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
'''
import argparse
import typing
import os
import sys
from argparse import ArgumentParser
@ -117,33 +173,62 @@ class InvokeAISettings(BaseSettings):
if name not in self._excluded():
setattr(self, name, getattr(opt,name))
def to_yaml(self)->str:
"""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__
type = get_args(get_type_hints(cls)['type'])[0]
field_dict = dict({type:dict()})
for name,field in self.__fields__.items():
if name in cls._excluded():
continue
category = field.field_info.extra.get("category") or "Uncategorized"
value = getattr(self,name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
conf = OmegaConf.create(field_dict)
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else 'INVOKEAI_'
if 'type' in get_type_hints(cls):
default_settings_stanza = get_args(get_type_hints(cls)['type'])[0]
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
else:
default_settings_stanza = 'globals'
initconf = cls.initconf.get(default_settings_stanza) if cls.initconf and default_settings_stanza in cls.initconf else None
settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
initconf = cls.initconf.get(settings_stanza) \
if cls.initconf and settings_stanza in cls.initconf \
else OmegaConf.create()
fields = cls.__fields__
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
env_name = env_prefix+f'{cls.cmd_name()}_{name}'
if initconf and name in initconf:
field.default = initconf.get(name)
current_default = field.default
category = field.field_info.extra.get("category","Uncategorized")
env_name = env_prefix + '_' + name
if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name)
if env_name in os.environ:
field.default = os.environ[env_name]
cls.add_field_argument(parser, name, field)
field.default = current_default
@classmethod
def cmd_name(self, command_field: str='type')->str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
return 'globals'
return 'Uncategorized'
@classmethod
def get_parser(cls)->ArgumentParser:
@ -165,31 +250,11 @@ class InvokeAISettings(BaseSettings):
class Config:
env_file_encoding = 'utf-8'
arbitrary_types_allowed = True
env_prefix = 'INVOKEAI_'
case_sensitive = True
@classmethod
def customise_sources(
cls,
init_settings,
env_settings,
file_secret_settings,
):
return (
init_settings,
cls._omegaconf_settings_source,
env_settings,
file_secret_settings,
)
@classmethod
def _omegaconf_settings_source(cls, settings: BaseSettings) -> dict[str, Any]:
if initconf := InvokeAISettings.initconf:
return initconf.get(settings.cmd_name(),{})
else:
return {}
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
field_type = get_type_hints(cls).get(name)
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups:
@ -198,7 +263,7 @@ class InvokeAISettings(BaseSettings):
else:
argparse_group = command_parser
if get_origin(field.type_) == Literal:
if get_origin(field_type) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
@ -214,6 +279,17 @@ class InvokeAISettings(BaseSettings):
choices=allowed_values,
help=field.field_info.description,
)
elif get_origin(field_type) == list:
argparse_group.add_argument(
f"--{name}",
dest=name,
nargs='*',
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
else:
argparse_group.add_argument(
f"--{name}",
@ -243,7 +319,7 @@ class InvokeAIAppConfig(InvokeAISettings):
Application-wide settings.
'''
#fmt: off
type: Literal["globals"] = "globals"
type: Literal["InvokeAI"] = "InvokeAI"
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
@ -266,10 +342,10 @@ class InvokeAIAppConfig(InvokeAISettings):
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
#fmt: on
@ -286,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings):
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
self.parse_args(argv)
if not conf:
if conf is None:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except:

View File

@ -727,6 +727,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new =
new = OmegaConf.create()
new.globals = dict()

View File

@ -10,17 +10,21 @@ from invokeai.app.invocations.generate import TextToImageInvocation
init1 = OmegaConf.create(
'''
globals:
nsfw_checker: False
max_loaded_models: 5
InvokeAI:
Features:
nsfw_checker: False
Memory/Performance:
max_loaded_models: 5
'''
)
init2 = OmegaConf.create(
'''
globals:
nsfw_checker: True
max_loaded_models: 2
InvokeAI:
Features:
nsfw_checker: true
Memory/Performance:
max_loaded_models: 2
'''
)
@ -50,7 +54,7 @@ def test_env_override():
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==False
os.environ['INVOKEAI_globals_nsfw_checker'] = 'True'
os.environ['INVOKEAI_nsfw_checker'] = 'True'
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==True