mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make invokeai.yaml more hierarchical; fix list configuration bug
This commit is contained in:
parent
eadfd239a8
commit
b7c5a39685
@ -4,33 +4,57 @@
|
|||||||
|
|
||||||
Arguments and fields are taken from the pydantic definition of the
|
Arguments and fields are taken from the pydantic definition of the
|
||||||
model. Defaults can be set by creating a yaml configuration file that
|
model. Defaults can be set by creating a yaml configuration file that
|
||||||
has top-level keys corresponding to an invocation name, a command, or
|
has a top-level key of "InvokeAI" and subheadings for each of the
|
||||||
"globals" for global values such as `xformers_enabled`. Currently
|
categories returned by `invokeai --help`. The file looks like this:
|
||||||
graphs cannot be configured this way, but their constituents can be.
|
|
||||||
|
|
||||||
[file: invokeai.yaml]
|
[file: invokeai.yaml]
|
||||||
|
|
||||||
globals:
|
InvokeAI:
|
||||||
nsfw_checker: False
|
Paths:
|
||||||
max_loaded_models: 5
|
root: /home/lstein/invokeai-main
|
||||||
|
conf_path: configs/models.yaml
|
||||||
txt2img:
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
steps: 20
|
outdir: outputs
|
||||||
scheduler: k_heun
|
embedding_dir: embeddings
|
||||||
width: 768
|
lora_dir: loras
|
||||||
|
autoconvert_dir: null
|
||||||
img2img:
|
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
|
||||||
width: 1024
|
Models:
|
||||||
height: 1024
|
model: stable-diffusion-1.5
|
||||||
|
embeddings: true
|
||||||
|
Memory/Performance:
|
||||||
|
xformers_enabled: false
|
||||||
|
sequential_guidance: false
|
||||||
|
precision: float16
|
||||||
|
max_loaded_models: 4
|
||||||
|
always_use_cpu: false
|
||||||
|
free_gpu_mem: false
|
||||||
|
Features:
|
||||||
|
nsfw_checker: true
|
||||||
|
restore: true
|
||||||
|
esrgan: true
|
||||||
|
patchmatch: true
|
||||||
|
internet_available: true
|
||||||
|
log_tokenization: false
|
||||||
|
Cross-Origin Resource Sharing:
|
||||||
|
allow_origins: []
|
||||||
|
allow_credentials: true
|
||||||
|
allow_methods:
|
||||||
|
- '*'
|
||||||
|
allow_headers:
|
||||||
|
- '*'
|
||||||
|
Web Server:
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 8081
|
||||||
|
|
||||||
The default name of the configuration file is `invokeai.yaml`, located
|
The default name of the configuration file is `invokeai.yaml`, located
|
||||||
in INVOKEAI_ROOT. You can use any OmegaConf dictionary by passing it
|
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||||
to the config object at initialization time:
|
OmegaConf dictionary object initialization time:
|
||||||
|
|
||||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||||
conf = InvokeAIAppConfig(conf=omegaconf)
|
conf = InvokeAIAppConfig(conf=omegaconf)
|
||||||
|
|
||||||
By default, InvokeAIAppConfig will parse the contents of argv at
|
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
|
||||||
initialization time. You may pass a list of strings in the optional
|
initialization time. You may pass a list of strings in the optional
|
||||||
`argv` argument to use instead of the system argv:
|
`argv` argument to use instead of the system argv:
|
||||||
|
|
||||||
@ -42,9 +66,9 @@ has highest priority.
|
|||||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||||
|
|
||||||
Any setting can be overwritten by setting an environment variable of
|
Any setting can be overwritten by setting an environment variable of
|
||||||
form: "INVOKEAI_<command>_<value>", as in:
|
form: "INVOKEAI_<setting>", as in:
|
||||||
|
|
||||||
export INVOKEAI_txt2img_steps=30
|
export INVOKEAI_port=8080
|
||||||
|
|
||||||
Order of precedence (from highest):
|
Order of precedence (from highest):
|
||||||
1) initialization options
|
1) initialization options
|
||||||
@ -86,8 +110,40 @@ does this:
|
|||||||
config = get_invokeai_config()
|
config = get_invokeai_config()
|
||||||
print(config.root)
|
print(config.root)
|
||||||
|
|
||||||
|
# Subclassing
|
||||||
|
|
||||||
|
If you wish to create a similar class, please subclass the
|
||||||
|
`InvokeAISettings` class and define a Literal field named "type",
|
||||||
|
which is set to the desired top-level name. For example, to create a
|
||||||
|
"InvokeBatch" configuration, define like this:
|
||||||
|
|
||||||
|
class InvokeBatch(InvokeAISettings):
|
||||||
|
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||||
|
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||||
|
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||||
|
|
||||||
|
This will now read and write from the "InvokeBatch" section of the
|
||||||
|
config file, look for environment variables named INVOKEBATCH_*, and
|
||||||
|
accept the command-line arguments `--node_count` and `--cpu_count`. The
|
||||||
|
two configs are kept in separate sections of the config file:
|
||||||
|
|
||||||
|
# invokeai.yaml
|
||||||
|
|
||||||
|
InvokeBatch:
|
||||||
|
Resources:
|
||||||
|
node_count: 1
|
||||||
|
cpu_count: 8
|
||||||
|
|
||||||
|
InvokeAI:
|
||||||
|
Paths:
|
||||||
|
root: /home/lstein/invokeai-main
|
||||||
|
conf_path: configs/models.yaml
|
||||||
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
|
outdir: outputs
|
||||||
|
...
|
||||||
'''
|
'''
|
||||||
import argparse
|
import argparse
|
||||||
|
import typing
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
@ -117,33 +173,62 @@ class InvokeAISettings(BaseSettings):
|
|||||||
if name not in self._excluded():
|
if name not in self._excluded():
|
||||||
setattr(self, name, getattr(opt,name))
|
setattr(self, name, getattr(opt,name))
|
||||||
|
|
||||||
|
def to_yaml(self)->str:
|
||||||
|
"""
|
||||||
|
Return a YAML string representing our settings. This can be used
|
||||||
|
as the contents of `invokeai.yaml` to restore settings later.
|
||||||
|
"""
|
||||||
|
cls = self.__class__
|
||||||
|
type = get_args(get_type_hints(cls)['type'])[0]
|
||||||
|
field_dict = dict({type:dict()})
|
||||||
|
for name,field in self.__fields__.items():
|
||||||
|
if name in cls._excluded():
|
||||||
|
continue
|
||||||
|
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||||
|
value = getattr(self,name)
|
||||||
|
if category not in field_dict[type]:
|
||||||
|
field_dict[type][category] = dict()
|
||||||
|
# keep paths as strings to make it easier to read
|
||||||
|
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||||
|
conf = OmegaConf.create(field_dict)
|
||||||
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_parser_arguments(cls, parser):
|
def add_parser_arguments(cls, parser):
|
||||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else 'INVOKEAI_'
|
|
||||||
if 'type' in get_type_hints(cls):
|
if 'type' in get_type_hints(cls):
|
||||||
default_settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||||
else:
|
else:
|
||||||
default_settings_stanza = 'globals'
|
settings_stanza = "Uncategorized"
|
||||||
initconf = cls.initconf.get(default_settings_stanza) if cls.initconf and default_settings_stanza in cls.initconf else None
|
|
||||||
|
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||||
|
|
||||||
|
initconf = cls.initconf.get(settings_stanza) \
|
||||||
|
if cls.initconf and settings_stanza in cls.initconf \
|
||||||
|
else OmegaConf.create()
|
||||||
|
|
||||||
fields = cls.__fields__
|
fields = cls.__fields__
|
||||||
cls.argparse_groups = {}
|
cls.argparse_groups = {}
|
||||||
for name, field in fields.items():
|
for name, field in fields.items():
|
||||||
if name not in cls._excluded():
|
if name not in cls._excluded():
|
||||||
env_name = env_prefix+f'{cls.cmd_name()}_{name}'
|
current_default = field.default
|
||||||
if initconf and name in initconf:
|
|
||||||
field.default = initconf.get(name)
|
category = field.field_info.extra.get("category","Uncategorized")
|
||||||
|
env_name = env_prefix + '_' + name
|
||||||
|
if category in initconf and name in initconf.get(category):
|
||||||
|
field.default = initconf.get(category).get(name)
|
||||||
if env_name in os.environ:
|
if env_name in os.environ:
|
||||||
field.default = os.environ[env_name]
|
field.default = os.environ[env_name]
|
||||||
cls.add_field_argument(parser, name, field)
|
cls.add_field_argument(parser, name, field)
|
||||||
|
|
||||||
|
field.default = current_default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cmd_name(self, command_field: str='type')->str:
|
def cmd_name(self, command_field: str='type')->str:
|
||||||
hints = get_type_hints(self)
|
hints = get_type_hints(self)
|
||||||
if command_field in hints:
|
if command_field in hints:
|
||||||
return get_args(hints[command_field])[0]
|
return get_args(hints[command_field])[0]
|
||||||
else:
|
else:
|
||||||
return 'globals'
|
return 'Uncategorized'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_parser(cls)->ArgumentParser:
|
def get_parser(cls)->ArgumentParser:
|
||||||
@ -165,31 +250,11 @@ class InvokeAISettings(BaseSettings):
|
|||||||
class Config:
|
class Config:
|
||||||
env_file_encoding = 'utf-8'
|
env_file_encoding = 'utf-8'
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
env_prefix = 'INVOKEAI_'
|
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
@classmethod
|
|
||||||
def customise_sources(
|
|
||||||
cls,
|
|
||||||
init_settings,
|
|
||||||
env_settings,
|
|
||||||
file_secret_settings,
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
init_settings,
|
|
||||||
cls._omegaconf_settings_source,
|
|
||||||
env_settings,
|
|
||||||
file_secret_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _omegaconf_settings_source(cls, settings: BaseSettings) -> dict[str, Any]:
|
|
||||||
if initconf := InvokeAISettings.initconf:
|
|
||||||
return initconf.get(settings.cmd_name(),{})
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||||
|
field_type = get_type_hints(cls).get(name)
|
||||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||||
if category := field.field_info.extra.get("category"):
|
if category := field.field_info.extra.get("category"):
|
||||||
if category not in cls.argparse_groups:
|
if category not in cls.argparse_groups:
|
||||||
@ -198,7 +263,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
else:
|
else:
|
||||||
argparse_group = command_parser
|
argparse_group = command_parser
|
||||||
|
|
||||||
if get_origin(field.type_) == Literal:
|
if get_origin(field_type) == Literal:
|
||||||
allowed_values = get_args(field.type_)
|
allowed_values = get_args(field.type_)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
for val in allowed_values:
|
for val in allowed_values:
|
||||||
@ -214,6 +279,17 @@ class InvokeAISettings(BaseSettings):
|
|||||||
choices=allowed_values,
|
choices=allowed_values,
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif get_origin(field_type) == list:
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
nargs='*',
|
||||||
|
type=field.type_,
|
||||||
|
default=default,
|
||||||
|
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
argparse_group.add_argument(
|
argparse_group.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
@ -243,7 +319,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
Application-wide settings.
|
Application-wide settings.
|
||||||
'''
|
'''
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["globals"] = "globals"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
@ -266,10 +342,10 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||||
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
|
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
|
||||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
|
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
|
||||||
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
|
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||||
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
|
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||||
#fmt: on
|
#fmt: on
|
||||||
@ -286,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
# Set the runtime root directory. We parse command-line switches here
|
# Set the runtime root directory. We parse command-line switches here
|
||||||
# in order to pick up the --root_dir option.
|
# in order to pick up the --root_dir option.
|
||||||
self.parse_args(argv)
|
self.parse_args(argv)
|
||||||
if not conf:
|
if conf is None:
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||||
except:
|
except:
|
||||||
|
@ -727,6 +727,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
|||||||
def migrate_init_file(legacy_format:Path):
|
def migrate_init_file(legacy_format:Path):
|
||||||
|
|
||||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||||
|
new =
|
||||||
new = OmegaConf.create()
|
new = OmegaConf.create()
|
||||||
|
|
||||||
new.globals = dict()
|
new.globals = dict()
|
||||||
|
@ -10,17 +10,21 @@ from invokeai.app.invocations.generate import TextToImageInvocation
|
|||||||
|
|
||||||
init1 = OmegaConf.create(
|
init1 = OmegaConf.create(
|
||||||
'''
|
'''
|
||||||
globals:
|
InvokeAI:
|
||||||
nsfw_checker: False
|
Features:
|
||||||
max_loaded_models: 5
|
nsfw_checker: False
|
||||||
|
Memory/Performance:
|
||||||
|
max_loaded_models: 5
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
|
|
||||||
init2 = OmegaConf.create(
|
init2 = OmegaConf.create(
|
||||||
'''
|
'''
|
||||||
globals:
|
InvokeAI:
|
||||||
nsfw_checker: True
|
Features:
|
||||||
max_loaded_models: 2
|
nsfw_checker: true
|
||||||
|
Memory/Performance:
|
||||||
|
max_loaded_models: 2
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -50,7 +54,7 @@ def test_env_override():
|
|||||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
||||||
assert conf.nsfw_checker==False
|
assert conf.nsfw_checker==False
|
||||||
|
|
||||||
os.environ['INVOKEAI_globals_nsfw_checker'] = 'True'
|
os.environ['INVOKEAI_nsfw_checker'] = 'True'
|
||||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
||||||
assert conf.nsfw_checker==True
|
assert conf.nsfw_checker==True
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user