update config script to work with new config system

This commit is contained in:
Lincoln Stein 2023-05-17 00:18:19 -04:00
parent 8d75e50435
commit eadfd239a8
3 changed files with 45 additions and 87 deletions

View File

@ -17,7 +17,7 @@ from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models from .api.routers import images, sessions, models
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIWebConfig from .services.config import InvokeAIAppConfig
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
@ -38,7 +38,7 @@ socket_io = SocketIO(app)
# parse command-line settings, environment and the init file # parse command-line settings, environment and the init file
# (this is a module global) # (this is a module global)
global web_config global web_config
web_config = InvokeAIWebConfig() web_config = InvokeAIAppConfig()
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")

View File

@ -137,7 +137,6 @@ class InvokeAISettings(BaseSettings):
field.default = os.environ[env_name] field.default = os.environ[env_name]
cls.add_field_argument(parser, name, field) cls.add_field_argument(parser, name, field)
@classmethod @classmethod
def cmd_name(self, command_field: str='type')->str: def cmd_name(self, command_field: str='type')->str:
hints = get_type_hints(self) hints = get_type_hints(self)
@ -267,6 +266,12 @@ class InvokeAIAppConfig(InvokeAISettings):
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
#fmt: on #fmt: on
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs): def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
@ -403,21 +408,6 @@ class InvokeAIAppConfig(InvokeAISettings):
''' '''
return _find_root() return _find_root()
class InvokeAIWebConfig(InvokeAIAppConfig):
'''
Web-specific settings
'''
#fmt: off
type : Literal["web"] = "web"
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
#fmt: on
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig)->InvokeAISettings: def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig)->InvokeAISettings:
''' '''
This returns a singleton InvokeAIAppConfig configuration object. This returns a singleton InvokeAIAppConfig configuration object.

View File

