ext and revious fields.
-Use cursor arrows to make a checkbox selection, and space to toggle.
-"""
- self.nextrely -= 1
- for i in textwrap.wrap(label, width=window_width - 6):
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value=i,
- editable=False,
- color="CONTROL",
- )
-
- self.nextrely += 1
- label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
- for line in textwrap.wrap(label, width=window_width - 6):
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value=line,
- editable=False,
- color="CONTROL",
- )
-
- self.hf_token = self.add_widget_intelligent(
- npyscreen.TitlePassword,
- name="Access Token (ctrl-shift-V pastes):",
- value=access_token,
- begin_entry_at=42,
- use_two_lines=False,
- scroll_exit=True,
- )
-
- # old settings for defaults
- precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
- device = old_opts.device
- attention_type = old_opts.attention_type
- attention_slice_size = old_opts.attention_slice_size
- self.nextrely += 1
- self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Image Generation Options:",
- editable=False,
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 2
- self.generation_options = self.add_widget_intelligent(
- MultiSelectColumns,
- columns=3,
- values=GENERATION_OPT_CHOICES,
- value=[GENERATION_OPT_CHOICES.index(x) for x in GENERATION_OPT_CHOICES if getattr(old_opts, x)],
- relx=30,
- max_height=2,
- max_width=80,
- scroll_exit=True,
- )
-
- self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Floating Point Precision:",
- begin_entry_at=0,
- editable=False,
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 2
- self.precision = self.add_widget_intelligent(
- SingleSelectColumnsSimple,
- columns=len(PRECISION_CHOICES),
- name="Precision",
- values=PRECISION_CHOICES,
- value=PRECISION_CHOICES.index(precision),
- begin_entry_at=3,
- max_height=2,
- relx=30,
- max_width=80,
- scroll_exit=True,
- )
- self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Generation Device:",
- begin_entry_at=0,
- editable=False,
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 2
- self.device = self.add_widget_intelligent(
- SingleSelectColumnsSimple,
- columns=len(DEVICE_CHOICES),
- values=DEVICE_CHOICES,
- value=[DEVICE_CHOICES.index(device)],
- begin_entry_at=3,
- relx=30,
- max_height=2,
- max_width=60,
- scroll_exit=True,
- )
- self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Attention Type:",
- begin_entry_at=0,
- editable=False,
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 2
- self.attention_type = self.add_widget_intelligent(
- SingleSelectColumnsSimple,
- columns=len(ATTENTION_CHOICES),
- values=ATTENTION_CHOICES,
- value=[ATTENTION_CHOICES.index(attention_type)],
- begin_entry_at=3,
- max_height=2,
- relx=30,
- max_width=80,
- scroll_exit=True,
- )
- self.attention_type.on_changed = self.show_hide_slice_sizes
- self.attention_slice_label = self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Attention Slice Size:",
- relx=5,
- editable=False,
- hidden=attention_type != "sliced",
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 2
- self.attention_slice_size = self.add_widget_intelligent(
- SingleSelectColumnsSimple,
- columns=len(ATTENTION_SLICE_CHOICES),
- values=ATTENTION_SLICE_CHOICES,
- value=[ATTENTION_SLICE_CHOICES.index(attention_slice_size)],
- relx=30,
- hidden=attention_type != "sliced",
- max_height=2,
- max_width=110,
- scroll_exit=True,
- )
- self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
- begin_entry_at=0,
- editable=False,
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 1
- self.disk = self.add_widget_intelligent(
- npyscreen.Slider,
- value=clip(old_opts.convert_cache, range=(0, 100), step=0.5),
- out_of=100,
- lowest=0.0,
- step=0.5,
- relx=8,
- scroll_exit=True,
- )
- self.nextrely += 1
- self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
- begin_entry_at=0,
- editable=False,
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 1
- self.ram = self.add_widget_intelligent(
- npyscreen.Slider,
- value=clip(old_opts.ram, range=(3.0, MAX_RAM), step=0.5),
- out_of=round(MAX_RAM),
- lowest=0.0,
- step=0.5,
- relx=8,
- scroll_exit=True,
- )
- if HAS_CUDA:
- self.nextrely += 1
- self.add_widget_intelligent(
- npyscreen.TitleFixedText,
- name="Model VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
- begin_entry_at=0,
- editable=False,
- color="CONTROL",
- scroll_exit=True,
- )
- self.nextrely -= 1
- self.vram = self.add_widget_intelligent(
- npyscreen.Slider,
- value=clip(old_opts.vram, range=(0, MAX_VRAM), step=0.25),
- out_of=round(MAX_VRAM * 2) / 2,
- lowest=0.0,
- relx=8,
- step=0.25,
- scroll_exit=True,
- )
- else:
- self.vram = DummyWidgetValue.zero
-
- self.nextrely += 1
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value="Location of the database used to store model path and configuration information:",
- editable=False,
- color="CONTROL",
- )
- self.nextrely += 1
- self.outdir = self.add_widget_intelligent(
- FileBox,
- name="Output directory for images ( autocompletes, ctrl-N advances):",
- value=str(default_output_dir()),
- select_dir=True,
- must_exist=False,
- use_two_lines=False,
- labelColor="GOOD",
- begin_entry_at=40,
- max_height=3,
- max_width=127,
- scroll_exit=True,
- )
- self.autoimport_dirs = {}
- self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
- FileBox,
- name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
- value=str(config.autoimport_path),
- select_dir=True,
- must_exist=False,
- use_two_lines=False,
- labelColor="GOOD",
- begin_entry_at=32,
- max_height=3,
- max_width=127,
- scroll_exit=True,
- )
- self.nextrely += 1
- label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
-AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
-https://huggingface.co/spaces/CompVis/stable-diffusion-license and
-https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
-"""
- for i in textwrap.wrap(label, width=window_width - 6):
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value=i,
- editable=False,
- color="CONTROL",
- )
- self.license_acceptance = self.add_widget_intelligent(
- npyscreen.Checkbox,
- name="I accept the CreativeML Responsible AI Licenses",
- value=not first_time,
- relx=2,
- scroll_exit=True,
- )
- self.nextrely += 1
- label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
- self.ok_button = self.add_widget_intelligent(
- CenteredButtonPress,
- name=label,
- relx=(window_width - len(label)) // 2,
- when_pressed_function=self.on_ok,
- )
-
- def show_hide_slice_sizes(self, value):
- show = ATTENTION_CHOICES[value[0]] == "sliced"
- self.attention_slice_label.hidden = not show
- self.attention_slice_size.hidden = not show
-
- def show_hide_model_conf_override(self, value):
- self.model_conf_override.hidden = value
- self.model_conf_override.display()
-
- def on_ok(self):
- options = self.marshall_arguments()
- if self.validate_field_values(options):
- self.parentApp.new_opts = options
- if hasattr(self.parentApp, "model_select"):
- self.parentApp.setNextForm("MODELS")
- else:
- self.parentApp.setNextForm(None)
- self.editing = False
- else:
- self.editing = True
-
- def validate_field_values(self, opt: Namespace) -> bool:
- bad_fields = []
- if not opt.license_acceptance:
- bad_fields.append("Please accept the license terms before proceeding to model downloads")
- if not Path(opt.outdir).parent.exists():
- bad_fields.append(
- f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
- )
- if len(bad_fields) > 0:
- message = "The following problems were detected and must be corrected:\n"
- for problem in bad_fields:
- message += f"* {problem}\n"
- npyscreen.notify_confirm(message)
- return False
- else:
- return True
-
- def marshall_arguments(self) -> Namespace:
- new_opts = Namespace()
-
- for attr in [
- "ram",
- "vram",
- "convert_cache",
- "outdir",
- ]:
- if hasattr(self, attr):
- setattr(new_opts, attr, getattr(self, attr).value)
-
- for attr in self.autoimport_dirs:
- if not self.autoimport_dirs[attr].value:
- continue
- directory = Path(self.autoimport_dirs[attr].value)
- if directory.is_relative_to(config.root_path):
- directory = directory.relative_to(config.root_path)
- setattr(new_opts, attr, directory)
-
- new_opts.hf_token = self.hf_token.value
- new_opts.license_acceptance = self.license_acceptance.value
- new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
- new_opts.device = DEVICE_CHOICES[self.device.value[0]]
- new_opts.attention_type = ATTENTION_CHOICES[self.attention_type.value[0]]
- new_opts.attention_slice_size = ATTENTION_SLICE_CHOICES[self.attention_slice_size.value[0]]
- generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value]
- for v in GENERATION_OPT_CHOICES:
- setattr(new_opts, v, v in generation_options)
- return new_opts
-
-
-class EditOptApplication(npyscreen.NPSAppManaged):
- def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper):
- super().__init__()
- self.program_opts = program_opts
- self.invokeai_opts = invokeai_opts
- self.user_cancelled = False
- self.autoload_pending = True
- self.install_helper = install_helper
- self.install_selections = default_user_selections(program_opts, install_helper)
-
- def onStart(self):
- npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
- self.options = self.addForm(
- "MAIN",
- editOptsForm,
- name="InvokeAI Startup Options",
- cycle_widgets=False,
- )
- if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
- self.model_select = self.addForm(
- "MODELS",
- addModelsForm,
- name="Install Stable Diffusion Models",
- multipage=True,
- cycle_widgets=False,
- )
-
-
-def get_default_ram_cache_size() -> float:
- """Run a heuristic for the default RAM cache based on installed RAM."""
-
- # Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
- # So we adjust everthing down a bit.
- return (
- 15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
- ) # 2.1 is just large enough for sd 1.5 ;-)
-
-
-def get_default_config() -> InvokeAIAppConfig:
- """Builds a new config object, setting the ram and precision using the appropriate heuristic."""
- config = InvokeAIAppConfig()
- config.ram = get_default_ram_cache_size()
- config.precision = "float32" if FORCE_FULL_PRECISION else choose_precision(torch.device(choose_torch_device()))
- return config
-
-
-def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
- default_model = install_helper.default_model()
- assert default_model is not None
- default_models = [default_model] if program_opts.default_only else install_helper.recommended_models()
- return InstallSelections(
- install_models=default_models if program_opts.yes_to_all else [],
- )
-
-
-# -------------------------------------
-def clip(value: float, range: tuple[float, float], step: float) -> float:
- minimum, maximum = range
- if value < minimum:
- value = minimum
- if value > maximum:
- value = maximum
- return round(value / step) * step
-
-
-# -------------------------------------
-def initialize_rootdir(root: Path, yes_to_all: bool = False):
- logger.info("Initializing InvokeAI runtime directory")
- for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
- os.makedirs(os.path.join(root, name), exist_ok=True)
- for model_type in ModelType:
- Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
-
- configs_src = Path(model_configs.__path__[0])
- configs_dest = root / "configs"
- if not os.path.samefile(configs_src, configs_dest):
- shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
-
- dest = root / "models"
- dest.mkdir(parents=True, exist_ok=True)
-
-
-# -------------------------------------
-def run_console_ui(
- program_opts: Namespace, install_helper: InstallHelper
-) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
- first_time = not config.init_file_path.exists()
- config_opts = get_default_config() if first_time else config
- if program_opts.root:
- config_opts.set_root(Path(program_opts.root))
-
- if not set_min_terminal_size(MIN_COLS, MIN_LINES):
- raise WindowTooSmallException(
- "Could not increase terminal size. Try running again with a larger window or smaller font size."
- )
-
- editApp = EditOptApplication(program_opts, config_opts, install_helper)
- editApp.run()
- if editApp.user_cancelled:
- return (None, None)
- else:
- return (editApp.new_opts, editApp.install_selections)
-
-
-# -------------------------------------
-def default_output_dir() -> Path:
- return config.root_path / "outputs"
-
-
-def is_v2_install(root: Path) -> bool:
- # We check for to see if the runtime directory is correctly initialized.
- old_init_file = root / "invokeai.init"
- new_init_file = root / "invokeai.yaml"
- old_hub = root / "models/hub"
- is_v2 = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
- return is_v2
-
-
-# -------------------------------------
-def main(opt: Namespace) -> None:
- global FORCE_FULL_PRECISION # FIXME
- global config
-
- updates: dict[str, Any] = {}
-
- config = get_config()
- if opt.full_precision:
- updates["precision"] = "float32"
-
- try:
- # Attempt to read the config file into the config object
- config.merge_from_file()
- except FileNotFoundError:
- # No config file, first time running the app
- pass
-
- config.update_config(updates)
- logger = InvokeAILogger().get_logger(config=config)
-
- errors: set[str] = set()
- FORCE_FULL_PRECISION = opt.full_precision # FIXME global
-
- # Before we write anything else, make a backup of the existing init file
- new_init_file = config.init_file_path
- backup_init_file = new_init_file.with_suffix(".bak")
- if new_init_file.exists():
- copy(new_init_file, backup_init_file)
-
- try:
- # v2.3 -> v4.0.0 upgrade is no longer supported
- if is_v2_install(config.root_path):
- logger.error("Migration from v2.3 to v4.0.0 is no longer supported. Please install a fresh copy.")
- sys.exit(0)
-
- # run this unconditionally in case new directories need to be added
- initialize_rootdir(config.root_path, opt.yes_to_all)
-
- # this will initialize and populate the models tables if not present
- install_helper = InstallHelper(config, logger)
-
- models_to_download = default_user_selections(opt, install_helper)
-
- if opt.yes_to_all:
- # We will not show the UI - just write the default config to the file and move on to installing models.
- get_default_config().write_file(new_init_file)
- else:
- # Run the UI to get the user's options & model choices
- user_opts, models_to_download = run_console_ui(opt, install_helper)
- if user_opts:
- # Create a dict of the user's opts, omitting any fields that are not config settings (like `hf_token`)
- user_opts_dict = {k: v for k, v in vars(user_opts).items() if k in config.model_fields}
- # Merge the user's opts back into the config object & write it
- config.update_config(user_opts_dict)
- config.write_file(config.init_file_path)
-
- if hasattr(user_opts, "hf_token") and user_opts.hf_token:
- HfLogin(user_opts.hf_token)
- else:
- logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
- sys.exit(0)
-
- if opt.skip_support_models:
- logger.info("Skipping support models at user's request")
- else:
- logger.info("Installing support models")
- download_support_models()
-
- if opt.skip_sd_weights:
- logger.warning("Skipping diffusion weights download per user request")
- elif models_to_download:
- install_helper.add_or_delete(models_to_download)
-
- postscript(errors=errors)
-
- if not opt.yes_to_all:
- input("Press any key to continue...")
- except WindowTooSmallException as e:
- logger.error(str(e))
- if backup_init_file.exists():
- move(backup_init_file, new_init_file)
- except KeyboardInterrupt:
- print("\nGoodbye! Come back soon.")
- if backup_init_file.exists():
- move(backup_init_file, new_init_file)
- except Exception:
- print("An error occurred during installation.")
- if backup_init_file.exists():
- move(backup_init_file, new_init_file)
- print(traceback.format_exc(), file=sys.stderr)
-
-
-# -------------------------------------
-if __name__ == "__main__":
- main()
diff --git a/invokeai/backend/install/legacy_arg_parsing.py b/invokeai/backend/install/legacy_arg_parsing.py
deleted file mode 100644
index b3cb00c94d..0000000000
--- a/invokeai/backend/install/legacy_arg_parsing.py
+++ /dev/null
@@ -1,379 +0,0 @@
-# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
-
-import argparse
-import shlex
-from argparse import ArgumentParser
-
-# note that this includes both old sampler names and new scheduler names
-# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init
-SAMPLER_CHOICES = [
- "ddim",
- "ddpm",
- "deis",
- "lms",
- "lms_k",
- "pndm",
- "heun",
- "heun_k",
- "euler",
- "euler_k",
- "euler_a",
- "kdpm_2",
- "kdpm_2_a",
- "dpmpp_2s",
- "dpmpp_2s_k",
- "dpmpp_2m",
- "dpmpp_2m_k",
- "dpmpp_2m_sde",
- "dpmpp_2m_sde_k",
- "dpmpp_sde",
- "dpmpp_sde_k",
- "unipc",
- "k_dpm_2_a",
- "k_dpm_2",
- "k_dpmpp_2_a",
- "k_dpmpp_2",
- "k_euler_a",
- "k_euler",
- "k_heun",
- "k_lms",
- "plms",
- "lcm",
-]
-
-PRECISION_CHOICES = [
- "auto",
- "float32",
- "autocast",
- "float16",
-]
-
-
-class FileArgumentParser(ArgumentParser):
- """
- Supports reading defaults from an init file.
- """
-
- def convert_arg_line_to_args(self, arg_line):
- return shlex.split(arg_line, comments=True)
-
-
-legacy_parser = FileArgumentParser(
- description="""
-Generate images using Stable Diffusion.
- Use --web to launch the web interface.
- Use --from_file to load prompts from a file path or standard input ("-").
- Otherwise you will be dropped into an interactive command prompt (type -h for help.)
- Other command-line arguments are defaults that can usually be overridden
- prompt the command prompt.
- """,
- fromfile_prefix_chars="@",
-)
-general_group = legacy_parser.add_argument_group("General")
-model_group = legacy_parser.add_argument_group("Model selection")
-file_group = legacy_parser.add_argument_group("Input/output")
-web_server_group = legacy_parser.add_argument_group("Web server")
-render_group = legacy_parser.add_argument_group("Rendering")
-postprocessing_group = legacy_parser.add_argument_group("Postprocessing")
-deprecated_group = legacy_parser.add_argument_group("Deprecated options")
-
-deprecated_group.add_argument("--laion400m")
-deprecated_group.add_argument("--weights") # deprecated
-general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number")
-model_group.add_argument(
- "--root_dir",
- default=None,
- help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
-)
-model_group.add_argument(
- "--config",
- "-c",
- "-config",
- dest="conf",
- default="./configs/models.yaml",
- help="Path to configuration file for alternate models.",
-)
-model_group.add_argument(
- "--model",
- help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
-)
-model_group.add_argument(
- "--weight_dirs",
- nargs="+",
- type=str,
- help="List of one or more directories that will be auto-scanned for new model weights to import",
-)
-model_group.add_argument(
- "--png_compression",
- "-z",
- type=int,
- default=6,
- choices=range(0, 9),
- dest="png_compression",
- help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
-)
-model_group.add_argument(
- "-F",
- "--full_precision",
- dest="full_precision",
- action="store_true",
- help="Deprecated way to set --precision=float32",
-)
-model_group.add_argument(
- "--max_loaded_models",
- dest="max_loaded_models",
- type=int,
- default=2,
- help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
-)
-model_group.add_argument(
- "--free_gpu_mem",
- dest="free_gpu_mem",
- action="store_true",
- help="Force free gpu memory before final decoding",
-)
-model_group.add_argument(
- "--sequential_guidance",
- dest="sequential_guidance",
- action="store_true",
- help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed",
-)
-model_group.add_argument(
- "--xformers",
- action=argparse.BooleanOptionalAction,
- default=True,
- help="Enable/disable xformers support (default enabled if installed)",
-)
-model_group.add_argument(
- "--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available"
-)
-model_group.add_argument(
- "--precision",
- dest="precision",
- type=str,
- choices=PRECISION_CHOICES,
- metavar="PRECISION",
- help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
- default="auto",
-)
-model_group.add_argument(
- "--ckpt_convert",
- action=argparse.BooleanOptionalAction,
- dest="ckpt_convert",
- default=True,
- help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.",
-)
-model_group.add_argument(
- "--internet",
- action=argparse.BooleanOptionalAction,
- dest="internet_available",
- default=True,
- help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).",
-)
-model_group.add_argument(
- "--nsfw_checker",
- "--safety_checker",
- action=argparse.BooleanOptionalAction,
- dest="safety_checker",
- default=False,
- help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
-)
-model_group.add_argument(
- "--autoimport",
- default=None,
- type=str,
- help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
-)
-model_group.add_argument(
- "--autoconvert",
- default=None,
- type=str,
- help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models",
-)
-model_group.add_argument(
- "--patchmatch",
- action=argparse.BooleanOptionalAction,
- default=True,
- help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.",
-)
-file_group.add_argument(
- "--from_file",
- dest="infile",
- type=str,
- help="If specified, load prompts from this file",
-)
-file_group.add_argument(
- "--outdir",
- "-o",
- type=str,
- help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs",
- default="outputs",
-)
-file_group.add_argument(
- "--prompt_as_dir",
- "-p",
- action="store_true",
- help="Place images in subdirectories named after the prompt.",
-)
-render_group.add_argument(
- "--fnformat",
- default="{prefix}.{seed}.png",
- type=str,
- help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png",
-)
-render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps")
-render_group.add_argument(
- "-W",
- "--width",
- type=int,
- help="Image width, multiple of 64",
-)
-render_group.add_argument(
- "-H",
- "--height",
- type=int,
- help="Image height, multiple of 64",
-)
-render_group.add_argument(
- "-C",
- "--cfg_scale",
- default=7.5,
- type=float,
- help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
-)
-render_group.add_argument(
- "--sampler",
- "-A",
- "-m",
- dest="sampler_name",
- type=str,
- choices=SAMPLER_CHOICES,
- metavar="SAMPLER_NAME",
- help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
- default="k_lms",
-)
-render_group.add_argument(
- "--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens"
-)
-render_group.add_argument(
- "-f",
- "--strength",
- type=float,
- help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely",
-)
-render_group.add_argument(
- "-T",
- "-fit",
- "--fit",
- action=argparse.BooleanOptionalAction,
- help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)",
-)
-
-render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid")
-render_group.add_argument(
- "--embedding_directory",
- "--embedding_path",
- dest="embedding_path",
- default="embeddings",
- type=str,
- help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)",
-)
-render_group.add_argument(
- "--lora_directory",
- dest="lora_path",
- default="loras",
- type=str,
- help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)",
-)
-render_group.add_argument(
- "--embeddings",
- action=argparse.BooleanOptionalAction,
- default=True,
- help="Enable embedding directory (default). Use --no-embeddings to disable.",
-)
-render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display")
-render_group.add_argument(
- "--karras_max",
- type=int,
- default=None,
- help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].",
-)
-# Restoration related args
-postprocessing_group.add_argument(
- "--no_restore",
- dest="restore",
- action="store_false",
- help="Disable face restoration with GFPGAN or codeformer",
-)
-postprocessing_group.add_argument(
- "--no_upscale",
- dest="esrgan",
- action="store_false",
- help="Disable upscaling with ESRGAN",
-)
-postprocessing_group.add_argument(
- "--esrgan_bg_tile",
- type=int,
- default=400,
- help="Tile size for background sampler, 0 for no tile during testing. Default: 400.",
-)
-postprocessing_group.add_argument(
- "--esrgan_denoise_str",
- type=float,
- default=0.75,
- help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75",
-)
-postprocessing_group.add_argument(
- "--gfpgan_model_path",
- type=str,
- default="./models/gfpgan/GFPGANv1.4.pth",
- help="Indicates the path to the GFPGAN model",
-)
-web_server_group.add_argument(
- "--web",
- dest="web",
- action="store_true",
- help="Start in web server mode.",
-)
-web_server_group.add_argument(
- "--web_develop",
- dest="web_develop",
- action="store_true",
- help="Start in web server development mode.",
-)
-web_server_group.add_argument(
- "--web_verbose",
- action="store_true",
- help="Enables verbose logging",
-)
-web_server_group.add_argument(
- "--cors",
- nargs="*",
- type=str,
- help="Additional allowed origins, comma-separated",
-)
-web_server_group.add_argument(
- "--host",
- type=str,
- default="127.0.0.1",
- help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.",
-)
-web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on")
-web_server_group.add_argument(
- "--certfile",
- type=str,
- default=None,
- help="Web server: Path to certificate file to use for SSL. Use together with --keyfile",
-)
-web_server_group.add_argument(
- "--keyfile",
- type=str,
- default=None,
- help="Web server: Path to private key file to use for SSL. Use together with --certfile",
-)
-web_server_group.add_argument(
- "--gui",
- dest="gui",
- action="store_true",
- help="Start InvokeAI GUI",
-)
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index 7ea39bb5c3..dae55a0751 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -25,8 +25,8 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
-
-from ..util import auto_detect_slice_size, normalize_device
+from invokeai.backend.util.attention import auto_detect_slice_size
+from invokeai.backend.util.devices import normalize_device
@dataclass
diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py
index 4278f08bff..2a0fcccd89 100644
--- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py
+++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py
@@ -11,7 +11,7 @@ from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
-from ...util import torch_dtype
+from invokeai.backend.util.devices import torch_dtype
class CrossAttentionType(enum.Enum):
diff --git a/invokeai/backend/training/__init__.py b/invokeai/backend/training/__init__.py
deleted file mode 100644
index 6b5aa7327d..0000000000
--- a/invokeai/backend/training/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""
-Initialization file for invokeai.backend.training
-"""
-
-from .textual_inversion_training import do_textual_inversion_training, parse_args # noqa: F401
diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py
deleted file mode 100644
index 6e214739a0..0000000000
--- a/invokeai/backend/training/textual_inversion_training.py
+++ /dev/null
@@ -1,924 +0,0 @@
-# This code was copied from
-# https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
-# on January 2, 2023
-# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
-
-"""
-This is the backend to "textual_inversion.py"
-"""
-
-import logging
-import math
-import os
-import random
-from argparse import Namespace
-from pathlib import Path
-from typing import Optional
-
-import datasets
-import diffusers
-import numpy as np
-import PIL
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-import transformers
-from accelerate import Accelerator
-from accelerate.logging import get_logger
-from accelerate.utils import ProjectConfiguration, set_seed
-from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-from diffusers.optimization import get_scheduler
-from diffusers.utils import check_min_version
-from diffusers.utils.import_utils import is_xformers_available
-from huggingface_hub import HfFolder, Repository, whoami
-from packaging import version
-from PIL import Image
-from torch.utils.data import Dataset
-from torchvision import transforms
-from tqdm.auto import tqdm
-from transformers import CLIPTextModel, CLIPTokenizer
-
-# invokeai stuff
-from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
-from invokeai.app.services.config.config_default import get_config
-from invokeai.backend.install.install_helper import initialize_record_store
-from invokeai.backend.model_manager import BaseModelType, ModelType
-
-if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
- PIL_INTERPOLATION = {
- "linear": PIL.Image.Resampling.BILINEAR,
- "bilinear": PIL.Image.Resampling.BILINEAR,
- "bicubic": PIL.Image.Resampling.BICUBIC,
- "lanczos": PIL.Image.Resampling.LANCZOS,
- "nearest": PIL.Image.Resampling.NEAREST,
- }
-else:
- PIL_INTERPOLATION = {
- "linear": PIL.Image.LINEAR,
- "bilinear": PIL.Image.BILINEAR,
- "bicubic": PIL.Image.BICUBIC,
- "lanczos": PIL.Image.LANCZOS,
- "nearest": PIL.Image.NEAREST,
- }
-# ------------------------------------------------------------------------------
-
-
-# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.10.0.dev0")
-
-
-logger = get_logger(__name__)
-
-
-def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path):
- logger.info("Saving embeddings")
- learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
- learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
- torch.save(learned_embeds_dict, save_path)
-
-
-def parse_args() -> Namespace:
- config = get_config()
- parser = PagingArgumentParser(description="Textual inversion training")
- general_group = parser.add_argument_group("General")
- model_group = parser.add_argument_group("Models and Paths")
- image_group = parser.add_argument_group("Training Image Location and Options")
- trigger_group = parser.add_argument_group("Trigger Token")
- training_group = parser.add_argument_group("Training Parameters")
- checkpointing_group = parser.add_argument_group("Checkpointing and Resume")
- integration_group = parser.add_argument_group("Integration")
- general_group.add_argument(
- "--front_end",
- "--gui",
- dest="front_end",
- action="store_true",
- default=False,
- help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
- )
- general_group.add_argument(
- "--root_dir",
- "--root",
- type=Path,
- default=config.root_path,
- help="Path to the invokeai runtime directory",
- )
- general_group.add_argument(
- "--logging_dir",
- type=Path,
- default="logs",
- help=(
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
- ),
- )
- general_group.add_argument(
- "--output_dir",
- type=Path,
- default=f"{config.root_path}/text-inversion-model",
- help="The output directory where the model predictions and checkpoints will be written.",
- )
- model_group.add_argument(
- "--model",
- type=str,
- default="sd-1/main/stable-diffusion-v1-5",
- help="Name of the diffusers model to train against.",
- )
- model_group.add_argument(
- "--revision",
- type=str,
- default=None,
- required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
-
- model_group.add_argument(
- "--tokenizer_name",
- type=str,
- default=None,
- help="Pretrained tokenizer name or path if not the same as model_name",
- )
- image_group.add_argument(
- "--train_data_dir",
- type=Path,
- default=None,
- help="A folder containing the training data.",
- )
- image_group.add_argument(
- "--resolution",
- type=int,
- default=512,
- help=(
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
- " resolution"
- ),
- )
- image_group.add_argument(
- "--center_crop",
- action="store_true",
- help="Whether to center crop images before resizing to resolution",
- )
- trigger_group.add_argument(
- "--placeholder_token",
- "--trigger_term",
- dest="placeholder_token",
- type=str,
- default=None,
- help='A token to use as a placeholder for the concept. This token will trigger the concept when included in the prompt as "".',
- )
- trigger_group.add_argument(
- "--learnable_property",
- type=str,
- choices=["object", "style"],
- default="object",
- help="Choose between 'object' and 'style'",
- )
- trigger_group.add_argument(
- "--initializer_token",
- type=str,
- default="*",
- help="A symbol to use as the initializer word.",
- )
- checkpointing_group.add_argument(
- "--checkpointing_steps",
- type=int,
- default=500,
- help=(
- "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
- " training using `--resume_from_checkpoint`."
- ),
- )
- checkpointing_group.add_argument(
- "--resume_from_checkpoint",
- type=Path,
- default=None,
- help=(
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
- ),
- )
- checkpointing_group.add_argument(
- "--save_steps",
- type=int,
- default=500,
- help="Save learned_embeds.bin every X updates steps.",
- )
- training_group.add_argument(
- "--repeats",
- type=int,
- default=100,
- help="How many times to repeat the training data.",
- )
- training_group.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
- training_group.add_argument(
- "--train_batch_size",
- type=int,
- default=16,
- help="Batch size (per device) for the training dataloader.",
- )
- training_group.add_argument("--num_train_epochs", type=int, default=100)
- training_group.add_argument(
- "--max_train_steps",
- type=int,
- default=5000,
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
- )
- training_group.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
- training_group.add_argument(
- "--gradient_checkpointing",
- action="store_true",
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
- )
- training_group.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- training_group.add_argument(
- "--scale_lr",
- action="store_true",
- default=True,
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
- )
- training_group.add_argument(
- "--lr_scheduler",
- type=str,
- default="constant",
- help=(
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
- ' "constant", "constant_with_warmup"]'
- ),
- )
- training_group.add_argument(
- "--lr_warmup_steps",
- type=int,
- default=500,
- help="Number of steps for the warmup in the lr scheduler.",
- )
- training_group.add_argument(
- "--adam_beta1",
- type=float,
- default=0.9,
- help="The beta1 parameter for the Adam optimizer.",
- )
- training_group.add_argument(
- "--adam_beta2",
- type=float,
- default=0.999,
- help="The beta2 parameter for the Adam optimizer.",
- )
- training_group.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
- training_group.add_argument(
- "--adam_epsilon",
- type=float,
- default=1e-08,
- help="Epsilon value for the Adam optimizer",
- )
- training_group.add_argument(
- "--mixed_precision",
- type=str,
- default="no",
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose"
- "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
- "and an Nvidia Ampere GPU."
- ),
- )
- training_group.add_argument(
- "--allow_tf32",
- action="store_true",
- help=(
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
- ),
- )
- training_group.add_argument(
- "--local_rank",
- type=int,
- default=-1,
- help="For distributed training: local_rank",
- )
- parser.add_argument(
- "--enable_xformers_memory_efficient_attention",
- action="store_true",
- help="Whether or not to use xformers.",
- )
-
- integration_group.add_argument(
- "--only_save_embeds",
- action="store_true",
- default=False,
- help="Save only the embeddings for the new concept.",
- )
- integration_group.add_argument(
- "--hub_model_id",
- type=str,
- default=None,
- help="The name of the repository to keep in sync with the local `output_dir`.",
- )
- integration_group.add_argument(
- "--report_to",
- type=str,
- default="tensorboard",
- help=(
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
- ),
- )
- integration_group.add_argument(
- "--push_to_hub",
- action="store_true",
- help="Whether or not to push the model to the Hub.",
- )
- integration_group.add_argument(
- "--hub_token",
- type=str,
- default=None,
- help="The token to use to push to the Model Hub.",
- )
- args = parser.parse_args()
- return args
-
-
-imagenet_templates_small = [
- "a photo of a {}",
- "a rendering of a {}",
- "a cropped photo of the {}",
- "the photo of a {}",
- "a photo of a clean {}",
- "a photo of a dirty {}",
- "a dark photo of the {}",
- "a photo of my {}",
- "a photo of the cool {}",
- "a close-up photo of a {}",
- "a bright photo of the {}",
- "a cropped photo of a {}",
- "a photo of the {}",
- "a good photo of the {}",
- "a photo of one {}",
- "a close-up photo of the {}",
- "a rendition of the {}",
- "a photo of the clean {}",
- "a rendition of a {}",
- "a photo of a nice {}",
- "a good photo of a {}",
- "a photo of the nice {}",
- "a photo of the small {}",
- "a photo of the weird {}",
- "a photo of the large {}",
- "a photo of a cool {}",
- "a photo of a small {}",
-]
-
-imagenet_style_templates_small = [
- "a painting in the style of {}",
- "a rendering in the style of {}",
- "a cropped painting in the style of {}",
- "the painting in the style of {}",
- "a clean painting in the style of {}",
- "a dirty painting in the style of {}",
- "a dark painting in the style of {}",
- "a picture in the style of {}",
- "a cool painting in the style of {}",
- "a close-up painting in the style of {}",
- "a bright painting in the style of {}",
- "a cropped painting in the style of {}",
- "a good painting in the style of {}",
- "a close-up painting in the style of {}",
- "a rendition in the style of {}",
- "a nice painting in the style of {}",
- "a small painting in the style of {}",
- "a weird painting in the style of {}",
- "a large painting in the style of {}",
-]
-
-
-class TextualInversionDataset(Dataset):
- def __init__(
- self,
- data_root,
- tokenizer,
- learnable_property="object", # [object, style]
- size=512,
- repeats=100,
- interpolation="bicubic",
- flip_p=0.5,
- set="train",
- placeholder_token="*",
- center_crop=False,
- ):
- self.data_root = Path(data_root)
- self.tokenizer = tokenizer
- self.learnable_property = learnable_property
- self.size = size
- self.placeholder_token = placeholder_token
- self.center_crop = center_crop
- self.flip_p = flip_p
-
- self.image_paths = [
- self.data_root / file_path
- for file_path in self.data_root.iterdir()
- if file_path.is_file()
- and file_path.name.endswith((".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG", ".gif", ".GIF"))
- ]
-
- self.num_images = len(self.image_paths)
- self._length = self.num_images
-
- if set == "train":
- self._length = self.num_images * repeats
-
- self.interpolation = {
- "linear": PIL_INTERPOLATION["linear"],
- "bilinear": PIL_INTERPOLATION["bilinear"],
- "bicubic": PIL_INTERPOLATION["bicubic"],
- "lanczos": PIL_INTERPOLATION["lanczos"],
- }[interpolation]
-
- self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
- self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
-
- def __len__(self) -> int:
- return self._length
-
- def __getitem__(self, i):
- example = {}
- image = Image.open(self.image_paths[i % self.num_images])
-
- if not image.mode == "RGB":
- image = image.convert("RGB")
-
- placeholder_string = self.placeholder_token
- text = random.choice(self.templates).format(placeholder_string)
-
- example["input_ids"] = self.tokenizer(
- text,
- padding="max_length",
- truncation=True,
- max_length=self.tokenizer.model_max_length,
- return_tensors="pt",
- ).input_ids[0]
-
- # default to score-sde preprocessing
- img = np.array(image).astype(np.uint8)
-
- if self.center_crop:
- crop = min(img.shape[0], img.shape[1])
- (
- h,
- w,
- ) = (
- img.shape[0],
- img.shape[1],
- )
- img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
-
- image = Image.fromarray(img)
- image = image.resize((self.size, self.size), resample=self.interpolation)
-
- image = self.flip_transform(image)
- image = np.array(image).astype(np.uint8)
- image = (image / 127.5 - 1.0).astype(np.float32)
-
- example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
- return example
-
-
-def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
- if token is None:
- token = HfFolder.get_token()
- if organization is None:
- username = whoami(token)["name"]
- return f"{username}/{model_id}"
- else:
- return f"{organization}/{model_id}"
-
-
-def do_textual_inversion_training(
- config: InvokeAIAppConfig,
- model: str,
- train_data_dir: Path,
- output_dir: Path,
- placeholder_token: str,
- initializer_token: str,
- save_steps: int = 500,
- only_save_embeds: bool = False,
- tokenizer_name: Optional[str] = None,
- learnable_property: str = "object",
- repeats: int = 100,
- seed: Optional[int] = None,
- resolution: int = 512,
- center_crop: bool = False,
- train_batch_size: int = 16,
- num_train_epochs: int = 100,
- max_train_steps: int = 5000,
- gradient_accumulation_steps: int = 1,
- gradient_checkpointing: bool = False,
- learning_rate: float = 1e-4,
- scale_lr: bool = True,
- lr_scheduler: str = "constant",
- lr_warmup_steps: int = 500,
- adam_beta1: float = 0.9,
- adam_beta2: float = 0.999,
- adam_weight_decay: float = 1e-02,
- adam_epsilon: float = 1e-08,
- push_to_hub: bool = False,
- hub_token: Optional[str] = None,
- logging_dir: Path = Path("logs"),
- mixed_precision: str = "fp16",
- allow_tf32: bool = False,
- report_to: str = "tensorboard",
- local_rank: int = -1,
- checkpointing_steps: int = 500,
- resume_from_checkpoint: Optional[Path] = None,
- enable_xformers_memory_efficient_attention: bool = False,
- hub_model_id: Optional[str] = None,
- **kwargs,
-) -> None:
- assert model, "Please specify a base model with --model"
- assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
- assert placeholder_token, "Please specify a trigger term using --placeholder_token"
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
- if env_local_rank != -1 and env_local_rank != local_rank:
- local_rank = env_local_rank
-
- # setting up things the way invokeai expects them
- if not os.path.isabs(output_dir):
- output_dir = config.root_path / output_dir
-
- logging_dir = output_dir / logging_dir
-
- accelerator_config = ProjectConfiguration()
- accelerator_config.logging_dir = logging_dir
- accelerator = Accelerator(
- gradient_accumulation_steps=gradient_accumulation_steps,
- mixed_precision=mixed_precision,
- log_with=report_to,
- project_config=accelerator_config,
- )
-
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- logger.info(accelerator.state, main_process_only=False)
- if accelerator.is_local_main_process:
- datasets.utils.logging.set_verbosity_warning()
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- datasets.utils.logging.set_verbosity_error()
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
-
- # If passed along, set the training seed now.
- if seed is not None:
- set_seed(seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if push_to_hub:
- if hub_model_id is None:
- repo_name = get_full_repo_name(Path(output_dir).name, token=hub_token)
- else:
- repo_name = hub_model_id
- repo = Repository(output_dir, clone_from=repo_name)
-
- with open(os.path.join(output_dir, ".gitignore"), "w+") as gitignore:
- if "step_*" not in gitignore:
- gitignore.write("step_*\n")
- if "epoch_*" not in gitignore:
- gitignore.write("epoch_*\n")
- elif output_dir is not None:
- os.makedirs(output_dir, exist_ok=True)
-
- model_records = initialize_record_store(config)
- base, type, name = model.split("/") # note frontend still returns old-style keys
- try:
- model_config = model_records.search_by_attr(
- model_name=name, model_type=ModelType(type), base_model=BaseModelType(base)
- )[0]
- except IndexError:
- raise Exception(f"Unknown model {model}")
- model_path = config.models_path / model_config.path
-
- pipeline_args = {"local_files_only": True}
- if tokenizer_name:
- tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
- else:
- tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", **pipeline_args)
-
- # Load scheduler and models
- noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler", **pipeline_args)
- text_encoder = CLIPTextModel.from_pretrained(
- model_path,
- subfolder="text_encoder",
- **pipeline_args,
- )
- vae = AutoencoderKL.from_pretrained(
- model_path,
- subfolder="vae",
- **pipeline_args,
- )
- unet = UNet2DConditionModel.from_pretrained(
- model_path,
- subfolder="unet",
- **pipeline_args,
- )
-
- # Add the placeholder token in tokenizer
- num_added_tokens = tokenizer.add_tokens(placeholder_token)
- if num_added_tokens == 0:
- raise ValueError(
- f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
- " `placeholder_token` that is not already in the tokenizer."
- )
-
- # Convert the initializer_token, placeholder_token to ids
- token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
- # Check if initializer_token is a single token or a sequence of tokens
- if len(token_ids) > 1:
- raise ValueError(
- f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}"
- )
-
- initializer_token_id = token_ids[0]
- placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
-
- # Resize the token embeddings as we are adding new special tokens to the tokenizer
- text_encoder.resize_token_embeddings(len(tokenizer))
-
- # Initialise the newly added placeholder token with the embeddings of the initializer token
- token_embeds = text_encoder.get_input_embeddings().weight.data
- token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
-
- # Freeze vae and unet
- vae.requires_grad_(False)
- unet.requires_grad_(False)
- # Freeze all parameters except for the token embeddings in text encoder
- text_encoder.text_model.encoder.requires_grad_(False)
- text_encoder.text_model.final_layer_norm.requires_grad_(False)
- text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
-
- if gradient_checkpointing:
- # Keep unet in train mode if we are using gradient checkpointing to save memory.
- # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
- unet.train()
- text_encoder.gradient_checkpointing_enable()
- unet.enable_gradient_checkpointing()
-
- if enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError("xformers is not available. Make sure it is installed correctly")
-
- # Enable TF32 for faster training on Ampere GPUs,
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
-
- if scale_lr:
- learning_rate = learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
-
- # Initialize the optimizer
- optimizer = torch.optim.AdamW(
- text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
- lr=learning_rate,
- betas=(adam_beta1, adam_beta2),
- weight_decay=adam_weight_decay,
- eps=adam_epsilon,
- )
-
- # Dataset and DataLoaders creation:
- train_dataset = TextualInversionDataset(
- data_root=train_data_dir,
- tokenizer=tokenizer,
- size=resolution,
- placeholder_token=placeholder_token,
- repeats=repeats,
- learnable_property=learnable_property,
- center_crop=center_crop,
- set="train",
- )
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
-
- # Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
- if max_train_steps is None:
- max_train_steps = num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
-
- scheduler = get_scheduler(
- lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
- num_training_steps=max_train_steps * gradient_accumulation_steps,
- )
-
- # Prepare everything with our `accelerator`.
- text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- text_encoder, optimizer, train_dataloader, scheduler
- )
-
- # For mixed precision training we cast the unet and vae weights to half-precision
- # as these models are only used for inference, keeping weights in full precision is not required.
- weight_dtype = torch.float32
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- # Move vae and unet to device and cast to weight_dtype
- unet.to(accelerator.device, dtype=weight_dtype)
- vae.to(accelerator.device, dtype=weight_dtype)
-
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
- if overrode_max_train_steps:
- max_train_steps = num_train_epochs * num_update_steps_per_epoch
- # Afterwards we recalculate our number of training epochs
- num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
-
- # We need to initialize the trackers we use, and also store our configuration.
- # The trackers initializes automatically on the main process.
- if accelerator.is_main_process:
- params = locals()
- for k in params: # init_trackers() doesn't like objects
- params[k] = str(params[k]) if isinstance(params[k], object) else params[k]
- accelerator.init_trackers("textual_inversion", config=params)
-
- # Train!
- total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
-
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {len(train_dataset)}")
- logger.info(f" Num Epochs = {num_train_epochs}")
- logger.info(f" Instantaneous batch size per device = {train_batch_size}")
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {max_train_steps}")
- global_step = 0
- first_epoch = 0
- resume_step = None
-
- # Potentially load in the weights and states from a previous save
- if resume_from_checkpoint:
- if resume_from_checkpoint != "latest":
- path = os.path.basename(resume_from_checkpoint)
- else:
- # Get the most recent checkpoint
- dirs = os.listdir(output_dir)
- dirs = [d for d in dirs if d.startswith("checkpoint")]
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
- path = dirs[-1] if len(dirs) > 0 else None
-
- if path is None:
- accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
- resume_from_checkpoint = None
- else:
- accelerator.print(f"Resuming from checkpoint {path}")
- accelerator.load_state(os.path.join(output_dir, path))
- global_step = int(path.split("-")[1])
-
- resume_global_step = global_step * gradient_accumulation_steps
- first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps)
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(
- range(global_step, max_train_steps),
- disable=not accelerator.is_local_main_process,
- )
- progress_bar.set_description("Steps")
-
- # keep original embeddings as reference
- orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
-
- for epoch in range(first_epoch, num_train_epochs):
- text_encoder.train()
- for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if resume_step and resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
- with accelerator.accumulate(text_encoder):
- # Convert images to latent space
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents)
- bsz = latents.shape[0]
- # Sample a random timestep for each image
- timesteps = torch.randint(
- 0,
- noise_scheduler.config.num_train_timesteps,
- (bsz,),
- device=latents.device,
- )
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the text embedding for conditioning
- encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
-
- # Predict the noise residual
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- # Get the target for loss depending on the prediction type
- if noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif noise_scheduler.config.prediction_type == "v_prediction":
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
- else:
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
-
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
-
- accelerator.backward(loss)
-
- optimizer.step()
- scheduler.step()
- optimizer.zero_grad()
-
- # Let's make sure we don't update any embedding weights besides the newly added token
- index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
- with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
- orig_embeds_params[index_no_updates]
- )
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
- if global_step % save_steps == 0:
- save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin")
- save_progress(
- text_encoder,
- placeholder_token_id,
- accelerator,
- placeholder_token,
- save_path,
- )
-
- if global_step % checkpointing_steps == 0:
- if accelerator.is_main_process:
- save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
- accelerator.save_state(save_path)
- logger.info(f"Saved state to {save_path}")
-
- logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
- accelerator.log(logs, step=global_step)
-
- if global_step >= max_train_steps:
- break
-
- # Create the pipeline using using the trained modules and save it.
- accelerator.wait_for_everyone()
- if accelerator.is_main_process:
- if push_to_hub and only_save_embeds:
- logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
- save_full_model = True
- else:
- save_full_model = not only_save_embeds
- if save_full_model:
- pipeline = StableDiffusionPipeline.from_pretrained(
- model_path,
- text_encoder=accelerator.unwrap_model(text_encoder),
- vae=vae,
- unet=unet,
- tokenizer=tokenizer,
- **pipeline_args,
- )
- pipeline.save_pretrained(output_dir)
- # Save the newly trained embeddings
- save_path = os.path.join(output_dir, "learned_embeds.bin")
- save_progress(
- text_encoder,
- placeholder_token_id,
- accelerator,
- placeholder_token,
- save_path,
- )
-
- if push_to_hub:
- repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
-
- accelerator.end_training()
diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py
index ee6793376f..2c9cceff2c 100644
--- a/invokeai/backend/util/__init__.py
+++ b/invokeai/backend/util/__init__.py
@@ -2,32 +2,14 @@
Initialization file for invokeai.backend.util
"""
-from .attention import auto_detect_slice_size # noqa: F401
-from .devices import ( # noqa: F401
- CPU_DEVICE,
- CUDA_DEVICE,
- MPS_DEVICE,
- choose_precision,
- choose_torch_device,
- normalize_device,
- torch_dtype,
-)
+from .devices import choose_precision, choose_torch_device
from .logging import InvokeAILogger
-from .util import ( # TO DO: Clean this up; remove the unused symbols
- GIG,
- Chdir,
- ask_user, # noqa
- directory_size,
- download_with_resume,
- instantiate_from_config, # noqa
- url_attachment_name, # noqa
-)
+from .util import GIG, Chdir, directory_size
__all__ = [
"GIG",
"directory_size",
"Chdir",
- "download_with_resume",
"InvokeAILogger",
"choose_precision",
"choose_torch_device",
diff --git a/invokeai/backend/util/log.py b/invokeai/backend/util/log.py
deleted file mode 100644
index 3919d456b9..0000000000
--- a/invokeai/backend/util/log.py
+++ /dev/null
@@ -1,67 +0,0 @@
-"""
-Functions for better format logging
- write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
- 1 write_log_message -- Writes a message to the console
- 2 write_log_files -- Writes a message to files
- 2.1 write_log_default -- File in plain text
- 2.2 write_log_txt -- File in txt format
- 2.3 write_log_markdown -- File in markdown format
-"""
-
-import os
-
-
-def write_log(results, log_path, file_types, output_cntr):
- """
- logs the name of the output image, prompt, and prompt args to the terminal and files
- """
- output_cntr = write_log_message(results, output_cntr)
- write_log_files(results, log_path, file_types)
- return output_cntr
-
-
-def write_log_message(results, output_cntr):
- """logs to the terminal"""
- if len(results) == 0:
- return output_cntr
- log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
- if len(log_lines) > 1:
- subcntr = 1
- for ll in log_lines:
- print(f"[{output_cntr}.{subcntr}] {ll}", end="")
- subcntr += 1
- else:
- print(f"[{output_cntr}] {log_lines[0]}", end="")
- return output_cntr + 1
-
-
-def write_log_files(results, log_path, file_types):
- for file_type in file_types:
- if file_type == "txt":
- write_log_txt(log_path, results)
- elif file_type == "md" or file_type == "markdown":
- write_log_markdown(log_path, results)
- else:
- print(f"'{file_type}' format is not supported, so write in plain text")
- write_log_default(log_path, results, file_type)
-
-
-def write_log_default(log_path, results, file_type):
- plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
- with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
- file.writelines(plain_txt_lines)
-
-
-def write_log_txt(log_path, results):
- txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
- with open(log_path + ".txt", "a", encoding="utf-8") as file:
- file.writelines(txt_lines)
-
-
-def write_log_markdown(log_path, results):
- md_lines = []
- for path, prompt in results:
- file_name = os.path.basename(path)
- md_lines.append(f"## {file_name}\n![]({file_name})\n\n{prompt}\n")
- with open(log_path + ".md", "a", encoding="utf-8") as file:
- file.writelines(md_lines)
diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py
index ac7a64e807..7d0d9d03f7 100644
--- a/invokeai/backend/util/util.py
+++ b/invokeai/backend/util/util.py
@@ -1,29 +1,13 @@
import base64
-import importlib
import io
-import math
-import multiprocessing as mp
import os
-import re
import warnings
-from collections import abc
-from inspect import isfunction
from pathlib import Path
-from queue import Queue
-from threading import Thread
-import numpy as np
-import requests
-import torch
from diffusers import logging as diffusers_logging
-from PIL import Image, ImageDraw, ImageFont
-from tqdm import tqdm
+from PIL import Image
from transformers import logging as transformers_logging
-import invokeai.backend.util.logging as logger
-
-from .devices import torch_dtype
-
# actual size of a gig
GIG = 1073741824
@@ -41,340 +25,6 @@ def directory_size(directory: Path) -> int:
return sum
-def log_txt_as_img(wh, xc, size=10):
- # wh a tuple of (width, height)
- # xc a list of captions to plot
- b = len(xc)
- txts = []
- for bi in range(b):
- txt = Image.new("RGB", wh, color="white")
- draw = ImageDraw.Draw(txt)
- font = ImageFont.load_default()
- nc = int(40 * (wh[0] / 256))
- lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
-
- try:
- draw.text((0, 0), lines, fill="black", font=font)
- except UnicodeEncodeError:
- logger.warning("Cant encode string for logging. Skipping.")
-
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
- txts.append(txt)
- txts = np.stack(txts)
- txts = torch.tensor(txts)
- return txts
-
-
-def ismap(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] > 3)
-
-
-def isimage(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
-
-
-def exists(x):
- return x is not None
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def mean_flat(tensor):
- """
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-def count_params(model, verbose=False):
- total_params = sum(p.numel() for p in model.parameters())
- if verbose:
- logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
- return total_params
-
-
-def instantiate_from_config(config, **kwargs):
- if "target" not in config:
- if config == "__is_first_stage__":
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
-
-
-def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
-
-
-def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
- # create dummy dataset instance
-
- # run prefetching
- if idx_to_fn:
- res = func(data, worker_id=idx)
- else:
- res = func(data)
- Q.put([idx, res])
- Q.put("Done")
-
-
-def parallel_data_prefetch(
- func: callable,
- data,
- n_proc,
- target_data_type="ndarray",
- cpu_intensive=True,
- use_worker_id=False,
-):
- # if target_data_type not in ["ndarray", "list"]:
- # raise ValueError(
- # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
- # )
- if isinstance(data, np.ndarray) and target_data_type == "list":
- raise ValueError("list expected but function got ndarray.")
- elif isinstance(data, abc.Iterable):
- if isinstance(data, dict):
- logger.warning(
- '"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
- )
- data = list(data.values())
- if target_data_type == "ndarray":
- data = np.asarray(data)
- else:
- data = list(data)
- else:
- raise TypeError(
- f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
- )
-
- if cpu_intensive:
- Q = mp.Queue(1000)
- proc = mp.Process
- else:
- Q = Queue(1000)
- proc = Thread
- # spawn processes
- if target_data_type == "ndarray":
- arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
- else:
- step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
- arguments = [
- [func, Q, part, i, use_worker_id]
- for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
- ]
- processes = []
- for i in range(n_proc):
- p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
- processes += [p]
-
- # start processes
- logger.info("Start prefetching...")
- import time
-
- start = time.time()
- gather_res = [[] for _ in range(n_proc)]
- try:
- for p in processes:
- p.start()
-
- k = 0
- while k < n_proc:
- # get result
- res = Q.get()
- if res == "Done":
- k += 1
- else:
- gather_res[res[0]] = res[1]
-
- except Exception as e:
- logger.error("Exception: ", e)
- for p in processes:
- p.terminate()
-
- raise e
- finally:
- for p in processes:
- p.join()
- logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
-
- if target_data_type == "ndarray":
- if not isinstance(gather_res[0], np.ndarray):
- return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
-
- # order outputs
- return np.concatenate(gather_res, axis=0)
- elif target_data_type == "list":
- out = []
- for r in gather_res:
- out.extend(r)
- return out
- else:
- return gather_res
-
-
-def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
- delta = (res[0] / shape[0], res[1] / shape[1])
- d = (shape[0] // res[0], shape[1] // res[1])
-
- grid = (
- torch.stack(
- torch.meshgrid(
- torch.arange(0, res[0], delta[0]),
- torch.arange(0, res[1], delta[1]),
- indexing="ij",
- ),
- dim=-1,
- ).to(device)
- % 1
- )
-
- rand_val = torch.rand(res[0] + 1, res[1] + 1)
-
- angles = 2 * math.pi * rand_val
- gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
-
- def tile_grads(slice1, slice2):
- return (
- gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
- .repeat_interleave(d[0], 0)
- .repeat_interleave(d[1], 1)
- )
-
- def dot(grad, shift):
- return (
- torch.stack(
- (
- grid[: shape[0], : shape[1], 0] + shift[0],
- grid[: shape[0], : shape[1], 1] + shift[1],
- ),
- dim=-1,
- )
- * grad[: shape[0], : shape[1]]
- ).sum(dim=-1)
-
- n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
- n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
- n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
- n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
- t = fade(grid[: shape[0], : shape[1]])
- noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(
- device
- )
- return noise.to(dtype=torch_dtype(device))
-
-
-def ask_user(question: str, answers: list):
- from itertools import chain, repeat
-
- user_prompt = f"\n>> {question} {answers}: "
- invalid_answer_msg = "Invalid answer. Please try again."
- pose_question = chain([user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])))
- user_answers = map(input, pose_question)
- valid_response = next(filter(answers.__contains__, user_answers))
- return valid_response
-
-
-# -------------------------------------
-def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
- """
- Download a model file.
- :param url: https, http or ftp URL
- :param dest: A Path object. If path exists and is a directory, then we try to derive the filename
- from the URL's Content-Disposition header and copy the URL contents into
- dest/filename
- :param access_token: Access token to access this resource
- """
- header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
- open_mode = "wb"
- exist_size = 0
-
- resp = requests.get(url, headers=header, stream=True, allow_redirects=True)
- content_length = int(resp.headers.get("content-length", 0))
-
- if dest.is_dir():
- try:
- file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
- except AttributeError:
- file_name = os.path.basename(url)
- dest = dest / file_name
- else:
- dest.parent.mkdir(parents=True, exist_ok=True)
-
- if dest.exists():
- exist_size = dest.stat().st_size
- header["Range"] = f"bytes={exist_size}-"
- open_mode = "ab"
- resp = requests.get(url, headers=header, stream=True) # new request with range
-
- if exist_size > content_length:
- logger.warning("corrupt existing file found. re-downloading")
- os.remove(dest)
- exist_size = 0
-
- if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
- logger.warning(f"{dest}: complete file found. Skipping.")
- return dest
- elif resp.status_code == 206 or exist_size > 0:
- logger.warning(f"{dest}: partial file found. Resuming...")
- elif resp.status_code != 200:
- logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
- else:
- logger.info(f"{dest}: Downloading...")
-
- try:
- if content_length < 2000:
- logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
- return None
-
- with (
- open(dest, open_mode) as file,
- tqdm(
- desc=str(dest),
- initial=exist_size,
- total=content_length,
- unit="iB",
- unit_scale=True,
- unit_divisor=1000,
- ) as bar,
- ):
- for data in resp.iter_content(chunk_size=1024):
- size = file.write(data)
- bar.update(size)
- except Exception as e:
- logger.error(f"An error occurred while downloading {dest}: {str(e)}")
- return None
-
- return dest
-
-
-def url_attachment_name(url: str) -> dict:
- try:
- resp = requests.get(url, stream=True)
- match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
- return match.group(1)
- except Exception:
- return None
-
-
-def download_with_progress_bar(url: str, dest: Path) -> bool:
- result = download_with_resume(url, dest, access_token=None)
- return result is not None
-
-
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml
deleted file mode 100644
index ff4d8217d0..0000000000
--- a/invokeai/configs/INITIAL_MODELS.yaml
+++ /dev/null
@@ -1,189 +0,0 @@
-# This file predefines a few models that the user may want to install.
-sd-1/main/stable-diffusion-v1-5:
- description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
- source: runwayml/stable-diffusion-v1-5
- recommended: True
- default: True
-sd-1/main/stable-diffusion-v1-5-inpainting:
- description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
- source: runwayml/stable-diffusion-inpainting
- recommended: True
-sd-2/main/stable-diffusion-2-1:
- description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
- source: stabilityai/stable-diffusion-2-1
- recommended: False
-sd-2/main/stable-diffusion-2-inpainting:
- description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
- source: stabilityai/stable-diffusion-2-inpainting
- recommended: False
-sdxl/main/stable-diffusion-xl-base-1-0:
- description: Stable Diffusion XL base model (12 GB)
- source: stabilityai/stable-diffusion-xl-base-1.0
- recommended: True
-sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
- description: Stable Diffusion XL refiner model (12 GB)
- source: stabilityai/stable-diffusion-xl-refiner-1.0
- recommended: False
-sdxl/vae/sdxl-vae-fp16-fix:
- description: Version of the SDXL-1.0 VAE that works in half precision mode
- source: madebyollin/sdxl-vae-fp16-fix
- recommended: True
-sd-1/main/Analog-Diffusion:
- description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
- source: wavymulder/Analog-Diffusion
- recommended: False
-sd-1/main/Deliberate:
- description: Versatile model that produces detailed images up to 768px (4.27 GB)
- source: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors?download=true
- recommended: False
-sd-1/main/Dungeons-and-Diffusion:
- description: Dungeons & Dragons characters (2.13 GB)
- source: 0xJustin/Dungeons-and-Diffusion
- recommended: False
-sd-1/main/dreamlike-photoreal-2:
- description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
- source: dreamlike-art/dreamlike-photoreal-2.0
- recommended: False
-sd-1/main/Inkpunk-Diffusion:
- description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
- source: Envvi/Inkpunk-Diffusion
- recommended: False
-sd-1/main/openjourney:
- description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
- source: prompthero/openjourney
- recommended: False
-sd-1/main/seek.art_MEGA:
- source: coreco/seek.art_MEGA
- description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
- recommended: False
-sd-1/main/trinart_stable_diffusion_v2:
- description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
- source: naclbit/trinart_stable_diffusion_v2
- recommended: False
-sd-1/controlnet/qrcode_monster:
- source: monster-labs/control_v1p_sd15_qrcode_monster
- description: Controlnet model that generates scannable creative QR codes
- subfolder: v2
-sd-1/controlnet/canny:
- description: Controlnet weights trained on sd-1.5 with canny conditioning.
- source: lllyasviel/control_v11p_sd15_canny
- recommended: True
-sd-1/controlnet/inpaint:
- source: lllyasviel/control_v11p_sd15_inpaint
- description: Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version
-sd-1/controlnet/mlsd:
- description: Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version
- source: lllyasviel/control_v11p_sd15_mlsd
-sd-1/controlnet/depth:
- description: Controlnet weights trained on sd-1.5 with depth conditioning
- source: lllyasviel/control_v11f1p_sd15_depth
- recommended: True
-sd-1/controlnet/normal_bae:
- description: Controlnet weights trained on sd-1.5 with normalbae image conditioning
- source: lllyasviel/control_v11p_sd15_normalbae
-sd-1/controlnet/seg:
- description: Controlnet weights trained on sd-1.5 with seg image conditioning
- source: lllyasviel/control_v11p_sd15_seg
-sd-1/controlnet/lineart:
- description: Controlnet weights trained on sd-1.5 with lineart image conditioning
- source: lllyasviel/control_v11p_sd15_lineart
- recommended: True
-sd-1/controlnet/lineart_anime:
- description: Controlnet weights trained on sd-1.5 with anime image conditioning
- source: lllyasviel/control_v11p_sd15s2_lineart_anime
-sd-1/controlnet/openpose:
- description: Controlnet weights trained on sd-1.5 with openpose image conditioning
- source: lllyasviel/control_v11p_sd15_openpose
- recommended: True
-sd-1/controlnet/scribble:
- source: lllyasviel/control_v11p_sd15_scribble
- description: Controlnet weights trained on sd-1.5 with scribble image conditioning
- recommended: False
-sd-1/controlnet/softedge:
- source: lllyasviel/control_v11p_sd15_softedge
- description: Controlnet weights trained on sd-1.5 with soft edge conditioning
-sd-1/controlnet/shuffle:
- source: lllyasviel/control_v11e_sd15_shuffle
- description: Controlnet weights trained on sd-1.5 with shuffle image conditioning
-sd-1/controlnet/tile:
- source: lllyasviel/control_v11f1e_sd15_tile
- description: Controlnet weights trained on sd-1.5 with tiled image conditioning
-sd-1/controlnet/ip2p:
- source: lllyasviel/control_v11e_sd15_ip2p
- description: Controlnet weights trained on sd-1.5 with ip2p conditioning.
-sdxl/controlnet/canny-sdxl:
- description: Controlnet weights trained on sdxl-1.0 with canny conditioning.
- source: diffusers/controlnet-canny-sdxl-1.0
- recommended: True
-sdxl/controlnet/depth-sdxl:
- description: Controlnet weights trained on sdxl-1.0 with depth conditioning.
- source: diffusers/controlnet-depth-sdxl-1.0
- recommended: True
-sdxl/controlnet/softedge-dexined-sdxl:
- description: Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.
- source: SargeZT/controlnet-sd-xl-1.0-softedge-dexined
-sdxl/controlnet/depth-16bit-zoe-sdxl:
- description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).
- source: SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe
-sdxl/controlnet/depth-zoe-sdxl:
- description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).
- source: diffusers/controlnet-zoe-depth-sdxl-1.0
-sd-1/t2i_adapter/canny-sd15:
- source: TencentARC/t2iadapter_canny_sd15v2
-sd-1/t2i_adapter/sketch-sd15:
- source: TencentARC/t2iadapter_sketch_sd15v2
-sd-1/t2i_adapter/depth-sd15:
- source: TencentARC/t2iadapter_depth_sd15v2
-sd-1/t2i_adapter/zoedepth-sd15:
- source: TencentARC/t2iadapter_zoedepth_sd15v1
-sdxl/t2i_adapter/canny-sdxl:
- source: TencentARC/t2i-adapter-canny-sdxl-1.0
-sdxl/t2i_adapter/zoedepth-sdxl:
- source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
-sdxl/t2i_adapter/lineart-sdxl:
- source: TencentARC/t2i-adapter-lineart-sdxl-1.0
-sdxl/t2i_adapter/sketch-sdxl:
- source: TencentARC/t2i-adapter-sketch-sdxl-1.0
-sd-1/embedding/EasyNegative:
- source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
- recommended: True
- description: A textual inversion to use in the negative prompt to reduce bad anatomy
-sd-1/lora/FlatColor:
- source: https://civitai.com/models/6433/loraflatcolor
- recommended: True
- description: A LoRA that generates scenery using solid blocks of color
-sd-1/lora/Ink scenery:
- source: https://civitai.com/api/download/models/83390
- description: Generate india ink-like landscapes
-sd-1/ip_adapter/ip_adapter_sd15:
- source: InvokeAI/ip_adapter_sd15
- recommended: True
- requires:
- - InvokeAI/ip_adapter_sd_image_encoder
- description: IP-Adapter for SD 1.5 models
-sd-1/ip_adapter/ip_adapter_plus_sd15:
- source: InvokeAI/ip_adapter_plus_sd15
- recommended: False
- requires:
- - InvokeAI/ip_adapter_sd_image_encoder
- description: Refined IP-Adapter for SD 1.5 models
-sd-1/ip_adapter/ip_adapter_plus_face_sd15:
- source: InvokeAI/ip_adapter_plus_face_sd15
- recommended: False
- requires:
- - InvokeAI/ip_adapter_sd_image_encoder
- description: Refined IP-Adapter for SD 1.5 models, adapted for faces
-sdxl/ip_adapter/ip_adapter_sdxl:
- source: InvokeAI/ip_adapter_sdxl
- recommended: False
- requires:
- - InvokeAI/ip_adapter_sdxl_image_encoder
- description: IP-Adapter for SDXL models
-any/clip_vision/ip_adapter_sd_image_encoder:
- source: InvokeAI/ip_adapter_sd_image_encoder
- recommended: False
- description: Required model for using IP-Adapters with SD-1/2 models
-any/clip_vision/ip_adapter_sdxl_image_encoder:
- source: InvokeAI/ip_adapter_sdxl_image_encoder
- recommended: False
- description: Required model for using IP-Adapters with SDXL models
diff --git a/invokeai/frontend/install/invokeai_configure.py b/invokeai/frontend/install/invokeai_configure.py
deleted file mode 100644
index 6b2cc5236c..0000000000
--- a/invokeai/frontend/install/invokeai_configure.py
+++ /dev/null
@@ -1,60 +0,0 @@
-"""
-Wrapper for invokeai.backend.configure.invokeai_configure
-"""
-
-import argparse
-
-
-def run_configure() -> None:
- # Before doing _anything_, parse CLI args!
- from invokeai.frontend.cli.arg_parser import InvokeAIArgs
-
- parser = argparse.ArgumentParser(description="InvokeAI model downloader")
- parser.add_argument(
- "--skip-sd-weights",
- dest="skip_sd_weights",
- action=argparse.BooleanOptionalAction,
- default=False,
- help="skip downloading the large Stable Diffusion weight files",
- )
- parser.add_argument(
- "--skip-support-models",
- dest="skip_support_models",
- action=argparse.BooleanOptionalAction,
- default=False,
- help="skip downloading the support models",
- )
- parser.add_argument(
- "--full-precision",
- dest="full_precision",
- action=argparse.BooleanOptionalAction,
- type=bool,
- default=False,
- help="use 32-bit weights instead of faster 16-bit weights",
- )
- parser.add_argument(
- "--yes",
- "-y",
- dest="yes_to_all",
- action="store_true",
- help='answer "yes" to all prompts',
- )
- parser.add_argument(
- "--default_only",
- action="store_true",
- help="when --yes specified, only install the default model",
- )
- parser.add_argument(
- "--root_dir",
- dest="root",
- type=str,
- default=None,
- help="path to root of install directory",
- )
-
- opt = parser.parse_args()
- InvokeAIArgs.args = opt
-
- from invokeai.backend.install.invokeai_configure import main as invokeai_configure
-
- invokeai_configure(opt)
diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py
deleted file mode 100644
index b4b3e36d83..0000000000
--- a/invokeai/frontend/install/model_install.py
+++ /dev/null
@@ -1,652 +0,0 @@
-#!/usr/bin/env python
-# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
-# Before running stable-diffusion on an internet-isolated machine,
-# run this script from one with internet connectivity. The
-# two machines must share a common .cache directory.
-
-"""
-This is the npyscreen frontend to the model installation application.
-It is currently named model_install2.py, but will ultimately replace model_install.py.
-"""
-
-import argparse
-import curses
-import pathlib
-import sys
-import traceback
-import warnings
-from argparse import Namespace
-from shutil import get_terminal_size
-from typing import Any, Dict, List, Optional, Set
-
-import npyscreen
-import torch
-from npyscreen import widget
-
-from invokeai.app.services.config.config_default import get_config
-from invokeai.app.services.model_install import ModelInstallServiceBase
-from invokeai.backend.install.check_directories import validate_directories
-from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
-from invokeai.backend.model_manager import ModelType
-from invokeai.backend.util import choose_precision, choose_torch_device
-from invokeai.backend.util.logging import InvokeAILogger
-from invokeai.frontend.install.widgets import (
- MIN_COLS,
- MIN_LINES,
- CenteredTitleText,
- CyclingForm,
- MultiSelectColumns,
- SingleSelectColumns,
- TextBox,
- WindowTooSmallException,
- set_min_terminal_size,
-)
-
-warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402
-config = get_config()
-logger = InvokeAILogger.get_logger("ModelInstallService", config=config)
-# logger.setLevel("WARNING")
-# logger.setLevel('DEBUG')
-
-# build a table mapping all non-printable characters to None
-# for stripping control characters
-# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
-NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
-
-# maximum number of installed models we can display before overflowing vertically
-MAX_OTHER_MODELS = 72
-
-
-def make_printable(s: str) -> str:
- """Replace non-printable characters in a string."""
- return s.translate(NOPRINT_TRANS_TABLE)
-
-
-class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
- """Main form for interactive TUI."""
-
- # for responsive resizing set to False, but this seems to cause a crash!
- FIX_MINIMUM_SIZE_WHEN_CREATED = True
-
- # for persistence
- current_tab = 0
-
- def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any):
- self.multipage = multipage
- self.subprocess = None
- super().__init__(parentApp=parentApp, name=name, **keywords)
-
- def create(self) -> None:
- self.installer = self.parentApp.install_helper.installer
- self.model_labels = self._get_model_labels()
- self.keypress_timeout = 10
- self.counter = 0
- self.subprocess_connection = None
-
- window_width, window_height = get_terminal_size()
-
- # npyscreen has no typing hints
- self.nextrely -= 1 # type: ignore
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value="Use ctrl-N and ctrl-P to move to the ext and revious fields. Cursor keys navigate, and selects.",
- editable=False,
- color="CAUTION",
- )
- self.nextrely += 1 # type: ignore
- self.tabs = self.add_widget_intelligent(
- SingleSelectColumns,
- values=[
- "STARTERS",
- "MAINS",
- "CONTROLNETS",
- "T2I-ADAPTERS",
- "IP-ADAPTERS",
- "LORAS",
- "TI EMBEDDINGS",
- ],
- value=[self.current_tab],
- columns=7,
- max_height=2,
- relx=8,
- scroll_exit=True,
- )
- self.tabs.on_changed = self._toggle_tables
-
- top_of_table = self.nextrely # type: ignore
- self.starter_pipelines = self.add_starter_pipelines()
- bottom_of_table = self.nextrely # type: ignore
-
- self.nextrely = top_of_table
- self.pipeline_models = self.add_pipeline_widgets(
- model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models
- )
- # self.pipeline_models['autoload_pending'] = True
- bottom_of_table = max(bottom_of_table, self.nextrely)
-
- self.nextrely = top_of_table
- self.controlnet_models = self.add_model_widgets(
- model_type=ModelType.ControlNet,
- window_width=window_width,
- )
- bottom_of_table = max(bottom_of_table, self.nextrely)
-
- self.nextrely = top_of_table
- self.t2i_models = self.add_model_widgets(
- model_type=ModelType.T2IAdapter,
- window_width=window_width,
- )
- bottom_of_table = max(bottom_of_table, self.nextrely)
- self.nextrely = top_of_table
- self.ipadapter_models = self.add_model_widgets(
- model_type=ModelType.IPAdapter,
- window_width=window_width,
- )
- bottom_of_table = max(bottom_of_table, self.nextrely)
-
- self.nextrely = top_of_table
- self.lora_models = self.add_model_widgets(
- model_type=ModelType.LoRA,
- window_width=window_width,
- )
- bottom_of_table = max(bottom_of_table, self.nextrely)
-
- self.nextrely = top_of_table
- self.ti_models = self.add_model_widgets(
- model_type=ModelType.TextualInversion,
- window_width=window_width,
- )
- bottom_of_table = max(bottom_of_table, self.nextrely)
-
- self.nextrely = bottom_of_table + 1
-
- self.nextrely += 1
- back_label = "BACK"
- cancel_label = "CANCEL"
- current_position = self.nextrely
- if self.multipage:
- self.back_button = self.add_widget_intelligent(
- npyscreen.ButtonPress,
- name=back_label,
- when_pressed_function=self.on_back,
- )
- else:
- self.nextrely = current_position
- self.cancel_button = self.add_widget_intelligent(
- npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
- )
- self.nextrely = current_position
-
- label = "APPLY CHANGES"
- self.nextrely = current_position
- self.done = self.add_widget_intelligent(
- npyscreen.ButtonPress,
- name=label,
- relx=window_width - len(label) - 15,
- when_pressed_function=self.on_done,
- )
-
- # This restores the selected page on return from an installation
- for _i in range(1, self.current_tab + 1):
- self.tabs.h_cursor_line_down(1)
- self._toggle_tables([self.current_tab])
-
- ############# diffusers tab ##########
- def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
- """Add widgets responsible for selecting diffusers models"""
- widgets: Dict[str, npyscreen.widget] = {}
-
- all_models = self.all_models # master dict of all models, indexed by key
- model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]]
- model_labels = [self.model_labels[x] for x in model_list]
-
- widgets.update(
- label1=self.add_widget_intelligent(
- CenteredTitleText,
- name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
- editable=False,
- labelColor="CAUTION",
- )
- )
-
- self.nextrely -= 1
- # if user has already installed some initial models, then don't patronize them
- # by showing more recommendations
- show_recommended = len(self.installed_models) == 0
-
- checked = [
- model_list.index(x)
- for x in model_list
- if (show_recommended and all_models[x].recommended) or all_models[x].installed
- ]
- widgets.update(
- models_selected=self.add_widget_intelligent(
- MultiSelectColumns,
- columns=1,
- name="Install Starter Models",
- values=model_labels,
- value=checked,
- max_height=len(model_list) + 1,
- relx=4,
- scroll_exit=True,
- ),
- models=model_list,
- )
-
- self.nextrely += 1
- return widgets
-
- ############# Add a set of model install widgets ########
- def add_model_widgets(
- self,
- model_type: ModelType,
- window_width: int = 120,
- install_prompt: Optional[str] = None,
- exclude: Optional[Set[str]] = None,
- ) -> dict[str, npyscreen.widget]:
- """Generic code to create model selection widgets"""
- if exclude is None:
- exclude = set()
- widgets: Dict[str, npyscreen.widget] = {}
- all_models = self.all_models
- model_list = sorted(
- [x for x in all_models if all_models[x].type == model_type and x not in exclude],
- key=lambda x: all_models[x].name or "",
- )
- model_labels = [self.model_labels[x] for x in model_list]
-
- show_recommended = len(self.installed_models) == 0
- truncated = False
- if len(model_list) > 0:
- max_width = max([len(x) for x in model_labels])
- columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
- columns = min(len(model_list), columns) or 1
- prompt = (
- install_prompt
- or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
- )
-
- widgets.update(
- label1=self.add_widget_intelligent(
- CenteredTitleText,
- name=prompt,
- editable=False,
- labelColor="CAUTION",
- )
- )
-
- if len(model_labels) > MAX_OTHER_MODELS:
- model_labels = model_labels[0:MAX_OTHER_MODELS]
- truncated = True
-
- widgets.update(
- models_selected=self.add_widget_intelligent(
- MultiSelectColumns,
- columns=columns,
- name=f"Install {model_type} Models",
- values=model_labels,
- value=[
- model_list.index(x)
- for x in model_list
- if (show_recommended and all_models[x].recommended) or all_models[x].installed
- ],
- max_height=len(model_list) // columns + 1,
- relx=4,
- scroll_exit=True,
- ),
- models=model_list,
- )
-
- if truncated:
- widgets.update(
- warning_message=self.add_widget_intelligent(
- npyscreen.FixedText,
- value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
- editable=False,
- color="CAUTION",
- )
- )
-
- self.nextrely += 1
- widgets.update(
- download_ids=self.add_widget_intelligent(
- TextBox,
- name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
- max_height=6,
- scroll_exit=True,
- editable=True,
- )
- )
- return widgets
-
- ### Tab for arbitrary diffusers widgets ###
- def add_pipeline_widgets(
- self,
- model_type: ModelType = ModelType.Main,
- window_width: int = 120,
- **kwargs,
- ) -> dict[str, npyscreen.widget]:
- """Similar to add_model_widgets() but adds some additional widgets at the bottom
- to support the autoload directory"""
- widgets = self.add_model_widgets(
- model_type=model_type,
- window_width=window_width,
- install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
- **kwargs,
- )
-
- return widgets
-
- def resize(self) -> None:
- super().resize()
- if s := self.starter_pipelines.get("models_selected"):
- if model_list := self.starter_pipelines.get("models"):
- s.values = [self.model_labels[x] for x in model_list]
-
- def _toggle_tables(self, value: List[int]) -> None:
- selected_tab = value[0]
- widgets = [
- self.starter_pipelines,
- self.pipeline_models,
- self.controlnet_models,
- self.t2i_models,
- self.ipadapter_models,
- self.lora_models,
- self.ti_models,
- ]
-
- for group in widgets:
- for _k, v in group.items():
- try:
- v.hidden = True
- v.editable = False
- except Exception:
- pass
- for _k, v in widgets[selected_tab].items():
- try:
- v.hidden = False
- if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
- v.editable = True
- except Exception:
- pass
- self.__class__.current_tab = selected_tab # for persistence
- self.display()
-
- def _get_model_labels(self) -> dict[str, str]:
- """Return a list of trimmed labels for all models."""
- window_width, window_height = get_terminal_size()
- checkbox_width = 4
- spacing_width = 2
- result = {}
-
- models = self.all_models
- label_width = max([len(models[x].name or "") for x in self.starter_models])
- description_width = window_width - label_width - checkbox_width - spacing_width
-
- for key in self.all_models:
- description = models[key].description
- description = (
- description[0 : description_width - 3] + "..."
- if description and len(description) > description_width
- else description
- if description
- else ""
- )
- result[key] = f"%-{label_width}s %s" % (models[key].name, description)
-
- return result
-
- def _get_columns(self) -> int:
- window_width, window_height = get_terminal_size()
- cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1
- return min(cols, len(self.installed_models))
-
- def confirm_deletions(self, selections: InstallSelections) -> bool:
- remove_models = selections.remove_models
- if remove_models:
- model_names = [self.all_models[x].name or "" for x in remove_models]
- mods = "\n".join(model_names)
- is_ok = npyscreen.notify_ok_cancel(
- f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
- )
- assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations
- return is_ok
- else:
- return True
-
- @property
- def all_models(self) -> Dict[str, UnifiedModelInfo]:
- # npyscreen doesn't having typing hints
- return self.parentApp.install_helper.all_models # type: ignore
-
- @property
- def starter_models(self) -> List[str]:
- return self.parentApp.install_helper._starter_models # type: ignore
-
- @property
- def installed_models(self) -> List[str]:
- return self.parentApp.install_helper._installed_models # type: ignore
-
- def on_back(self) -> None:
- self.parentApp.switchFormPrevious()
- self.editing = False
-
- def on_cancel(self) -> None:
- self.parentApp.setNextForm(None)
- self.parentApp.user_cancelled = True
- self.editing = False
-
- def on_done(self) -> None:
- self.marshall_arguments()
- if not self.confirm_deletions(self.parentApp.install_selections):
- return
- self.parentApp.setNextForm(None)
- self.parentApp.user_cancelled = False
- self.editing = False
-
- def marshall_arguments(self) -> None:
- """
- Assemble arguments and store as attributes of the application:
- .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
- True => Install
- False => Remove
- .scan_directory: Path to a directory of models to scan and import
- .autoscan_on_startup: True if invokeai should scan and import at startup time
- .import_model_paths: list of URLs, repo_ids and file paths to import
- """
- selections = self.parentApp.install_selections
- all_models = self.all_models
-
- # Defined models (in INITIAL_CONFIG.yaml or invokeai.db) to add/remove
- ui_sections = [
- self.starter_pipelines,
- self.pipeline_models,
- self.controlnet_models,
- self.t2i_models,
- self.ipadapter_models,
- self.lora_models,
- self.ti_models,
- ]
- for section in ui_sections:
- if "models_selected" not in section:
- continue
- selected = {section["models"][x] for x in section["models_selected"].value}
- models_to_install = [x for x in selected if not self.all_models[x].installed]
- models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
- selections.remove_models.extend(models_to_remove)
- selections.install_models.extend([all_models[x] for x in models_to_install])
-
- # models located in the 'download_ids" section
- for section in ui_sections:
- if downloads := section.get("download_ids"):
- models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
- selections.install_models.extend(models)
-
-
-class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
- def __init__(self, opt: Namespace, install_helper: InstallHelper):
- super().__init__()
- self.program_opts = opt
- self.user_cancelled = False
- self.install_selections = InstallSelections()
- self.install_helper = install_helper
-
- def onStart(self) -> None:
- npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
- self.main_form = self.addForm(
- "MAIN",
- addModelsForm,
- name="Install Stable Diffusion Models",
- cycle_widgets=False,
- )
-
-
-def list_models(installer: ModelInstallServiceBase, model_type: ModelType):
- """Print out all models of type model_type."""
- models = installer.record_store.search_by_attr(model_type=model_type)
- print(f"Installed models of type `{model_type}`:")
- for model in models:
- path = (config.models_path / model.path).resolve()
- print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}")
-
-
-# --------------------------------------------------------
-def select_and_download_models(opt: Namespace) -> None:
- """Prompt user for install/delete selections and execute."""
- precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
- config.precision = precision
- install_helper = InstallHelper(config, logger)
- installer = install_helper.installer
-
- if opt.list_models:
- list_models(installer, opt.list_models)
-
- elif opt.add or opt.delete:
- selections = InstallSelections(
- install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
- )
- install_helper.add_or_delete(selections)
-
- elif opt.default_only:
- default_model = install_helper.default_model()
- assert default_model is not None
- selections = InstallSelections(install_models=[default_model])
- install_helper.add_or_delete(selections)
-
- elif opt.yes_to_all:
- selections = InstallSelections(install_models=install_helper.recommended_models())
- install_helper.add_or_delete(selections)
-
- # this is where the TUI is called
- else:
- if not set_min_terminal_size(MIN_COLS, MIN_LINES):
- raise WindowTooSmallException(
- "Could not increase terminal size. Try running again with a larger window or smaller font size."
- )
-
- installApp = AddModelApplication(opt, install_helper)
- try:
- installApp.run()
- except KeyboardInterrupt:
- print("Aborted...")
- sys.exit(-1)
-
- install_helper.add_or_delete(installApp.install_selections)
-
-
-# -------------------------------------
-def main() -> None:
- parser = argparse.ArgumentParser(description="InvokeAI model downloader")
- parser.add_argument(
- "--add",
- nargs="*",
- help="List of URLs, local paths or repo_ids of models to install",
- )
- parser.add_argument(
- "--delete",
- nargs="*",
- help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
- )
- parser.add_argument(
- "--full-precision",
- dest="full_precision",
- action=argparse.BooleanOptionalAction,
- type=bool,
- default=False,
- help="use 32-bit weights instead of faster 16-bit weights",
- )
- parser.add_argument(
- "--yes",
- "-y",
- dest="yes_to_all",
- action="store_true",
- help='answer "yes" to all prompts',
- )
- parser.add_argument(
- "--default_only",
- action="store_true",
- help="Only install the default model",
- )
- parser.add_argument(
- "--list-models",
- choices=[x.value for x in ModelType],
- help="list installed models",
- )
- parser.add_argument(
- "--root_dir",
- dest="root",
- type=pathlib.Path,
- default=None,
- help="path to root of install directory",
- )
- opt = parser.parse_args()
-
- invoke_args: dict[str, Any] = {}
- if opt.full_precision:
- invoke_args["precision"] = "float32"
- config.update_config(invoke_args)
- if opt.root:
- config.set_root(opt.root)
-
- logger = InvokeAILogger().get_logger(config=config)
-
- try:
- validate_directories(config)
- except AssertionError:
- logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
- sys.argv = ["invokeai_configure", "--yes", "--skip-sd-weights"]
- from invokeai.frontend.install.invokeai_configure import invokeai_configure
-
- invokeai_configure()
- sys.exit(0)
-
- try:
- select_and_download_models(opt)
- except AssertionError as e:
- logger.error(e)
- sys.exit(-1)
- except KeyboardInterrupt:
- curses.nocbreak()
- curses.echo()
- curses.endwin()
- logger.info("Goodbye! Come back soon.")
- except WindowTooSmallException as e:
- logger.error(str(e))
- except widget.NotEnoughSpaceForWidget as e:
- if str(e).startswith("Height of 1 allocated"):
- logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
- input("Press any key to continue...")
- except Exception as e:
- if str(e).startswith("addwstr"):
- logger.error(
- "Insufficient horizontal space for the interface. Please make your window wider and try again."
- )
- else:
- print(f"An exception has occurred: {str(e)} Details:")
- print(traceback.format_exc(), file=sys.stderr)
- input("Press any key to continue...")
-
-
-# -------------------------------------
-if __name__ == "__main__":
- main()
diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py
deleted file mode 100644
index 49ca1e3583..0000000000
--- a/invokeai/frontend/install/widgets.py
+++ /dev/null
@@ -1,441 +0,0 @@
-"""
-Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
-"""
-
-import curses
-import math
-import os
-import platform
-import struct
-import subprocess
-import sys
-import textwrap
-from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
-from shutil import get_terminal_size
-from typing import Optional
-
-import npyscreen
-import npyscreen.wgmultiline as wgmultiline
-import pyperclip
-from npyscreen import fmPopup
-
-# minimum size for UIs
-MIN_COLS = 150
-MIN_LINES = 40
-
-
-class WindowTooSmallException(Exception):
- pass
-
-
-# -------------------------------------
-def set_terminal_size(columns: int, lines: int) -> bool:
- OS = platform.uname().system
- screen_ok = False
- while not screen_ok:
- ts = get_terminal_size()
- width = max(columns, ts.columns)
- height = max(lines, ts.lines)
-
- if OS == "Windows":
- pass
- # not working reliably - ask user to adjust the window
- # _set_terminal_size_powershell(width,height)
- elif OS in ["Darwin", "Linux"]:
- _set_terminal_size_unix(width, height)
-
- # check whether it worked....
- ts = get_terminal_size()
- if ts.columns < columns or ts.lines < lines:
- print(
- f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m"
- )
- resp = input(
- "Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.."
- )
- if resp.upper().startswith("Q"):
- break
- else:
- screen_ok = True
- return screen_ok
-
-
-def _set_terminal_size_powershell(width: int, height: int):
- script = f"""
-$pshost = get-host
-$pswindow = $pshost.ui.rawui
-$newsize = $pswindow.buffersize
-$newsize.height = 3000
-$newsize.width = {width}
-$pswindow.buffersize = $newsize
-$newsize = $pswindow.windowsize
-$newsize.height = {height}
-$newsize.width = {width}
-$pswindow.windowsize = $newsize
-"""
- subprocess.run(["powershell", "-Command", "-"], input=script, text=True)
-
-
-def _set_terminal_size_unix(width: int, height: int):
- import fcntl
- import termios
-
- # These terminals accept the size command and report that the
- # size changed, but they lie!!!
- for bad_terminal in ["TERMINATOR_UUID", "ALACRITTY_WINDOW_ID"]:
- if os.environ.get(bad_terminal):
- return
-
- winsize = struct.pack("HHHH", height, width, 0, 0)
- fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
- sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
- sys.stdout.flush()
-
-
-def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
- # make sure there's enough room for the ui
- term_cols, term_lines = get_terminal_size()
- if term_cols >= min_cols and term_lines >= min_lines:
- return True
- cols = max(term_cols, min_cols)
- lines = max(term_lines, min_lines)
- return set_terminal_size(cols, lines)
-
-
-class IntSlider(npyscreen.Slider):
- def translate_value(self):
- stri = "%2d / %2d" % (self.value, self.out_of)
- length = (len(str(self.out_of))) * 2 + 4
- stri = stri.rjust(length)
- return stri
-
-
-# -------------------------------------
-# fix npyscreen form so that cursor wraps both forward and backward
-class CyclingForm(object):
- def find_previous_editable(self, *args):
- done = False
- n = self.editw - 1
- while not done:
- if self._widgets__[n].editable and not self._widgets__[n].hidden:
- self.editw = n
- done = True
- n -= 1
- if n < 0:
- if self.cycle_widgets:
- n = len(self._widgets__) - 1
- else:
- done = True
-
-
-# -------------------------------------
-class CenteredTitleText(npyscreen.TitleText):
- def __init__(self, *args, **keywords):
- super().__init__(*args, **keywords)
- self.resize()
-
- def resize(self):
- super().resize()
- maxy, maxx = self.parent.curses_pad.getmaxyx()
- label = self.name
- self.relx = (maxx - len(label)) // 2
-
-
-# -------------------------------------
-class CenteredButtonPress(npyscreen.ButtonPress):
- def resize(self):
- super().resize()
- maxy, maxx = self.parent.curses_pad.getmaxyx()
- label = self.name
- self.relx = (maxx - len(label)) // 2
-
-
-# -------------------------------------
-class OffsetButtonPress(npyscreen.ButtonPress):
- def __init__(self, screen, offset=0, *args, **keywords):
- super().__init__(screen, *args, **keywords)
- self.offset = offset
-
- def resize(self):
- maxy, maxx = self.parent.curses_pad.getmaxyx()
- width = len(self.name)
- self.relx = self.offset + (maxx - width) // 2
-
-
-class IntTitleSlider(npyscreen.TitleText):
- _entry_type = IntSlider
-
-
-class FloatSlider(npyscreen.Slider):
- # this is supposed to adjust display precision, but doesn't
- def translate_value(self):
- stri = "%3.2f / %3.2f" % (self.value, self.out_of)
- length = (len(str(self.out_of))) * 2 + 4
- stri = stri.rjust(length)
- return stri
-
-
-class FloatTitleSlider(npyscreen.TitleText):
- _entry_type = npyscreen.Slider
-
-
-class SelectColumnBase:
- """Base class for selection widget arranged in columns."""
-
- def make_contained_widgets(self):
- self._my_widgets = []
- column_width = self.width // self.columns
- for h in range(self.value_cnt):
- self._my_widgets.append(
- self._contained_widgets(
- self.parent,
- rely=self.rely + (h % self.rows) * self._contained_widget_height,
- relx=self.relx + (h // self.rows) * column_width,
- max_width=column_width,
- max_height=self.__class__._contained_widget_height,
- )
- )
-
- def set_up_handlers(self):
- super().set_up_handlers()
- self.handlers.update(
- {
- curses.KEY_UP: self.h_cursor_line_left,
- curses.KEY_DOWN: self.h_cursor_line_right,
- }
- )
-
- def h_cursor_line_down(self, ch):
- self.cursor_line += self.rows
- if self.cursor_line >= len(self.values):
- if self.scroll_exit:
- self.cursor_line = len(self.values) - self.rows
- self.h_exit_down(ch)
- return True
- else:
- self.cursor_line -= self.rows
- return True
-
- def h_cursor_line_up(self, ch):
- self.cursor_line -= self.rows
- if self.cursor_line < 0:
- if self.scroll_exit:
- self.cursor_line = 0
- self.h_exit_up(ch)
- else:
- self.cursor_line = 0
-
- def h_cursor_line_left(self, ch):
- super().h_cursor_line_up(ch)
-
- def h_cursor_line_right(self, ch):
- super().h_cursor_line_down(ch)
-
- def handle_mouse_event(self, mouse_event):
- mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
- column_width = self.width // self.columns
- column_height = math.ceil(self.value_cnt / self.columns)
- column_no = rel_x // column_width
- row_no = rel_y // self._contained_widget_height
- self.cursor_line = column_no * column_height + row_no
- if bstate & curses.BUTTON1_DOUBLE_CLICKED:
- if hasattr(self, "on_mouse_double_click"):
- self.on_mouse_double_click(self.cursor_line)
- self.display()
-
-
-class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
- def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
- if values is None:
- values = []
- self.columns = columns
- self.value_cnt = len(values)
- self.rows = math.ceil(self.value_cnt / self.columns)
- super().__init__(screen, values=values, **keywords)
-
- def on_mouse_double_click(self, cursor_line):
- self.h_select_toggle(cursor_line)
-
-
-class SingleSelectWithChanged(npyscreen.SelectOne):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.on_changed = None
-
- def h_select(self, ch):
- super().h_select(ch)
- if self.on_changed:
- self.on_changed(self.value)
-
-
-class CheckboxWithChanged(npyscreen.Checkbox):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.on_changed = None
-
- def whenToggled(self):
- super().whenToggled()
- if self.on_changed:
- self.on_changed(self.value)
-
-
-class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
- """Row of radio buttons. Spacebar to select."""
-
- def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
- if values is None:
- values = []
- self.columns = columns
- self.value_cnt = len(values)
- self.rows = math.ceil(self.value_cnt / self.columns)
- self.on_changed = None
- super().__init__(screen, values=values, **keywords)
-
- def h_cursor_line_right(self, ch):
- self.h_exit_down("bye bye")
-
- def h_cursor_line_left(self, ch):
- self.h_exit_up("bye bye")
-
-
-class SingleSelectColumns(SingleSelectColumnsSimple):
- """Row of radio buttons. When tabbing over a selection, it is auto selected."""
-
- def when_cursor_moved(self):
- self.h_select(self.cursor_line)
-
-
-class TextBoxInner(npyscreen.MultiLineEdit):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.yank = None
- self.handlers.update(
- {
- "^A": self.h_cursor_to_start,
- "^E": self.h_cursor_to_end,
- "^K": self.h_kill,
- "^F": self.h_cursor_right,
- "^B": self.h_cursor_left,
- "^Y": self.h_yank,
- "^V": self.h_paste,
- }
- )
-
- def h_cursor_to_start(self, input):
- self.cursor_position = 0
-
- def h_cursor_to_end(self, input):
- self.cursor_position = len(self.value)
-
- def h_kill(self, input):
- self.yank = self.value[self.cursor_position :]
- self.value = self.value[: self.cursor_position]
-
- def h_yank(self, input):
- if self.yank:
- self.paste(self.yank)
-
- def paste(self, text: str):
- self.value = self.value[: self.cursor_position] + text + self.value[self.cursor_position :]
- self.cursor_position += len(text)
-
- def h_paste(self, input: int = 0):
- try:
- text = pyperclip.paste()
- except ModuleNotFoundError:
- text = "To paste with the mouse on Linux, please install the 'xclip' program."
- self.paste(text)
-
- def handle_mouse_event(self, mouse_event):
- mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
- if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED):
- self.h_paste()
-
-
-class TextBox(npyscreen.BoxTitle):
- _contained_widget = TextBoxInner
-
-
-class BufferBox(npyscreen.BoxTitle):
- _contained_widget = npyscreen.BufferPager
-
-
-class ConfirmCancelPopup(fmPopup.ActionPopup):
- DEFAULT_COLUMNS = 100
-
- def on_ok(self):
- self.value = True
-
- def on_cancel(self):
- self.value = False
-
-
-class FileBox(npyscreen.BoxTitle):
- _contained_widget = npyscreen.Filename
-
-
-class PrettyTextBox(npyscreen.BoxTitle):
- _contained_widget = TextBox
-
-
-def _wrap_message_lines(message, line_length):
- lines = []
- for line in message.split("\n"):
- lines.extend(textwrap.wrap(line.rstrip(), line_length))
- return lines
-
-
-def _prepare_message(message):
- if isinstance(message, list) or isinstance(message, tuple):
- return "\n".join([s.rstrip() for s in message])
- # return "\n".join(message)
- else:
- return message
-
-
-def select_stable_diffusion_config_file(
- form_color: str = "DANGER",
- wrap: bool = True,
- model_name: str = "Unknown",
-):
- message = f"Please select the correct prediction type for the checkpoint named '{model_name}'. Press to skip installation."
- title = "CONFIG FILE SELECTION"
- options = [
- "'epsilon' - most v1.5 models and v2 models trained on 512 pixel images",
- "'vprediction' - v2 models trained on 768 pixel images and a few v1.5 models)",
- "Accept the best guess; you can fix it in the Web UI later",
- ]
-
- F = ConfirmCancelPopup(
- name=title,
- color=form_color,
- cycle_widgets=True,
- lines=16,
- )
- F.preserve_selected_widget = True
-
- mlw = F.add(
- wgmultiline.Pager,
- max_height=4,
- editable=False,
- )
- mlw_width = mlw.width - 1
- if wrap:
- message = _wrap_message_lines(message, mlw_width)
- mlw.values = message
-
- choice = F.add(
- npyscreen.SelectOne,
- values=options,
- value=[2],
- max_height=len(options) + 1,
- scroll_exit=True,
- )
-
- F.editw = 1
- F.edit()
- if not F.value:
- return None
- assert choice.value[0] in range(0, 3), "invalid choice"
- choices = ["epsilon", "v", "guess"]
- return choices[choice.value[0]]
diff --git a/invokeai/frontend/legacy_launch_invokeai.py b/invokeai/frontend/legacy_launch_invokeai.py
deleted file mode 100644
index 9e4cca7eac..0000000000
--- a/invokeai/frontend/legacy_launch_invokeai.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import argparse
-import sys
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--web", action="store_true")
- opts, _ = parser.parse_known_args()
-
- if opts.web:
- sys.argv.pop(sys.argv.index("--web"))
- from invokeai.app.api_app import invoke_api
-
- invoke_api()
- else:
- from invokeai.app.cli_app import invoke_cli
-
- invoke_cli()
-
-
-if __name__ == "__main__":
- main()
diff --git a/invokeai/frontend/merge/__init__.py b/invokeai/frontend/merge/__init__.py
deleted file mode 100644
index 4e56b146f4..0000000000
--- a/invokeai/frontend/merge/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""
-Initialization file for invokeai.frontend.merge
-"""
-
-from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401
diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py
deleted file mode 100644
index 58c872df85..0000000000
--- a/invokeai/frontend/merge/merge_diffusers.py
+++ /dev/null
@@ -1,448 +0,0 @@
-"""
-invokeai.frontend.merge exports a single function call merge_diffusion_models()
-used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
-
-Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
-"""
-
-import argparse
-import curses
-import re
-import sys
-from argparse import Namespace
-from pathlib import Path
-from typing import List, Optional, Tuple
-
-import npyscreen
-from npyscreen import widget
-
-from invokeai.app.services.config.config_default import get_config
-from invokeai.app.services.download import DownloadQueueService
-from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
-from invokeai.app.services.model_install import ModelInstallService
-from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
-from invokeai.app.services.shared.sqlite.sqlite_util import init_db
-from invokeai.backend.model_manager import (
- BaseModelType,
- ModelFormat,
- ModelType,
- ModelVariantType,
-)
-from invokeai.backend.model_manager.merge import ModelMerger
-from invokeai.backend.util.logging import InvokeAILogger
-from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
-
-config = get_config()
-logger = InvokeAILogger.get_logger()
-
-BASE_TYPES = [
- (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
- (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
- (BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
-]
-
-
-def _parse_args() -> Namespace:
- parser = argparse.ArgumentParser(description="InvokeAI model merging")
- parser.add_argument(
- "--root_dir",
- type=Path,
- default=config.root_path,
- help="Path to the invokeai runtime directory",
- )
- parser.add_argument(
- "--front_end",
- "--gui",
- dest="front_end",
- action="store_true",
- default=False,
- help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
- )
- parser.add_argument(
- "--models",
- dest="model_names",
- type=str,
- nargs="+",
- help="Two to three model names to be merged",
- )
- parser.add_argument(
- "--base_model",
- type=str,
- choices=[x[0].value for x in BASE_TYPES],
- help="The base model shared by the models to be merged",
- )
- parser.add_argument(
- "--merged_model_name",
- "--destination",
- dest="merged_model_name",
- type=str,
- help="Name of the output model. If not specified, will be the concatenation of the input model names.",
- )
- parser.add_argument(
- "--alpha",
- type=float,
- default=0.5,
- help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
- )
- parser.add_argument(
- "--interpolation",
- dest="interp",
- type=str,
- choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
- default="weighted_sum",
- help='Interpolation method to use. If three models are present, only "add_difference" will work.',
- )
- parser.add_argument(
- "--force",
- action="store_true",
- help="Try to merge models even if they are incompatible with each other",
- )
- parser.add_argument(
- "--clobber",
- "--overwrite",
- dest="clobber",
- action="store_true",
- help="Overwrite the merged model if --merged_model_name already exists",
- )
- return parser.parse_args()
-
-
-# ------------------------- GUI HERE -------------------------
-class mergeModelsForm(npyscreen.FormMultiPageAction):
- interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
-
- def __init__(self, parentApp, name):
- self.parentApp = parentApp
- self.ALLOW_RESIZE = True
- self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
- super().__init__(parentApp, name)
-
- @property
- def record_store(self):
- return self.parentApp.record_store
-
- def afterEditing(self):
- self.parentApp.setNextForm(None)
-
- def create(self):
- window_height, window_width = curses.initscr().getmaxyx()
- self.current_base = 0
- self.models = self.get_models(BASE_TYPES[self.current_base][0])
- self.model_names = [x[1] for x in self.models]
- max_width = max([len(x) for x in self.model_names])
- max_width += 6
- horizontal_layout = max_width * 3 < window_width
-
- self.add_widget_intelligent(
- npyscreen.FixedText,
- color="CONTROL",
- value="Select two models to merge and optionally a third.",
- editable=False,
- )
- self.add_widget_intelligent(
- npyscreen.FixedText,
- color="CONTROL",
- value="Use up and down arrows to move, to select an item, and to move from one field to the next.",
- editable=False,
- )
- self.nextrely += 1
- self.base_select = self.add_widget_intelligent(
- SingleSelectColumns,
- values=[x[1] for x in BASE_TYPES],
- value=[self.current_base],
- columns=4,
- max_height=2,
- relx=8,
- scroll_exit=True,
- )
- self.base_select.on_changed = self._populate_models
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value="MODEL 1",
- color="GOOD",
- editable=False,
- rely=6 if horizontal_layout else None,
- )
- self.model1 = self.add_widget_intelligent(
- npyscreen.SelectOne,
- values=self.model_names,
- value=0,
- max_height=len(self.model_names),
- max_width=max_width,
- scroll_exit=True,
- rely=7,
- )
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value="MODEL 2",
- color="GOOD",
- editable=False,
- relx=max_width + 3 if horizontal_layout else None,
- rely=6 if horizontal_layout else None,
- )
- self.model2 = self.add_widget_intelligent(
- npyscreen.SelectOne,
- name="(2)",
- values=self.model_names,
- value=1,
- max_height=len(self.model_names),
- max_width=max_width,
- relx=max_width + 3 if horizontal_layout else None,
- rely=7 if horizontal_layout else None,
- scroll_exit=True,
- )
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value="MODEL 3",
- color="GOOD",
- editable=False,
- relx=max_width * 2 + 3 if horizontal_layout else None,
- rely=6 if horizontal_layout else None,
- )
- models_plus_none = self.model_names.copy()
- models_plus_none.insert(0, "None")
- self.model3 = self.add_widget_intelligent(
- npyscreen.SelectOne,
- name="(3)",
- values=models_plus_none,
- value=0,
- max_height=len(self.model_names) + 1,
- max_width=max_width,
- scroll_exit=True,
- relx=max_width * 2 + 3 if horizontal_layout else None,
- rely=7 if horizontal_layout else None,
- )
- for m in [self.model1, self.model2, self.model3]:
- m.when_value_edited = self.models_changed
- self.merged_model_name = self.add_widget_intelligent(
- TextBox,
- name="Name for merged model:",
- labelColor="CONTROL",
- max_height=3,
- value="",
- scroll_exit=True,
- )
- self.force = self.add_widget_intelligent(
- npyscreen.Checkbox,
- name="Force merge of models created by different diffusers library versions",
- labelColor="CONTROL",
- value=True,
- scroll_exit=True,
- )
- self.nextrely += 1
- self.merge_method = self.add_widget_intelligent(
- npyscreen.TitleSelectOne,
- name="Merge Method:",
- values=self.interpolations,
- value=0,
- labelColor="CONTROL",
- max_height=len(self.interpolations) + 1,
- scroll_exit=True,
- )
- self.alpha = self.add_widget_intelligent(
- FloatTitleSlider,
- name="Weight (alpha) to assign to second and third models:",
- out_of=1.0,
- step=0.01,
- lowest=0,
- value=0.5,
- labelColor="CONTROL",
- scroll_exit=True,
- )
- self.model1.editing = True
-
- def models_changed(self):
- models = self.model1.values
- selected_model1 = self.model1.value[0]
- selected_model2 = self.model2.value[0]
- selected_model3 = self.model3.value[0]
- merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
- self.merged_model_name.value = merged_model_name
-
- if selected_model3 > 0:
- self.merge_method.values = ["add_difference ( A+(B-C) )"]
- self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
- else:
- self.merge_method.values = self.interpolations
- self.merge_method.value = 0
-
- def on_ok(self):
- if self.validate_field_values() and self.check_for_overwrite():
- self.parentApp.setNextForm(None)
- self.editing = False
- self.parentApp.merge_arguments = self.marshall_arguments()
- npyscreen.notify("Starting the merge...")
- else:
- self.editing = True
-
- def on_cancel(self):
- sys.exit(0)
-
- def marshall_arguments(self) -> dict:
- model_keys = [x[0] for x in self.models]
- models = [
- model_keys[self.model1.value[0]],
- model_keys[self.model2.value[0]],
- ]
- if self.model3.value[0] > 0:
- models.append(model_keys[self.model3.value[0] - 1])
- interp = "add_difference"
- else:
- interp = self.interpolations[self.merge_method.value[0]]
-
- args = {
- "model_keys": models,
- "base_model": tuple(BaseModelType)[self.base_select.value[0]],
- "alpha": self.alpha.value,
- "interp": interp,
- "force": self.force.value,
- "merged_model_name": self.merged_model_name.value,
- }
- return args
-
- def check_for_overwrite(self) -> bool:
- model_out = self.merged_model_name.value
- if model_out not in self.model_names:
- return True
- else:
- return npyscreen.notify_yes_no(
- f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
- )
-
- def validate_field_values(self) -> bool:
- bad_fields = []
- model_names = self.model_names
- selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
- if self.model3.value[0] > 0:
- selected_models.add(model_names[self.model3.value[0] - 1])
- if len(selected_models) < 2:
- bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
- if len(bad_fields) > 0:
- message = "The following problems were detected and must be corrected:"
- for problem in bad_fields:
- message += f"\n* {problem}"
- npyscreen.notify_confirm(message)
- return False
- else:
- return True
-
- def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
- models = [
- (x.key, x.name)
- for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
- if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
- ]
- return sorted(models, key=lambda x: x[1])
-
- def _populate_models(self, value: List[int]):
- base_model = BASE_TYPES[value[0]][0]
- self.models = self.get_models(base_model)
- self.model_names = [x[1] for x in self.models]
-
- models_plus_none = self.model_names.copy()
- models_plus_none.insert(0, "None")
- self.model1.values = self.model_names
- self.model2.values = self.model_names
- self.model3.values = models_plus_none
-
- self.display()
-
-
-class Mergeapp(npyscreen.NPSAppManaged):
- def __init__(self, record_store: ModelRecordServiceBase):
- super().__init__()
- self.record_store = record_store
-
- def onStart(self):
- npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
- self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
-
-
-def run_gui(args: Namespace) -> None:
- record_store: ModelRecordServiceBase = get_config_store()
- mergeapp = Mergeapp(record_store)
- mergeapp.run()
- args = mergeapp.merge_arguments
- merger = get_model_merger(record_store)
- merger.merge_diffusion_models_and_save(**args)
- merged_model_name = args["merged_model_name"]
- logger.info(f'Models merged into new model: "{merged_model_name}".')
-
-
-def run_cli(args: Namespace):
- assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
- assert (
- args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
- ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
-
- if not args.merged_model_name:
- args.merged_model_name = "+".join(args.model_names)
- logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
-
- record_store: ModelRecordServiceBase = get_config_store()
- assert (
- len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
- ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
-
- merger = get_model_merger(record_store)
- model_keys = []
- for name in args.model_names:
- if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
- model_keys.append(name)
- else:
- models = record_store.search_by_attr(
- model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
- )
- assert len(models) > 0, f"{name}: Unknown model"
- assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
- model_keys.append(models[0].key)
-
- merger.merge_diffusion_models_and_save(
- alpha=args.alpha,
- model_keys=model_keys,
- merged_model_name=args.merged_model_name,
- interp=args.interp,
- force=args.force,
- )
- logger.info(f'Models merged into new model: "{args.merged_model_name}".')
-
-
-def get_config_store() -> ModelRecordServiceSQL:
- output_path = config.outputs_path
- assert output_path is not None
- image_files = DiskImageFileStorage(output_path / "images")
- db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
- return ModelRecordServiceSQL(db)
-
-
-def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
- installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
- installer.start()
- return ModelMerger(installer)
-
-
-def main():
- args = _parse_args()
- if args.root_dir:
- config.set_root(Path(args.root_dir))
-
- try:
- if args.front_end:
- run_gui(args)
- else:
- run_cli(args)
- except widget.NotEnoughSpaceForWidget as e:
- if str(e).startswith("Height of 1 allocated"):
- logger.error("You need to have at least two diffusers models in order to merge")
- else:
- logger.error("Not enough room for the user interface. Try making this window larger.")
- sys.exit(-1)
- except Exception as e:
- logger.error(e)
- sys.exit(-1)
- except KeyboardInterrupt:
- sys.exit(-1)
-
-
-if __name__ == "__main__":
- main()
diff --git a/invokeai/frontend/training/__init__.py b/invokeai/frontend/training/__init__.py
deleted file mode 100644
index 7e002b4c03..0000000000
--- a/invokeai/frontend/training/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""
-Initialization file for invokeai.frontend.training
-"""
-
-from .textual_inversion import main as invokeai_textual_inversion # noqa: F401
diff --git a/invokeai/frontend/training/textual_inversion.py b/invokeai/frontend/training/textual_inversion.py
deleted file mode 100644
index 05f4a347ec..0000000000
--- a/invokeai/frontend/training/textual_inversion.py
+++ /dev/null
@@ -1,452 +0,0 @@
-#!/usr/bin/env python
-
-"""
-This is the frontend to "textual_inversion_training.py".
-
-Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
-"""
-
-import os
-import re
-import shutil
-import sys
-import traceback
-from argparse import Namespace
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple
-
-import npyscreen
-from npyscreen import widget
-from omegaconf import OmegaConf
-
-import invokeai.backend.util.logging as logger
-from invokeai.app.services.config.config_default import get_config
-from invokeai.backend.install.install_helper import initialize_installer
-from invokeai.backend.model_manager import ModelType
-from invokeai.backend.training import do_textual_inversion_training, parse_args
-
-TRAINING_DATA = "text-inversion-training-data"
-TRAINING_DIR = "text-inversion-output"
-CONF_FILE = "preferences.conf"
-config = None
-
-
-class textualInversionForm(npyscreen.FormMultiPageAction):
- resolutions = [512, 768, 1024]
- lr_schedulers = [
- "linear",
- "cosine",
- "cosine_with_restarts",
- "polynomial",
- "constant",
- "constant_with_warmup",
- ]
- precisions = ["no", "fp16", "bf16"]
- learnable_properties = ["object", "style"]
-
- def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
- self.saved_args = saved_args or {}
- super().__init__(parentApp, name)
-
- def afterEditing(self) -> None:
- self.parentApp.setNextForm(None)
-
- def create(self) -> None:
- self.model_names, default = self.get_model_names()
- default_initializer_token = "★"
- default_placeholder_token = ""
- saved_args = self.saved_args
-
- assert config is not None
-
- try:
- default = self.model_names.index(saved_args["model"])
- except Exception:
- pass
-
- self.add_widget_intelligent(
- npyscreen.FixedText,
- value="Use ctrl-N and ctrl-P to move to the ext and revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
- editable=False,
- )
-
- self.model = self.add_widget_intelligent(
- npyscreen.TitleSelectOne,
- name="Model Name:",
- values=sorted(self.model_names),
- value=default,
- max_height=len(self.model_names) + 1,
- scroll_exit=True,
- )
- self.placeholder_token = self.add_widget_intelligent(
- npyscreen.TitleText,
- name="Trigger Term:",
- value="", # saved_args.get('placeholder_token',''), # to restore previous term
- scroll_exit=True,
- )
- self.placeholder_token.when_value_edited = self.initializer_changed
- self.nextrely -= 1
- self.nextrelx += 30
- self.prompt_token = self.add_widget_intelligent(
- npyscreen.FixedText,
- name="Trigger term for use in prompt",
- value="",
- editable=False,
- scroll_exit=True,
- )
- self.nextrelx -= 30
- self.initializer_token = self.add_widget_intelligent(
- npyscreen.TitleText,
- name="Initializer:",
- value=saved_args.get("initializer_token", default_initializer_token),
- scroll_exit=True,
- )
- self.resume_from_checkpoint = self.add_widget_intelligent(
- npyscreen.Checkbox,
- name="Resume from last saved checkpoint",
- value=False,
- scroll_exit=True,
- )
- self.learnable_property = self.add_widget_intelligent(
- npyscreen.TitleSelectOne,
- name="Learnable property:",
- values=self.learnable_properties,
- value=self.learnable_properties.index(saved_args.get("learnable_property", "object")),
- max_height=4,
- scroll_exit=True,
- )
- self.train_data_dir = self.add_widget_intelligent(
- npyscreen.TitleFilename,
- name="Data Training Directory:",
- select_dir=True,
- must_exist=False,
- value=str(
- saved_args.get(
- "train_data_dir",
- config.root_path / TRAINING_DATA / default_placeholder_token,
- )
- ),
- scroll_exit=True,
- )
- self.output_dir = self.add_widget_intelligent(
- npyscreen.TitleFilename,
- name="Output Destination Directory:",
- select_dir=True,
- must_exist=False,
- value=str(
- saved_args.get(
- "output_dir",
- config.root_path / TRAINING_DIR / default_placeholder_token,
- )
- ),
- scroll_exit=True,
- )
- self.resolution = self.add_widget_intelligent(
- npyscreen.TitleSelectOne,
- name="Image resolution (pixels):",
- values=self.resolutions,
- value=self.resolutions.index(saved_args.get("resolution", 512)),
- max_height=4,
- scroll_exit=True,
- )
- self.center_crop = self.add_widget_intelligent(
- npyscreen.Checkbox,
- name="Center crop images before resizing to resolution",
- value=saved_args.get("center_crop", False),
- scroll_exit=True,
- )
- self.mixed_precision = self.add_widget_intelligent(
- npyscreen.TitleSelectOne,
- name="Mixed Precision:",
- values=self.precisions,
- value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
- max_height=4,
- scroll_exit=True,
- )
- self.num_train_epochs = self.add_widget_intelligent(
- npyscreen.TitleSlider,
- name="Number of training epochs:",
- out_of=1000,
- step=50,
- lowest=1,
- value=saved_args.get("num_train_epochs", 100),
- scroll_exit=True,
- )
- self.max_train_steps = self.add_widget_intelligent(
- npyscreen.TitleSlider,
- name="Max Training Steps:",
- out_of=10000,
- step=500,
- lowest=1,
- value=saved_args.get("max_train_steps", 3000),
- scroll_exit=True,
- )
- self.train_batch_size = self.add_widget_intelligent(
- npyscreen.TitleSlider,
- name="Batch Size (reduce if you run out of memory):",
- out_of=50,
- step=1,
- lowest=1,
- value=saved_args.get("train_batch_size", 8),
- scroll_exit=True,
- )
- self.gradient_accumulation_steps = self.add_widget_intelligent(
- npyscreen.TitleSlider,
- name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
- out_of=10,
- step=1,
- lowest=1,
- value=saved_args.get("gradient_accumulation_steps", 4),
- scroll_exit=True,
- )
- self.lr_warmup_steps = self.add_widget_intelligent(
- npyscreen.TitleSlider,
- name="Warmup Steps:",
- out_of=100,
- step=1,
- lowest=0,
- value=saved_args.get("lr_warmup_steps", 0),
- scroll_exit=True,
- )
- self.learning_rate = self.add_widget_intelligent(
- npyscreen.TitleText,
- name="Learning Rate:",
- value=str(
- saved_args.get("learning_rate", "5.0e-04"),
- ),
- scroll_exit=True,
- )
- self.scale_lr = self.add_widget_intelligent(
- npyscreen.Checkbox,
- name="Scale learning rate by number GPUs, steps and batch size",
- value=saved_args.get("scale_lr", True),
- scroll_exit=True,
- )
- self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
- npyscreen.Checkbox,
- name="Use xformers acceleration",
- value=saved_args.get("enable_xformers_memory_efficient_attention", False),
- scroll_exit=True,
- )
- self.lr_scheduler = self.add_widget_intelligent(
- npyscreen.TitleSelectOne,
- name="Learning rate scheduler:",
- values=self.lr_schedulers,
- max_height=7,
- value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
- scroll_exit=True,
- )
- self.model.editing = True
-
- def initializer_changed(self) -> None:
- placeholder = self.placeholder_token.value
- self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
- self.train_data_dir.value = str(config.root_path / TRAINING_DATA / placeholder)
- self.output_dir.value = str(config.root_path / TRAINING_DIR / placeholder)
- self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
-
- def on_ok(self):
- if self.validate_field_values():
- self.parentApp.setNextForm(None)
- self.editing = False
- self.parentApp.ti_arguments = self.marshall_arguments()
- npyscreen.notify("Launching textual inversion training. This will take a while...")
- else:
- self.editing = True
-
- def ok_cancel(self):
- sys.exit(0)
-
- def validate_field_values(self) -> bool:
- bad_fields = []
- if self.model.value is None:
- bad_fields.append("Model Name must correspond to a known model in invokeai.db")
- if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
- bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen")
- if self.train_data_dir.value is None:
- bad_fields.append("Data Training Directory cannot be empty")
- if self.output_dir.value is None:
- bad_fields.append("The Output Destination Directory cannot be empty")
- if len(bad_fields) > 0:
- message = "The following problems were detected and must be corrected:"
- for problem in bad_fields:
- message += f"\n* {problem}"
- npyscreen.notify_confirm(message)
- return False
- else:
- return True
-
- def get_model_names(self) -> Tuple[List[str], int]:
- global config
- assert config is not None
- installer = initialize_installer(config)
- store = installer.record_store
- main_models = store.search_by_attr(model_type=ModelType.Main)
- model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
- default = 0
- return (model_names, default)
-
- def marshall_arguments(self) -> dict:
- args = {}
-
- # the choices
- args.update(
- model=self.model_names[self.model.value[0]],
- resolution=self.resolutions[self.resolution.value[0]],
- lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
- mixed_precision=self.precisions[self.mixed_precision.value[0]],
- learnable_property=self.learnable_properties[self.learnable_property.value[0]],
- )
-
- # all the strings and booleans
- for attr in (
- "initializer_token",
- "placeholder_token",
- "train_data_dir",
- "output_dir",
- "scale_lr",
- "center_crop",
- "enable_xformers_memory_efficient_attention",
- ):
- args[attr] = getattr(self, attr).value
-
- # all the integers
- for attr in (
- "train_batch_size",
- "gradient_accumulation_steps",
- "num_train_epochs",
- "max_train_steps",
- "lr_warmup_steps",
- ):
- args[attr] = int(getattr(self, attr).value)
-
- # the floats (just one)
- args.update(learning_rate=float(self.learning_rate.value))
-
- # a special case
- if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
- args["resume_from_checkpoint"] = "latest"
-
- return args
-
-
-class MyApplication(npyscreen.NPSAppManaged):
- def __init__(self, saved_args: Optional[Dict[str, str]] = None):
- super().__init__()
- self.ti_arguments = None
- self.saved_args = saved_args
-
- def onStart(self):
- npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
- self.main = self.addForm(
- "MAIN",
- textualInversionForm,
- name="Textual Inversion Settings",
- saved_args=self.saved_args,
- )
-
-
-def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
- """
- Copy learned_embeds.bin into the embeddings folder, and offer to
- delete the full model and checkpoints.
- """
- assert config is not None
- source = Path(args["output_dir"], "learned_embeds.bin")
- dest_dir_name = args["placeholder_token"].strip("<>")
- destination = config.root_path / "embeddings" / dest_dir_name
- os.makedirs(destination, exist_ok=True)
- logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
- shutil.copy(source, destination)
- if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")):
- shutil.rmtree(Path(args["output_dir"]))
- else:
- logger.info(f'Keeping {args["output_dir"]}')
-
-
-def save_args(args: dict) -> None:
- """
- Save the current argument values to an omegaconf file
- """
- assert config is not None
- dest_dir = config.root_path / TRAINING_DIR
- os.makedirs(dest_dir, exist_ok=True)
- conf_file = dest_dir / CONF_FILE
- conf = OmegaConf.create(args)
- OmegaConf.save(config=conf, f=conf_file)
-
-
-def previous_args() -> dict:
- """
- Get the previous arguments used.
- """
- assert config is not None
- conf_file = config.root_path / TRAINING_DIR / CONF_FILE
- try:
- conf = OmegaConf.load(conf_file)
- conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
- except Exception:
- conf = None
-
- return conf
-
-
-def do_front_end() -> None:
- global config
- saved_args = previous_args()
- myapplication = MyApplication(saved_args=saved_args)
- myapplication.run()
-
- if my_args := myapplication.ti_arguments:
- os.makedirs(my_args["output_dir"], exist_ok=True)
-
- # Automatically add angle brackets around the trigger
- if not re.match("^<.+>$", my_args["placeholder_token"]):
- my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
-
- my_args["only_save_embeds"] = True
- save_args(my_args)
-
- try:
- print(my_args)
- do_textual_inversion_training(config, **my_args)
- copy_to_embeddings_folder(my_args)
- except Exception as e:
- logger.error("An exception occurred during training. The exception was:")
- logger.error(str(e))
- logger.error("DETAILS:")
- logger.error(traceback.format_exc())
-
-
-def main() -> None:
- global config
-
- args: Namespace = parse_args()
- config = get_config()
-
- # change root if needed
- if args.root_dir:
- config.set_root(args.root_dir)
-
- try:
- if args.front_end:
- do_front_end()
- else:
- do_textual_inversion_training(config, **vars(args))
- except AssertionError as e:
- logger.error(e)
- sys.exit(-1)
- except KeyboardInterrupt:
- pass
- except (widget.NotEnoughSpaceForWidget, Exception) as e:
- if str(e).startswith("Height of 1 allocated"):
- logger.error("You need to have at least one diffusers models defined in invokeai.db in order to train")
- elif str(e).startswith("addwstr"):
- logger.error("Not enough window space for the interface. Please make your window larger and try again.")
- else:
- logger.error(e)
- sys.exit(-1)
-
-
-if __name__ == "__main__":
- main()
diff --git a/pyproject.toml b/pyproject.toml
index 013a2e2b0d..be42437492 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,7 +73,7 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
- "matplotlib", # needed for plotting of Penner easing functions
+ "matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",
"picklescan",
@@ -127,27 +127,8 @@ dependencies = [
]
[project.scripts]
-
-# legacy entrypoints; provided for backwards compatibility
-"configure_invokeai.py" = "invokeai.frontend.install.invokeai_configure:run_configure"
-"textual_inversion.py" = "invokeai.frontend.training:invokeai_textual_inversion"
-
-# shortcut commands to start web ui
-# "invokeai --web" will launch the web interface
-# "invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
-
-# new shortcut to launch web interface
"invokeai-web" = "invokeai.app.run_app:run_app"
-
-# full commands
-"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:run_configure"
-"invokeai-merge" = "invokeai.frontend.merge.merge_diffusers:main"
-"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
-"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
-"invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install
-"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
-"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
@@ -191,7 +172,7 @@ version = { attr = "invokeai.version.__version__" }
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
markers = [
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
- "timeout: Marks the timeout override."
+ "timeout: Marks the timeout override.",
]
[tool.coverage.run]
branch = true
diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py
deleted file mode 100755
index a4125d8a58..0000000000
--- a/scripts/configure_invokeai.py
+++ /dev/null
@@ -1,12 +0,0 @@
-#!/usr/bin/env python
-# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
-
-import warnings
-
-from invokeai.frontend.install.invokeai_configure import run_configure as configure
-
-if __name__ == "__main__":
- warnings.warn(
- "configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2
- )
- configure()
diff --git a/scripts/images2prompt.py b/scripts/images2prompt.py
deleted file mode 100755
index b12ff562fa..0000000000
--- a/scripts/images2prompt.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#!/usr/bin/env python
-"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
-
-import sys
-
-from PIL import Image
-
-if len(sys.argv) < 2:
- print("Usage: file2prompt.py ...")
- print(
- "This script opens up the indicated invoke.py-generated PNG file(s) and prints out the prompt used to generate them."
- )
- exit(-1)
-
-filenames = sys.argv[1:]
-for f in filenames:
- try:
- im = Image.open(f)
- try:
- prompt = im.text["Dream"]
- except KeyError:
- prompt = ""
- print(f"{f}: {prompt}")
- except FileNotFoundError:
- sys.stderr.write(f"{f} not found\n")
- continue
- except PermissionError:
- sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n")
- continue
diff --git a/scripts/invokeai-cli.py b/scripts/invokeai-cli.py
deleted file mode 100755
index 73c0f6e7c3..0000000000
--- a/scripts/invokeai-cli.py
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
-
-import logging
-import os
-
-logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
-
-
-def main():
- # Change working directory to the repo root
- os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
-
- # TODO: Parse some top-level args here.
- from invokeai.app.cli_app import invoke_cli
-
- invoke_cli()
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/invokeai-model-install.py b/scripts/invokeai-model-install.py
deleted file mode 100755
index 0cac857f8c..0000000000
--- a/scripts/invokeai-model-install.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#!/usr/bin/env python
-
-from invokeai.frontend.install.model_install import main
-
-main()
diff --git a/scripts/make_models_markdown_table.py b/scripts/make_models_markdown_table.py
deleted file mode 100755
index 0af5085ff5..0000000000
--- a/scripts/make_models_markdown_table.py
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/usr/bin/env python
-
-"""
-This script is used at release time to generate a markdown table describing the
-starter models. This text is then manually copied into 050_INSTALL_MODELS.md.
-"""
-
-from pathlib import Path
-
-from omegaconf import OmegaConf
-
-
-def main():
- initial_models_file = Path(__file__).parent / "../invokeai/configs/INITIAL_MODELS.yaml"
- models = OmegaConf.load(initial_models_file)
- print("|Model Name | HuggingFace Repo ID | Description | URL |")
- print("|---------- | ---------- | ----------- | --- |")
- for model in models:
- repo_id = models[model].repo_id
- url = f"https://huggingface.co/{repo_id}"
- print(f"|{model}|{repo_id}|{models[model].description}|{url} |")
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/pypi_helper.py b/scripts/pypi_helper.py
deleted file mode 100755
index 6c1f9b9033..0000000000
--- a/scripts/pypi_helper.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#!/usr/bin/env python
-
-import requests
-
-from invokeai.version import __version__
-
-local_version = str(__version__).replace("-", "")
-package_name = "InvokeAI"
-
-
-def get_pypi_versions(package_name=package_name) -> list[str]:
- """Get the versions of the package from PyPI"""
- url = f"https://pypi.org/pypi/{package_name}/json"
- response = requests.get(url).json()
- versions: list[str] = list(response["releases"].keys())
- return versions
-
-
-def local_on_pypi(package_name=package_name, local_version=local_version) -> bool:
- """Compare the versions of the package from PyPI and the local package"""
- pypi_versions = get_pypi_versions(package_name)
- return local_version in pypi_versions
-
-
-if __name__ == "__main__":
- if local_on_pypi():
- print(f"Package {package_name} is up to date")
- else:
- print(f"Package {package_name} is not up to date")
diff --git a/scripts/scan_models_directory.py b/scripts/scan_models_directory.py
deleted file mode 100755
index a85fb793dd..0000000000
--- a/scripts/scan_models_directory.py
+++ /dev/null
@@ -1,61 +0,0 @@
-#!/usr/bin/env python
-
-"""
-Scan the models directory and print out a new models.yaml
-"""
-
-import argparse
-import os
-import sys
-from pathlib import Path
-
-from omegaconf import OmegaConf
-
-
-def main():
- parser = argparse.ArgumentParser(description="Model directory scanner")
- parser.add_argument("models_directory")
- parser.add_argument(
- "--all-models",
- default=False,
- action="store_true",
- help="If true, then generates stanzas for all models; otherwise just diffusers",
- )
-
- args = parser.parse_args()
- directory = args.models_directory
-
- conf = OmegaConf.create()
- conf["_version"] = "3.0.0"
-
- for root, dirs, files in os.walk(directory):
- parents = root.split("/")
- subpaths = parents[parents.index("models") + 1 :]
- if len(subpaths) < 2:
- continue
- base, model_type, *_ = subpaths
-
- if args.all_models or model_type == "diffusers":
- for d in dirs:
- conf[f"{base}/{model_type}/{d}"] = {
- "path": os.path.join(root, d),
- "description": f"{model_type} model {d}",
- "format": "folder",
- "base": base,
- }
-
- for f in files:
- basename = Path(f).stem
- format = Path(f).suffix[1:]
- conf[f"{base}/{model_type}/{basename}"] = {
- "path": os.path.join(root, f),
- "description": f"{model_type} model {basename}",
- "format": format,
- "base": base,
- }
-
- OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/sd-metadata.py b/scripts/sd-metadata.py
deleted file mode 100755
index b3ce4fd66e..0000000000
--- a/scripts/sd-metadata.py
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/usr/bin/env python
-
-import json
-import sys
-
-from invokeai.backend.image_util import retrieve_metadata
-
-if len(sys.argv) < 2:
- print("Usage: file2prompt.py ...")
- print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out their metadata.")
- exit(-1)
-
-filenames = sys.argv[1:]
-for f in filenames:
- try:
- metadata = retrieve_metadata(f)
- print(f"{f}:\n", json.dumps(metadata["sd-metadata"], indent=4))
- except FileNotFoundError:
- sys.stderr.write(f"{f} not found\n")
- continue
- except PermissionError:
- sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n")
- continue
From 0e514950718a5d616e8c5ac5c54f5dbee5304aa3 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 16:56:36 +1100
Subject: [PATCH 22/52] chore(ui): lint
---
invokeai/frontend/web/src/app/components/App.tsx | 3 +--
.../modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx | 1 -
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx
index 59a03233e8..ae1a762f61 100644
--- a/invokeai/frontend/web/src/app/components/App.tsx
+++ b/invokeai/frontend/web/src/app/components/App.tsx
@@ -69,8 +69,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
dispatch(appStarted());
}, [dispatch]);
-
- useStarterModelsToast()
+ useStarterModelsToast();
return (
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
index 4fa1c25ae8..b7ce8f8105 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
@@ -121,4 +121,3 @@ const modelsFilter = (
return matchesFilter && matchesType;
});
};
-
From 2eacbb4d9d5edf484b751b26024333175164549d Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 17:00:09 +1100
Subject: [PATCH 23/52] fix(nodes): do not load NSFW checker model on startup
Just check if the path exists to determine if it is "available". When needed, load it.
---
invokeai/backend/image_util/safety_checker.py | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py
index a93d15ed73..682603e770 100644
--- a/invokeai/backend/image_util/safety_checker.py
+++ b/invokeai/backend/image_util/safety_checker.py
@@ -4,6 +4,8 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
+from pathlib import Path
+
import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image
@@ -41,15 +43,15 @@ class SafetyChecker:
@classmethod
def safety_checker_available(cls) -> bool:
- cls._load_safety_checker()
- return cls.safety_checker is not None
+ return Path(get_config().models_path, CHECKER_PATH).exists()
@classmethod
def has_nsfw_concept(cls, image: Image.Image) -> bool:
- if not cls.safety_checker_available():
+ if not cls.safety_checker_available() and cls.tried_load:
+ return False
+ cls._load_safety_checker()
+ if cls.safety_checker is None or cls.feature_extractor is None:
return False
- assert cls.safety_checker is not None
- assert cls.feature_extractor is not None
device = choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
From 96ef7e3889f6b3d88125ed3fc53dc4b81b512f4c Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 17:16:09 +1100
Subject: [PATCH 24/52] docs: add link to docs to invokeai.yaml template
---
invokeai/app/services/config/config_default.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index e09567cc8c..6bd3d4ef64 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -248,7 +248,7 @@ class InvokeAIAppConfig(BaseSettings):
file.write("# Internal metadata - do not edit:\n")
file.write(yaml.dump(meta_dict, sort_keys=False))
file.write("\n")
- file.write("# Put user settings here:\n")
+ file.write("# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:\n")
if len(config_dict) > 0:
file.write(yaml.dump(config_dict, sort_keys=False))
From 13c72206d83eb430fb9ec98361bc5c3b00b4b835 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 17:16:23 +1100
Subject: [PATCH 25/52] docs: update CONFIGURATION.md
---
docs/features/CONFIGURATION.md | 19 +++++++++----------
1 file changed, 9 insertions(+), 10 deletions(-)
diff --git a/docs/features/CONFIGURATION.md b/docs/features/CONFIGURATION.md
index f608f80467..3f05c0cf9b 100644
--- a/docs/features/CONFIGURATION.md
+++ b/docs/features/CONFIGURATION.md
@@ -18,9 +18,6 @@ Settings sources are used in this order:
- `invokeai.yaml` settings
- Fallback: defaults
-The most commonly changed settings are also accessible
-graphically via the `invokeai-configure` script.
-
### InvokeAI Root Directory
On startup, InvokeAI searches for its "root" directory. This is the directory
@@ -42,10 +39,9 @@ It has two sections - one for internal use and one for user settings:
```yaml
# Internal metadata - do not edit:
-meta:
- schema_version: 4
+schema_version: 4
-# Put user settings here:
+# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
host: 0.0.0.0 # serve the app on your local network
models_dir: D:\invokeai\models # store models on an external drive
precision: float16 # always use fp16 precision
@@ -62,6 +58,12 @@ You can fix a broken `invokeai.yaml` by deleting it and running the
configuration script again -- option [6] in the launcher, "Re-run the
configure script".
+#### Custom Config File Location
+
+You can use any config file with the `--config` CLI arg. Pass in the path to the `invokeai.yaml` file you want to use.
+
+Note that environment variables will trump any settings in the config file.
+
### Environment Variables
All settings may be set via environment variables by prefixing `INVOKEAI_`
@@ -81,13 +83,10 @@ We suggest using `invokeai.yaml`, as it is more user-friendly.
A subset of settings may be specified using CLI args:
- `--root`: specify the root directory
-- `--ignore_missing_core-models`: if set, do not check for models needed
- to convert checkpoint/safetensor models to diffusers
+- `--config`: override the default `invokeai.yaml` file location
### All Settings
-The config is managed by the `InvokeAIAppConfig` class. The below docs are autogenerated from the class.
-
Following the table are additional explanations for certain settings.
From 040ea8f41b4d8fc1c222c109e573f21cfb5295f9 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 17:23:18 +1100
Subject: [PATCH 26/52] tidy: do not show msg when loading NSFW checker
---
invokeai/backend/image_util/safety_checker.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py
index 682603e770..7bceae8da7 100644
--- a/invokeai/backend/image_util/safety_checker.py
+++ b/invokeai/backend/image_util/safety_checker.py
@@ -36,7 +36,6 @@ class SafetyChecker:
try:
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
- logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
cls.tried_load = True
From 6c13fa13ea7e484050b981913486de220aba4912 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 18:12:23 +1100
Subject: [PATCH 27/52] fix(mm): regression from change to legacy conf dir
change
---
.../model_install/model_install_default.py | 2 +-
.../load/model_loaders/controlnet.py | 4 +---
.../load/model_loaders/stable_diffusion.py | 3 +--
invokeai/backend/model_manager/probe.py | 15 ++++++++-------
4 files changed, 11 insertions(+), 13 deletions(-)
diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py
index a80149c7f0..bd67535f79 100644
--- a/invokeai/app/services/model_install/model_install_default.py
+++ b/invokeai/app/services/model_install/model_install_default.py
@@ -604,7 +604,7 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix()
- # add 'main' specific fields
+ # Checkpoints have a config file needed for conversion - resolve this to an absolute path
if isinstance(info, CheckpointConfigBase):
legacy_conf = (self.app_config.legacy_conf_path / info.config_path).resolve()
info.config_path = legacy_conf.as_posix()
diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py
index 736bb65548..d994f2147a 100644
--- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py
+++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py
@@ -35,8 +35,6 @@ class ControlNetLoader(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, CheckpointConfigBase)
- config_file = config.config_path
-
image_size = (
512
if config.base == BaseModelType.StableDiffusion1
@@ -46,7 +44,7 @@ class ControlNetLoader(GenericDiffusersLoader):
)
self._logger.info(f"Converting {model_path} to diffusers format")
- with open(self._app_config.root_path / config_file, "r") as config_stream:
+ with open(config.config_path, "r") as config_stream:
convert_controlnet_to_diffusers(
model_path,
output_path,
diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py
index 3fb2e29f60..fa66c56364 100644
--- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py
+++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py
@@ -76,7 +76,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
assert isinstance(config, MainCheckpointConfig)
base = config.base
- config_file = config.config_path
prediction_type = config.prediction_type.value
upcast_attention = config.upcast_attention
image_size = (
@@ -92,7 +91,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
- original_config_file=self._app_config.root_path / config_file,
+ original_config_file=config.config_path,
extract_ema=True,
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py
index d43b160750..cac4a297ae 100644
--- a/invokeai/backend/model_manager/probe.py
+++ b/invokeai/backend/model_manager/probe.py
@@ -178,13 +178,14 @@ class ModelProbe(object):
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
- fields["config_path"] = cls._get_checkpoint_config_path(
+ ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
base_type=fields["base"],
variant_type=fields["variant"],
prediction_type=fields["prediction_type"],
- ).as_posix()
+ )
+ fields["config_path"] = str(ckpt_config_path)
# additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [
@@ -298,23 +299,23 @@ class ModelProbe(object):
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
+ config_file = f"stable-diffusion/{config_file}"
elif model_type is ModelType.ControlNet:
config_file = (
- "../controlnet/cldm_v15.yaml"
+ "controlnet/cldm_v15.yaml"
if base_type is BaseModelType.StableDiffusion1
- else "../controlnet/cldm_v21.yaml"
+ else "controlnet/cldm_v21.yaml"
)
elif model_type is ModelType.VAE:
config_file = (
- "../stable-diffusion/v1-inference.yaml"
+ "stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
- else "../stable-diffusion/v2-inference.yaml"
+ else "stable-diffusion/v2-inference.yaml"
)
else:
raise InvalidModelConfigException(
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
)
- assert isinstance(config_file, str)
return Path(config_file)
@classmethod
From b02f2da71d6f70244e3236a9954594a92a1ee49c Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 19:26:13 +1100
Subject: [PATCH 28/52] fix(config): handle legacy_conf_dir setting migration
---
invokeai/app/services/config/config_default.py | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 6bd3d4ef64..6e373151ae 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -353,6 +353,14 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
parsed_config_dict["vram"] = v
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
+ if k == "legacy_conf_dir":
+ # The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
+ # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
+ # Else we do not attempt to migrate this setting
+ if v != "configs/stable-diffusion":
+ parsed_config_dict["legacy_conf_dir"] = v
+ elif Path(v).name == "stable-diffusion":
+ parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
From f5337c7ce29d831a53c2daa7935048533c9eb9e6 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 20:24:02 +1100
Subject: [PATCH 29/52] fix(config): handle relative paths to v3 models.yamls
---
invokeai/app/services/config/config_default.py | 6 ++----
.../app/services/model_install/model_install_default.py | 4 ++++
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 6e373151ae..6edd3283ca 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -127,6 +127,7 @@ class InvokeAIAppConfig(BaseSettings):
# INTERNAL
schema_version: int = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
+ # This is only used during v3 models.yaml migration
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
# WEB
@@ -231,11 +232,8 @@ class InvokeAIAppConfig(BaseSettings):
dest_path: Path to write the config to.
"""
with open(dest_path, "w") as file:
- # Meta fields should be written in a separate stanza
+ # Meta fields should be written in a separate stanza - skip legacy_models_yaml_path
meta_dict = self.model_dump(mode="json", include={"schema_version"})
- # Only include the legacy_models_yaml_path if it's set
- if self.legacy_models_yaml_path:
- meta_dict.update(self.model_dump(mode="json", include={"legacy_models_yaml_path"}))
# User settings
config_dict = self.model_dump(
diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py
index bd67535f79..bd41022037 100644
--- a/invokeai/app/services/model_install/model_install_default.py
+++ b/invokeai/app/services/model_install/model_install_default.py
@@ -304,6 +304,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
+ # The old path may be relative to the root path
+ if not legacy_models_yaml_path.exists():
+ legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
+
if legacy_models_yaml_path.exists():
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
From 02329df1dfc4ccd43856d0d0b924a211d1016b1f Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 21:28:07 +1100
Subject: [PATCH 30/52] feat(config): write example config file out on app
startup
---
.../app/services/config/config_default.py | 23 +++++++++++++++----
1 file changed, 18 insertions(+), 5 deletions(-)
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 6edd3283ca..47da2f6628 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -32,7 +32,7 @@ ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
-CONFIG_SCHEMA_VERSION = 4
+CONFIG_SCHEMA_VERSION = "4.0.0"
def get_default_ram_cache_size() -> float:
@@ -126,7 +126,7 @@ class InvokeAIAppConfig(BaseSettings):
# fmt: off
# INTERNAL
- schema_version: int = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
+ schema_version: str = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
# This is only used during v3 models.yaml migration
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
@@ -223,7 +223,7 @@ class InvokeAIAppConfig(BaseSettings):
if new_value != current_value:
setattr(self, field_name, new_value)
- def write_file(self, dest_path: Path) -> None:
+ def write_file(self, dest_path: Path, as_example: bool = False) -> None:
"""Write the current configuration to file. This will overwrite the existing file.
A `meta` stanza is added to the top of the file, containing metadata about the config file. This is not stored in the config object.
@@ -238,11 +238,16 @@ class InvokeAIAppConfig(BaseSettings):
# User settings
config_dict = self.model_dump(
mode="json",
- exclude_unset=True,
- exclude_defaults=True,
+ exclude_unset=False if as_example else True,
+ exclude_defaults=False if as_example else True,
+ exclude_none=True if as_example else False,
exclude={"schema_version", "legacy_models_yaml_path"},
)
+ if as_example:
+ file.write(
+ "# This is an example file with default and example settings. Use the values here as a baseline.\n\n"
+ )
file.write("# Internal metadata - do not edit:\n")
file.write(yaml.dump(meta_dict, sort_keys=False))
file.write("\n")
@@ -436,6 +441,14 @@ def get_config() -> InvokeAIAppConfig:
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
+ # Create the example file from a deep copy, with some extra values provided
+ example_config = config.model_copy(deep=True)
+ example_config.remote_api_tokens = [
+ URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
+ URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
+ ]
+ example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
+
# Log in to HF
hf_login()
From dea9142cb8d72503a5a89617899439d92a870045 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Tue, 19 Mar 2024 21:33:09 +1100
Subject: [PATCH 31/52] tests: fix config test after changing config schema
version format
---
tests/test_config.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/test_config.py b/tests/test_config.py
index e0e0050bf5..5576db2bc9 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -9,14 +9,14 @@ from pydantic import ValidationError
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
v4_config = """
-schema_version: 4
+schema_version: 4.0.0
host: "192.168.1.1"
port: 8080
"""
invalid_v5_config = """
-schema_version: 5
+schema_version: 5.0.0
host: "192.168.1.1"
port: 8080
From 9a5575b46bd70c02029773fee323a55525472318 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Wed, 20 Mar 2024 09:21:48 +1100
Subject: [PATCH 32/52] feat(mm): move HF token helper to route
---
invokeai/app/api/routers/model_manager.py | 52 +++++++++++++++++++
.../app/services/config/config_default.py | 4 --
invokeai/app/util/hf_login.py | 46 ----------------
3 files changed, 52 insertions(+), 50 deletions(-)
delete mode 100644 invokeai/app/util/hf_login.py
diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py
index 39b9d793af..5a785eeb0b 100644
--- a/invokeai/app/api/routers/model_manager.py
+++ b/invokeai/app/api/routers/model_manager.py
@@ -1,13 +1,16 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
+import contextlib
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
+from enum import Enum
from typing import Any, Dict, List, Optional
+import huggingface_hub
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
@@ -22,6 +25,7 @@ from invokeai.app.services.model_records import (
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
+from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@@ -794,3 +798,51 @@ async def get_starter_models() -> list[StarterModel]:
model.is_installed = True
return starter_models
+
+
+class HFTokenStatus(str, Enum):
+ VALID = "valid"
+ INVALID = "invalid"
+ UNKNOWN = "unknown"
+
+
+class HFTokenHelper:
+ @classmethod
+ def get_status(cls) -> HFTokenStatus:
+ try:
+ if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
+ # Valid token!
+ return HFTokenStatus.VALID
+ # No token set
+ return HFTokenStatus.INVALID
+ except Exception:
+ return HFTokenStatus.UNKNOWN
+
+ @classmethod
+ def set_token(cls, token: str) -> HFTokenStatus:
+ with SuppressOutput(), contextlib.suppress(Exception):
+ huggingface_hub.login(token=token, add_to_git_credential=False)
+ return cls.get_status()
+
+
+@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
+async def get_hf_login_status() -> HFTokenStatus:
+ token_status = HFTokenHelper.get_status()
+
+ if token_status is HFTokenStatus.UNKNOWN:
+ ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
+
+ return token_status
+
+
+@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
+async def do_hf_login(
+ token: str = Body(description="Hugging Face token to use for login", embed=True),
+) -> HFTokenStatus:
+ HFTokenHelper.set_token(token)
+ token_status = HFTokenHelper.get_status()
+
+ if token_status is HFTokenStatus.UNKNOWN:
+ ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
+
+ return token_status
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 47da2f6628..9bca00615c 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -16,7 +16,6 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
import invokeai.configs as model_configs
-from invokeai.app.util.hf_login import hf_login
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@@ -449,9 +448,6 @@ def get_config() -> InvokeAIAppConfig:
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
- # Log in to HF
- hf_login()
-
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
diff --git a/invokeai/app/util/hf_login.py b/invokeai/app/util/hf_login.py
deleted file mode 100644
index 125010c2bb..0000000000
--- a/invokeai/app/util/hf_login.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import huggingface_hub
-from pwinput import pwinput
-
-from invokeai.app.util.suppress_output import SuppressOutput
-
-
-def hf_login() -> None:
- """Prompts the user for their HuggingFace token. If a valid token is already saved, this function will do nothing.
-
- Returns:
- bool: True if the login was successful, False if the user canceled.
-
- Raises:
- RuntimeError: If the user cancels the login prompt.
- """
-
- current_token = huggingface_hub.get_token()
-
- try:
- if huggingface_hub.get_token_permission(current_token):
- # We have a valid token already
- return
- except ConnectionError:
- print("Unable to reach HF to verify token. Skipping...")
- # No internet connection, so we can't check the token
- pass
-
- # InvokeAILogger depends on the config, and this class is used within the config, so we can't use the app logger here
- print("Enter your HuggingFace token. This is required to convert checkpoint/safetensors models to diffusers.")
- print("For more information, see https://huggingface.co/docs/hub/security-tokens#how-to-manage-user-access-tokens")
- print("Press Ctrl+C to skip.")
-
- while True:
- try:
- access_token = pwinput(prompt="HF token: ")
- # The login function prints to stdout
- with SuppressOutput():
- huggingface_hub.login(token=access_token, add_to_git_credential=False)
- print("Token verified.")
- break
- except ValueError:
- print("Invalid token!")
- continue
- except KeyboardInterrupt:
- print("\nToken verification canceled.")
- break
From 3f6f8199f6f77e0d113ce8c676ccdd13cc6fe0be Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Wed, 20 Mar 2024 09:21:56 +1100
Subject: [PATCH 33/52] chore(ui): typegen
---
.../frontend/web/src/services/api/schema.ts | 56 ++++++++++++++++++-
1 file changed, 54 insertions(+), 2 deletions(-)
diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index c1c0bd323e..2e69a071a6 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -159,6 +159,12 @@ export type paths = {
/** Get Starter Models */
get: operations["get_starter_models"];
};
+ "/api/v2/models/hf_login": {
+ /** Get Hf Login Status */
+ get: operations["get_hf_login_status"];
+ /** Do Hf Login */
+ post: operations["do_hf_login"];
+ };
"/api/v1/download_queue/": {
/**
* List Downloads
@@ -1022,6 +1028,14 @@ export type components = {
*/
image_names: string[];
};
+ /** Body_do_hf_login */
+ Body_do_hf_login: {
+ /**
+ * Token
+ * @description Hugging Face token to use for login
+ */
+ token: string;
+ };
/** Body_download */
Body_download: {
/**
@@ -4116,7 +4130,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes: {
- [key: string]: components["schemas"]["SubtractInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"];
+ [key: string]: components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["UnsharpMaskInvocation"];
};
/**
* Edges
@@ -4153,7 +4167,7 @@ export type components = {
* @description The results of node executions
*/
results: {
- [key: string]: components["schemas"]["IntegerOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["String2Output"] | components["schemas"]["NoiseOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"];
+ [key: string]: components["schemas"]["MetadataOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["String2Output"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["LoRALoaderOutput"];
};
/**
* Errors
@@ -4199,6 +4213,11 @@ export type components = {
*/
type?: "hf";
};
+ /**
+ * HFTokenStatus
+ * @enum {string}
+ */
+ HFTokenStatus: "valid" | "invalid" | "unknown";
/** HTTPValidationError */
HTTPValidationError: {
/** Detail */
@@ -11617,6 +11636,39 @@ export type operations = {
};
};
};
+ /** Get Hf Login Status */
+ get_hf_login_status: {
+ responses: {
+ /** @description Successful Response */
+ 200: {
+ content: {
+ "application/json": components["schemas"]["HFTokenStatus"];
+ };
+ };
+ };
+ };
+ /** Do Hf Login */
+ do_hf_login: {
+ requestBody: {
+ content: {
+ "application/json": components["schemas"]["Body_do_hf_login"];
+ };
+ };
+ responses: {
+ /** @description Successful Response */
+ 200: {
+ content: {
+ "application/json": components["schemas"]["HFTokenStatus"];
+ };
+ };
+ /** @description Validation Error */
+ 422: {
+ content: {
+ "application/json": components["schemas"]["HTTPValidationError"];
+ };
+ };
+ };
+ };
/**
* List Downloads
* @description Get a list of active and inactive jobs.
From bdb52cfcf7fdbe6cda7ca5fee44cd909a0f41eed Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Wed, 20 Mar 2024 09:24:10 +1100
Subject: [PATCH 34/52] feat(ui): set HF token in MM tab
- Display a toast on UI launch if the HF token is invalid
- Show form in MM if token is invalid or unable to be verified, let user set the token via this form
---
invokeai/frontend/web/public/locales/en.json | 8 ++
.../frontend/web/src/app/components/App.tsx | 2 +
.../frontend/web/src/app/types/invokeai.ts | 3 +-
.../modelManagerV2/components/HFToken.tsx | 81 +++++++++++++++++
.../modelManagerV2/hooks/useHFLoginToast.tsx | 89 +++++++++++++++++++
.../modelManagerV2/subpanels/ModelManager.tsx | 2 +
.../web/src/services/api/endpoints/models.ts | 27 ++++++
.../frontend/web/src/services/api/index.ts | 1 +
8 files changed, 212 insertions(+), 1 deletion(-)
create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/components/HFToken.tsx
create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFLoginToast.tsx
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 0cfaf912f0..733f558819 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -638,6 +638,14 @@
"huggingFacePlaceholder": "owner/model-name",
"huggingFaceRepoID": "HuggingFace Repo ID",
"huggingFaceHelper": "If multiple models are found in this repo, you will be prompted to select one to install.",
+ "hfToken": "HuggingFace Token",
+ "hfTokenHelperText": "A HF token is required to use checkpoint models. Click here to create or get your token.",
+ "hfTokenInvalid": "Invalid or Missing HF Token",
+ "hfTokenInvalidErrorMessage": "Invalid or missing HuggingFace token.",
+ "hfTokenInvalidErrorMessage2": "Update it in the ",
+ "hfTokenUnableToVerify": "Unable to Verify HF Token",
+ "hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.",
+ "hfTokenSaved": "HF Token Saved",
"imageEncoderModelId": "Image Encoder Model ID",
"installQueue": "Install Queue",
"inplaceInstall": "In-place install",
diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx
index ae1a762f61..13231e6aeb 100644
--- a/invokeai/frontend/web/src/app/components/App.tsx
+++ b/invokeai/frontend/web/src/app/components/App.tsx
@@ -11,6 +11,7 @@ import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
+import { useHFLoginToast } from 'features/modelManagerV2/hooks/useHFLoginToast';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
@@ -70,6 +71,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
}, [dispatch]);
useStarterModelsToast();
+ useHFLoginToast()
return (
diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts
index 1e40b7074d..4982dbb83f 100644
--- a/invokeai/frontend/web/src/app/types/invokeai.ts
+++ b/invokeai/frontend/web/src/app/types/invokeai.ts
@@ -25,7 +25,8 @@ export type AppFeature =
| 'prependQueue'
| 'invocationCache'
| 'bulkDownload'
- | 'starterModels';
+ | 'starterModels'
+ | 'hfToken';
/**
* A disable-able Stable Diffusion feature
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/components/HFToken.tsx b/invokeai/frontend/web/src/features/modelManagerV2/components/HFToken.tsx
new file mode 100644
index 0000000000..86ca7e128b
--- /dev/null
+++ b/invokeai/frontend/web/src/features/modelManagerV2/components/HFToken.tsx
@@ -0,0 +1,81 @@
+import {
+ Button,
+ ExternalLink,
+ Flex,
+ FormControl,
+ FormErrorMessage,
+ FormHelperText,
+ FormLabel,
+ Input,
+ useToast,
+} from '@invoke-ai/ui-library';
+import { skipToken } from '@reduxjs/toolkit/query';
+import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
+import type { ChangeEvent } from 'react';
+import { useCallback, useMemo, useState } from 'react';
+import { useTranslation } from 'react-i18next';
+import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models';
+
+export const HFToken = () => {
+ const { t } = useTranslation();
+ const isEnabled = useFeatureStatus('hfToken').isFeatureEnabled;
+ const [token, setToken] = useState('');
+ const { currentData } = useGetHFTokenStatusQuery(isEnabled ? undefined : skipToken);
+ const [trigger, { isLoading }] = useSetHFTokenMutation();
+ const toast = useToast();
+ const onChange = useCallback((e: ChangeEvent) => {
+ setToken(e.target.value);
+ }, []);
+ const onClick = useCallback(() => {
+ trigger({ token })
+ .unwrap()
+ .then((res) => {
+ if (res === 'valid') {
+ setToken('');
+ toast({
+ title: t('modelManager.hfTokenSaved'),
+ status: 'success',
+ duration: 3000,
+ });
+ }
+ });
+ }, [t, toast, token, trigger]);
+
+ const error = useMemo(() => {
+ if (!currentData || isLoading) {
+ return null;
+ }
+ if (currentData === 'invalid') {
+ return t('modelManager.hfTokenInvalidErrorMessage');
+ }
+ if (currentData === 'unknown') {
+ return t('modelManager.hfTokenUnableToVerifyErrorMessage');
+ }
+ return null;
+ }, [currentData, isLoading, t]);
+
+ if (!currentData || currentData === 'valid') {
+ return null;
+ }
+
+ return (
+
+
+ {t('modelManager.hfToken')}
+
+
+
+
+
+
+
+ {error}
+
+
+ );
+};
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFLoginToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFLoginToast.tsx
new file mode 100644
index 0000000000..972dfaccb3
--- /dev/null
+++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFLoginToast.tsx
@@ -0,0 +1,89 @@
+import { Button, Text, useToast } from '@invoke-ai/ui-library';
+import { skipToken } from '@reduxjs/toolkit/query';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
+import { setActiveTab } from 'features/ui/store/uiSlice';
+import { t } from 'i18next';
+import { useCallback, useEffect, useState } from 'react';
+import { useTranslation } from 'react-i18next';
+import { useGetHFTokenStatusQuery } from 'services/api/endpoints/models';
+import type { S } from 'services/api/types';
+
+const FEATURE_ID = 'hfToken';
+
+const getTitle = (token_status: S['HFTokenStatus']) => {
+ switch (token_status) {
+ case 'invalid':
+ return t('modelManager.hfTokenInvalid');
+ case 'unknown':
+ return t('modelManager.hfTokenUnableToVerify');
+ }
+};
+
+export const useHFLoginToast = () => {
+ const { t } = useTranslation();
+ const isEnabled = useFeatureStatus(FEATURE_ID).isFeatureEnabled;
+ const [didToast, setDidToast] = useState(false);
+ const { data } = useGetHFTokenStatusQuery(isEnabled ? undefined : skipToken);
+ const toast = useToast();
+
+ useEffect(() => {
+ if (toast.isActive(FEATURE_ID)) {
+ if (data === 'valid') {
+ setDidToast(true);
+ toast.close(FEATURE_ID);
+ }
+ return;
+ }
+ if (data && data !== 'valid' && !didToast && isEnabled) {
+ const title = getTitle(data);
+ toast({
+ id: FEATURE_ID,
+ title,
+ description: ,
+ status: 'info',
+ isClosable: true,
+ duration: null,
+ onCloseComplete: () => setDidToast(true),
+ });
+ }
+ }, [data, didToast, isEnabled, t, toast]);
+};
+
+type Props = {
+ token_status: S['HFTokenStatus'];
+};
+
+const ToastDescription = ({ token_status }: Props) => {
+ const { t } = useTranslation();
+ const dispatch = useAppDispatch();
+ const toast = useToast();
+
+ const onClick = useCallback(() => {
+ dispatch(setActiveTab('modelManager'));
+ toast.close(FEATURE_ID);
+ }, [dispatch, toast]);
+
+ if (token_status === 'invalid') {
+ return (
+
+ {t('modelManager.hfTokenInvalidErrorMessage')}{' '}
+ {t('modelManager.hfTokenInvalidErrorMessage2')}
+
+
+ );
+ }
+
+ if (token_status === 'unknown') {
+ return (
+
+ {t('modelManager.hfTokenUnableToErrorMessage')}{' '}
+
+
+ );
+ }
+};
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx
index ed75f86078..51c2ab0f7b 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx
@@ -1,5 +1,6 @@
import { Button, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
+import { HFToken } from 'features/modelManagerV2/components/HFToken';
import { SyncModelsButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsButton';
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback } from 'react';
@@ -27,6 +28,7 @@ export const ModelManager = () => {