mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
update config script to work with new config system
This commit is contained in:
parent
8d75e50435
commit
eadfd239a8
@ -17,7 +17,7 @@ from .api.dependencies import ApiDependencies
|
|||||||
from .api.routers import images, sessions, models
|
from .api.routers import images, sessions, models
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.config import InvokeAIWebConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
@ -38,7 +38,7 @@ socket_io = SocketIO(app)
|
|||||||
# parse command-line settings, environment and the init file
|
# parse command-line settings, environment and the init file
|
||||||
# (this is a module global)
|
# (this is a module global)
|
||||||
global web_config
|
global web_config
|
||||||
web_config = InvokeAIWebConfig()
|
web_config = InvokeAIAppConfig()
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
|
@ -137,7 +137,6 @@ class InvokeAISettings(BaseSettings):
|
|||||||
field.default = os.environ[env_name]
|
field.default = os.environ[env_name]
|
||||||
cls.add_field_argument(parser, name, field)
|
cls.add_field_argument(parser, name, field)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cmd_name(self, command_field: str='type')->str:
|
def cmd_name(self, command_field: str='type')->str:
|
||||||
hints = get_type_hints(self)
|
hints = get_type_hints(self)
|
||||||
@ -267,6 +266,12 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||||
|
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
|
||||||
|
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
|
||||||
|
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||||
|
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||||
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
|
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
||||||
@ -403,21 +408,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
'''
|
'''
|
||||||
return _find_root()
|
return _find_root()
|
||||||
|
|
||||||
class InvokeAIWebConfig(InvokeAIAppConfig):
|
|
||||||
'''
|
|
||||||
Web-specific settings
|
|
||||||
'''
|
|
||||||
#fmt: off
|
|
||||||
type : Literal["web"] = "web"
|
|
||||||
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
|
|
||||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
|
|
||||||
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
|
|
||||||
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
|
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
|
||||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
|
||||||
#fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig)->InvokeAISettings:
|
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig)->InvokeAISettings:
|
||||||
'''
|
'''
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
|
@ -36,7 +36,6 @@ from transformers import (
|
|||||||
CLIPTokenizer,
|
CLIPTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||||
@ -54,7 +53,6 @@ from invokeai.backend.config.model_install_backend import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.config import (
|
from invokeai.app.services.config import (
|
||||||
get_invokeai_config,
|
get_invokeai_config,
|
||||||
InvokeAIWebConfig,
|
|
||||||
InvokeAIAppConfig,
|
InvokeAIAppConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -62,6 +60,7 @@ warnings.filterwarnings("ignore")
|
|||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = get_invokeai_config()
|
config = get_invokeai_config()
|
||||||
|
|
||||||
@ -86,13 +85,6 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
|||||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||||
# or renaming it and then running invokeai-configure again.
|
# or renaming it and then running invokeai-configure again.
|
||||||
# Place frequently-used startup commands here, one or more per line.
|
|
||||||
# Examples:
|
|
||||||
# --outdir=D:\data\images
|
|
||||||
# --no-nsfw_checker
|
|
||||||
# --web --host=0.0.0.0
|
|
||||||
# --steps=20
|
|
||||||
# -Ak_euler_a -C10.0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -105,10 +97,9 @@ If you installed manually from source or with 'pip install': activate the virtua
|
|||||||
then run one of the following commands to start InvokeAI.
|
then run one of the following commands to start InvokeAI.
|
||||||
|
|
||||||
Web UI:
|
Web UI:
|
||||||
invokeai --web # (connect to http://localhost:9090)
|
invokeai-web
|
||||||
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
|
|
||||||
|
|
||||||
Command-line interface:
|
Command-line client:
|
||||||
invokeai
|
invokeai
|
||||||
|
|
||||||
If you installed using an installation script, run:
|
If you installed using an installation script, run:
|
||||||
@ -340,7 +331,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
def create(self):
|
def create(self):
|
||||||
program_opts = self.parentApp.program_opts
|
program_opts = self.parentApp.program_opts
|
||||||
old_opts = self.parentApp.invokeai_opts
|
old_opts = self.parentApp.invokeai_opts
|
||||||
first_time = not (config.root / 'invokeai.init').exists()
|
first_time = not (config.root / 'invokeai.yaml').exists()
|
||||||
access_token = HfFolder.get_token()
|
access_token = HfFolder.get_token()
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
for i in [
|
for i in [
|
||||||
@ -374,7 +365,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
self.outdir = self.add_widget_intelligent(
|
self.outdir = self.add_widget_intelligent(
|
||||||
npyscreen.TitleFilename,
|
npyscreen.TitleFilename,
|
||||||
name="(<tab> autocompletes, ctrl-N advances):",
|
name="(<tab> autocompletes, ctrl-N advances):",
|
||||||
value=old_opts.outdir or str(default_output_dir()),
|
value=str(old_opts.outdir) or str(default_output_dir()),
|
||||||
select_dir=True,
|
select_dir=True,
|
||||||
must_exist=False,
|
must_exist=False,
|
||||||
use_two_lines=False,
|
use_two_lines=False,
|
||||||
@ -389,7 +380,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
)
|
)
|
||||||
self.safety_checker = self.add_widget_intelligent(
|
self.nsfw_checker = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="NSFW checker",
|
name="NSFW checker",
|
||||||
value=old_opts.nsfw_checker,
|
value=old_opts.nsfw_checker,
|
||||||
@ -443,7 +434,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
relx=5,
|
relx=5,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.xformers = self.add_widget_intelligent(
|
self.xformers_enabled = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Enable xformers support if available",
|
name="Enable xformers support if available",
|
||||||
value=old_opts.xformers_enabled,
|
value=old_opts.xformers_enabled,
|
||||||
@ -578,10 +569,10 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
for attr in [
|
for attr in [
|
||||||
"outdir",
|
"outdir",
|
||||||
"safety_checker",
|
"nsfw_checker",
|
||||||
"free_gpu_mem",
|
"free_gpu_mem",
|
||||||
"max_loaded_models",
|
"max_loaded_models",
|
||||||
"xformers",
|
"xformers_enabled",
|
||||||
"always_use_cpu",
|
"always_use_cpu",
|
||||||
"embedding_path",
|
"embedding_path",
|
||||||
]:
|
]:
|
||||||
@ -628,7 +619,7 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
|||||||
|
|
||||||
|
|
||||||
def default_startup_options(init_file: Path) -> Namespace:
|
def default_startup_options(init_file: Path) -> Namespace:
|
||||||
opts = InvokeAIWebConfig(argv=[])
|
opts = InvokeAIAppConfig(argv=[])
|
||||||
outdir = Path(opts.outdir)
|
outdir = Path(opts.outdir)
|
||||||
if not outdir.is_absolute():
|
if not outdir.is_absolute():
|
||||||
opts.outdir = str(config.root / opts.outdir)
|
opts.outdir = str(config.root / opts.outdir)
|
||||||
@ -689,53 +680,31 @@ def run_console_ui(
|
|||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def write_opts(opts: Namespace, init_file: Path):
|
def write_opts(opts: Namespace, init_file: Path):
|
||||||
"""
|
"""
|
||||||
Update the invokeai.init file with values from opts Namespace
|
Update the invokeai.init file with values from current settings.
|
||||||
"""
|
"""
|
||||||
# touch file if it doesn't exist
|
|
||||||
if not init_file.exists():
|
|
||||||
with open(init_file, "w") as f:
|
|
||||||
f.write(INIT_FILE_PREAMBLE)
|
|
||||||
|
|
||||||
# We want to write in the changed arguments without clobbering
|
if Path(init_file).exists():
|
||||||
# any other initialization values the user has entered. There is
|
config = OmegaConf.load(init_file)
|
||||||
# no good way to do this because of the one-way nature of
|
else:
|
||||||
# argparse: i.e. --outdir could be --outdir, --out, or -o
|
config = OmegaConf.create()
|
||||||
# initfile needs to be replaced with a fully structured format
|
|
||||||
# such as yaml; this is a hack that will work much of the time
|
|
||||||
args_to_skip = re.compile(
|
|
||||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
|
||||||
)
|
|
||||||
# fix windows paths
|
|
||||||
opts.outdir = opts.outdir.replace("\\", "/")
|
|
||||||
opts.embedding_path = opts.embedding_path.replace("\\", "/")
|
|
||||||
new_file = f"{init_file}.new"
|
|
||||||
try:
|
|
||||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
|
||||||
with open(new_file, "w") as out_file:
|
|
||||||
for line in lines:
|
|
||||||
if len(line) > 0 and not args_to_skip.match(line):
|
|
||||||
out_file.write(line + "\n")
|
|
||||||
out_file.write(
|
|
||||||
f"""
|
|
||||||
--outdir={opts.outdir}
|
|
||||||
--embedding_path={opts.embedding_path}
|
|
||||||
--precision={opts.precision}
|
|
||||||
--max_loaded_models={int(opts.max_loaded_models)}
|
|
||||||
--{'no-' if not opts.safety_checker else ''}nsfw_checker
|
|
||||||
--{'no-' if not opts.xformers else ''}xformers
|
|
||||||
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
|
|
||||||
{'--always_use_cpu' if opts.always_use_cpu else ''}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except OSError as e:
|
|
||||||
print(f"** An error occurred while writing the init file: {str(e)}")
|
|
||||||
|
|
||||||
os.replace(new_file, init_file)
|
if not config.globals:
|
||||||
|
config.globals = dict()
|
||||||
|
|
||||||
|
globals = config.globals
|
||||||
|
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||||
|
for attr in fields:
|
||||||
|
if hasattr(opts,attr):
|
||||||
|
setattr(globals,attr,getattr(opts,attr))
|
||||||
|
|
||||||
|
with open(init_file,'w', encoding='utf-8') as file:
|
||||||
|
file.write(OmegaConf.to_yaml(config))
|
||||||
|
|
||||||
if opts.hf_token:
|
if opts.hf_token:
|
||||||
HfLogin(opts.hf_token)
|
HfLogin(opts.hf_token)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def default_output_dir() -> Path:
|
def default_output_dir() -> Path:
|
||||||
return config.root / "outputs"
|
return config.root / "outputs"
|
||||||
@ -751,7 +720,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
|||||||
write_opts(opt, initfile)
|
write_opts(opt, initfile)
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
# This is ugly. We're going to bring in
|
# Here we bring in
|
||||||
# the legacy Args object in order to parse
|
# the legacy Args object in order to parse
|
||||||
# the old init file and write out the new
|
# the old init file and write out the new
|
||||||
# yaml format.
|
# yaml format.
|
||||||
@ -760,22 +729,21 @@ def migrate_init_file(legacy_format:Path):
|
|||||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||||
new = OmegaConf.create()
|
new = OmegaConf.create()
|
||||||
|
|
||||||
new.web = dict()
|
|
||||||
for attr in ['host','port']:
|
|
||||||
if hasattr(old,attr):
|
|
||||||
setattr(new.web,attr,getattr(old,attr))
|
|
||||||
# change of name
|
|
||||||
new.web.allow_origins = old.cors or []
|
|
||||||
|
|
||||||
new.globals = dict()
|
new.globals = dict()
|
||||||
globals = new.globals
|
globals = new.globals
|
||||||
|
for attr in ['host','port']:
|
||||||
|
if hasattr(old,attr):
|
||||||
|
setattr(globals,attr,getattr(old,attr))
|
||||||
|
# change of name
|
||||||
|
globals.allow_origins = old.cors or []
|
||||||
|
|
||||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||||
for attr in fields:
|
for attr in fields:
|
||||||
if hasattr(old,attr):
|
if hasattr(old,attr):
|
||||||
setattr(globals,attr,getattr(old,attr))
|
setattr(globals,attr,getattr(old,attr))
|
||||||
|
|
||||||
# a few places where the names have changed
|
# a few places where the names have changed
|
||||||
globals.nsfw_checker = old.safety_checker
|
globals.nsfw_checker = old.nsfw_checker
|
||||||
globals.xformers_enabled = old.xformers
|
globals.xformers_enabled = old.xformers
|
||||||
globals.conf_path = old.conf
|
globals.conf_path = old.conf
|
||||||
globals.embedding_dir = old.embedding_path
|
globals.embedding_dir = old.embedding_path
|
||||||
@ -862,14 +830,14 @@ def main():
|
|||||||
initialize_rootdir(config.root, opt.yes_to_all)
|
initialize_rootdir(config.root, opt.yes_to_all)
|
||||||
|
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
write_default_options(opt, init_file)
|
write_default_options(opt, new_init_file)
|
||||||
init_options = Namespace(
|
init_options = Namespace(
|
||||||
precision="float32" if opt.full_precision else "float16"
|
precision="float32" if opt.full_precision else "float16"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||||
if init_options:
|
if init_options:
|
||||||
write_opts(init_options, init_file)
|
write_opts(init_options, new_init_file)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||||
|
Loading…
Reference in New Issue
Block a user