update textual inversion training to use root_path rather than root_dir

This commit is contained in:
Lincoln Stein 2024-03-16 22:50:12 -04:00 committed by psychedelicious
parent 5d16a40b95
commit f1450c2c24

View File

@ -123,7 +123,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
value=str(
saved_args.get(
"train_data_dir",
config.root_dir / TRAINING_DATA / default_placeholder_token,
config.root_path / TRAINING_DATA / default_placeholder_token,
)
),
scroll_exit=True,
@ -136,7 +136,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
value=str(
saved_args.get(
"output_dir",
config.root_dir / TRAINING_DIR / default_placeholder_token,
config.root_path / TRAINING_DIR / default_placeholder_token,
)
),
scroll_exit=True,
@ -241,8 +241,8 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
def initializer_changed(self) -> None:
placeholder = self.placeholder_token.value
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder)
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
self.train_data_dir.value = str(config.root_path / TRAINING_DATA / placeholder)
self.output_dir.value = str(config.root_path / TRAINING_DIR / placeholder)
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
def on_ok(self):
@ -354,7 +354,7 @@ def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
assert config is not None
source = Path(args["output_dir"], "learned_embeds.bin")
dest_dir_name = args["placeholder_token"].strip("<>")
destination = config.root_dir / "embeddings" / dest_dir_name
destination = config.root_path / "embeddings" / dest_dir_name
os.makedirs(destination, exist_ok=True)
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
shutil.copy(source, destination)
@ -369,7 +369,7 @@ def save_args(args: dict) -> None:
Save the current argument values to an omegaconf file
"""
assert config is not None
dest_dir = config.root_dir / TRAINING_DIR
dest_dir = config.root_path / TRAINING_DIR
os.makedirs(dest_dir, exist_ok=True)
conf_file = dest_dir / CONF_FILE
conf = OmegaConf.create(args)
@ -381,7 +381,7 @@ def previous_args() -> dict:
Get the previous arguments used.
"""
assert config is not None
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
conf_file = config.root_path / TRAINING_DIR / CONF_FILE
try:
conf = OmegaConf.load(conf_file)
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")