From 8b7925edf3b54bcffecba1ead70af0c2f8017b0b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 22 Feb 2023 11:29:30 -0500 Subject: [PATCH] fix crash in textual inversion with "num_samples=0" error -At some point pathlib was added to the list of imported modules and this broken the os.path code that assembled the sample data set. -Now fixed by replacing os.path calls with Path methods --- ldm/invoke/training/textual_inversion.py | 3 +-- ldm/invoke/training/textual_inversion_training.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ldm/invoke/training/textual_inversion.py b/ldm/invoke/training/textual_inversion.py index 7ea7970ecf..2961e4d99c 100755 --- a/ldm/invoke/training/textual_inversion.py +++ b/ldm/invoke/training/textual_inversion.py @@ -421,7 +421,6 @@ def do_front_end(args: Namespace): save_args(args) try: - print(f"DEBUG: args = {args}") do_textual_inversion_training(**args) copy_to_embeddings_folder(args) except Exception as e: @@ -454,7 +453,7 @@ def main(): '** Not enough window space for the interface. Please make your window larger and try again.' ) else: - print(f"** A layout error has occurred: {str(e)}") + print(f"** An error has occurred: {str(e)}") sys.exit(-1) diff --git a/ldm/invoke/training/textual_inversion_training.py b/ldm/invoke/training/textual_inversion_training.py index 0c781519af..58c67b2ca8 100644 --- a/ldm/invoke/training/textual_inversion_training.py +++ b/ldm/invoke/training/textual_inversion_training.py @@ -430,7 +430,7 @@ class TextualInversionDataset(Dataset): placeholder_token="*", center_crop=False, ): - self.data_root = data_root + self.data_root = Path(data_root) self.tokenizer = tokenizer self.learnable_property = learnable_property self.size = size @@ -439,9 +439,9 @@ class TextualInversionDataset(Dataset): self.flip_p = flip_p self.image_paths = [ - os.path.join(self.data_root, file_path) - for file_path in os.listdir(self.data_root) - if os.path.isfile(file_path) and file_path.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF')) + self.data_root / file_path + for file_path in self.data_root.iterdir() + if file_path.is_file() and file_path.name.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF')) ] self.num_images = len(self.image_paths)