@ -36,7 +36,6 @@ from transformers import (
CLIPTokenizer, CLIPTokenizer,
) )
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
@ -54,7 +53,6 @@ from invokeai.backend.config.model_install_backend import (
) )
from invokeai.app.services.config import ( from invokeai.app.services.config import (
get_invokeai_config, get_invokeai_config,
InvokeAIWebConfig,
InvokeAIAppConfig, InvokeAIAppConfig,
) )
@ -62,6 +60,7 @@ warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = get_invokeai_config() config = get_invokeai_config()
@ -86,13 +85,6 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values. # This is the InvokeAI initialization file, which contains command-line default values.
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting # Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
# or renaming it and then running invokeai-configure again. # or renaming it and then running invokeai-configure again.
# Place frequently-used startup commands here, one or more per line.
# Examples:
# --outdir=D:\data\images
# --no-nsfw_checker
# --web --host=0.0.0.0
# --steps=20
# -Ak_euler_a -C10.0
""" """
@ -105,10 +97,9 @@ If you installed manually from source or with 'pip install': activate the virtua
then run one of the following commands to start InvokeAI. then run one of the following commands to start InvokeAI.
Web UI: Web UI:
invokeai --web # (connect to http://localhost:9090) invokeai-web
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
Command-line interface: Command-line client:
invokeai invokeai
If you installed using an installation script, run: If you installed using an installation script, run:
@ -340,7 +331,7 @@ class editOptsForm(npyscreen.FormMultiPage):
def create(self): def create(self):
program_opts = self.parentApp.program_opts program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts old_opts = self.parentApp.invokeai_opts
first_time = not (config.root / 'invokeai.init').exists() first_time = not (config.root / 'invokeai.yaml').exists()
access_token = HfFolder.get_token() access_token = HfFolder.get_token()
window_width, window_height = get_terminal_size() window_width, window_height = get_terminal_size()
for i in [ for i in [
@ -374,7 +365,7 @@ class editOptsForm(npyscreen.FormMultiPage):
self.outdir = self.add_widget_intelligent( self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename, npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):", name="(<tab> autocompletes, ctrl-N advances):",
value=old_opts.outdir or str(default_output_dir()), value=str(old_opts.outdir) or str(default_output_dir()),
select_dir=True, select_dir=True,
must_exist=False, must_exist=False,
use_two_lines=False, use_two_lines=False,
@ -389,7 +380,7 @@ class editOptsForm(npyscreen.FormMultiPage):
editable=False, editable=False,
color="CONTROL", color="CONTROL",
) )
self.safety_checker = self.add_widget_intelligent( self.nsfw_checker = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="NSFW checker", name="NSFW checker",
value=old_opts.nsfw_checker, value=old_opts.nsfw_checker,
@ -443,7 +434,7 @@ class editOptsForm(npyscreen.FormMultiPage):
relx=5, relx=5,
scroll_exit=True, scroll_exit=True,
) )
self.xformers = self.add_widget_intelligent( self.xformers_enabled = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="Enable xformers support if available", name="Enable xformers support if available",
value=old_opts.xformers_enabled, value=old_opts.xformers_enabled,
@ -578,10 +569,10 @@ class editOptsForm(npyscreen.FormMultiPage):
for attr in [ for attr in [
"outdir", "outdir",
"safety_checker", "nsfw_checker",
"free_gpu_mem", "free_gpu_mem",
"max_loaded_models", "max_loaded_models",
"xformers", "xformers_enabled",
"always_use_cpu", "always_use_cpu",
"embedding_path", "embedding_path",
]: ]:
@ -628,7 +619,7 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace: def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIWebConfig(argv=[]) opts = InvokeAIAppConfig(argv=[])
outdir = Path(opts.outdir) outdir = Path(opts.outdir)
if not outdir.is_absolute(): if not outdir.is_absolute():
opts.outdir = str(config.root / opts.outdir) opts.outdir = str(config.root / opts.outdir)
@ -689,53 +680,31 @@ def run_console_ui(
# ------------------------------------- # -------------------------------------
def write_opts(opts: Namespace, init_file: Path): def write_opts(opts: Namespace, init_file: Path):
""" """
Update the invokeai.init file with values from opts Namespace Update the invokeai.init file with values from current settings.
""" """
# touch file if it doesn't exist
if not init_file.exists():
with open(init_file, "w") as f:
f.write(INIT_FILE_PREAMBLE)
# We want to write in the changed arguments without clobbering if Path(init_file).exists():
# any other initialization values the user has entered. There is config = OmegaConf.load(init_file)
# no good way to do this because of the one-way nature of else:
# argparse: i.e. --outdir could be --outdir, --out, or -o config = OmegaConf.create()
# initfile needs to be replaced with a fully structured format
# such as yaml; this is a hack that will work much of the time
args_to_skip = re.compile(
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
)
# fix windows paths
opts.outdir = opts.outdir.replace("\\", "/")
opts.embedding_path = opts.embedding_path.replace("\\", "/")
new_file = f"{init_file}.new"
try:
lines = [x.strip() for x in open(init_file, "r").readlines()]
with open(new_file, "w") as out_file:
for line in lines:
if len(line) > 0 and not args_to_skip.match(line):
out_file.write(line + "\n")
out_file.write(
f"""
--outdir={opts.outdir}
--embedding_path={opts.embedding_path}
--precision={opts.precision}
--max_loaded_models={int(opts.max_loaded_models)}
--{'no-' if not opts.safety_checker else ''}nsfw_checker
--{'no-' if not opts.xformers else ''}xformers
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
{'--always_use_cpu' if opts.always_use_cpu else ''}
"""
)
except OSError as e:
print(f"** An error occurred while writing the init file: {str(e)}")
os.replace(new_file, init_file) if not config.globals:
config.globals = dict()
globals = config.globals
fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields:
if hasattr(opts,attr):
setattr(globals,attr,getattr(opts,attr))
with open(init_file,'w', encoding='utf-8') as file:
file.write(OmegaConf.to_yaml(config))
if opts.hf_token: if opts.hf_token:
HfLogin(opts.hf_token) HfLogin(opts.hf_token)
# ------------------------------------- # -------------------------------------
def default_output_dir() -> Path: def default_output_dir() -> Path:
return config.root / "outputs" return config.root / "outputs"
@ -751,7 +720,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
write_opts(opt, initfile) write_opts(opt, initfile)
# ------------------------------------- # -------------------------------------
# This is ugly. We're going to bring in # Here we bring in
# the legacy Args object in order to parse # the legacy Args object in order to parse
# the old init file and write out the new # the old init file and write out the new
# yaml format. # yaml format.
@ -760,22 +729,21 @@ def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}']) old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new = OmegaConf.create() new = OmegaConf.create()
new.web = dict()
for attr in ['host','port']:
if hasattr(old,attr):
setattr(new.web,attr,getattr(old,attr))
# change of name
new.web.allow_origins = old.cors or []
new.globals = dict() new.globals = dict()
globals = new.globals globals = new.globals
for attr in ['host','port']:
if hasattr(old,attr):
setattr(globals,attr,getattr(old,attr))
# change of name
globals.allow_origins = old.cors or []
fields = list(get_type_hints(InvokeAIAppConfig).keys()) fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields: for attr in fields:
if hasattr(old,attr): if hasattr(old,attr):
setattr(globals,attr,getattr(old,attr)) setattr(globals,attr,getattr(old,attr))
# a few places where the names have changed # a few places where the names have changed
globals.nsfw_checker = old.safety_checker globals.nsfw_checker = old.nsfw_checker
globals.xformers_enabled = old.xformers globals.xformers_enabled = old.xformers
globals.conf_path = old.conf globals.conf_path = old.conf
globals.embedding_dir = old.embedding_path globals.embedding_dir = old.embedding_path
@ -862,14 +830,14 @@ def main():
initialize_rootdir(config.root, opt.yes_to_all) initialize_rootdir(config.root, opt.yes_to_all)
if opt.yes_to_all: if opt.yes_to_all:
write_default_options(opt, init_file) write_default_options(opt, new_init_file)
init_options = Namespace( init_options = Namespace(
precision="float32" if opt.full_precision else "float16" precision="float32" if opt.full_precision else "float16"
) )
else: else:
init_options, models_to_download = run_console_ui(opt, new_init_file) init_options, models_to_download = run_console_ui(opt, new_init_file)
if init_options: if init_options:
write_opts(init_options, init_file) write_opts(init_options, new_init_file)
else: else:
print( print(
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n' '\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'