update textual inversion training to use root_path rather than root_dir

This commit is contained in:
Lincoln Stein 2024-03-16 22:50:12 -04:00 committed by psychedelicious
parent 5d16a40b95
commit f1450c2c24

View File

@ -123,7 +123,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
value=str( value=str(
saved_args.get( saved_args.get(
"train_data_dir", "train_data_dir",
config.root_dir / TRAINING_DATA / default_placeholder_token, config.root_path / TRAINING_DATA / default_placeholder_token,
) )
), ),
scroll_exit=True, scroll_exit=True,
@ -136,7 +136,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
value=str( value=str(
saved_args.get( saved_args.get(
"output_dir", "output_dir",
config.root_dir / TRAINING_DIR / default_placeholder_token, config.root_path / TRAINING_DIR / default_placeholder_token,
) )
), ),
scroll_exit=True, scroll_exit=True,
@ -241,8 +241,8 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
def initializer_changed(self) -> None: def initializer_changed(self) -> None:
placeholder = self.placeholder_token.value placeholder = self.placeholder_token.value
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)" self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder) self.train_data_dir.value = str(config.root_path / TRAINING_DATA / placeholder)
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder) self.output_dir.value = str(config.root_path / TRAINING_DIR / placeholder)
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists() self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
def on_ok(self): def on_ok(self):
@ -354,7 +354,7 @@ def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
assert config is not None assert config is not None
source = Path(args["output_dir"], "learned_embeds.bin") source = Path(args["output_dir"], "learned_embeds.bin")
dest_dir_name = args["placeholder_token"].strip("<>") dest_dir_name = args["placeholder_token"].strip("<>")
destination = config.root_dir / "embeddings" / dest_dir_name destination = config.root_path / "embeddings" / dest_dir_name
os.makedirs(destination, exist_ok=True) os.makedirs(destination, exist_ok=True)
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}") logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
shutil.copy(source, destination) shutil.copy(source, destination)
@ -369,7 +369,7 @@ def save_args(args: dict) -> None:
Save the current argument values to an omegaconf file Save the current argument values to an omegaconf file
""" """
assert config is not None assert config is not None
dest_dir = config.root_dir / TRAINING_DIR dest_dir = config.root_path / TRAINING_DIR
os.makedirs(dest_dir, exist_ok=True) os.makedirs(dest_dir, exist_ok=True)
conf_file = dest_dir / CONF_FILE conf_file = dest_dir / CONF_FILE
conf = OmegaConf.create(args) conf = OmegaConf.create(args)
@ -381,7 +381,7 @@ def previous_args() -> dict:
Get the previous arguments used. Get the previous arguments used.
""" """
assert config is not None assert config is not None
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE conf_file = config.root_path / TRAINING_DIR / CONF_FILE
try: try:
conf = OmegaConf.load(conf_file) conf = OmegaConf.load(conf_file)
conf["placeholder_token"] = conf["placeholder_token"].strip("<>") conf["placeholder_token"] = conf["placeholder_token"].strip("<>")