fix crash in textual inversion with "num_samples=0" error (#2762)

-At some point pathlib was added to the list of imported modules and
this broken the os.path code that assembled the sample data set.

-Now fixed by replacing os.path calls with Path methods
This commit is contained in:
Lincoln Stein 2023-02-22 12:31:28 -05:00 committed by GitHub
commit a4afb69615
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 6 deletions

View File

@ -421,7 +421,6 @@ def do_front_end(args: Namespace):
save_args(args) save_args(args)
try: try:
print(f"DEBUG: args = {args}")
do_textual_inversion_training(**args) do_textual_inversion_training(**args)
copy_to_embeddings_folder(args) copy_to_embeddings_folder(args)
except Exception as e: except Exception as e:
@ -454,7 +453,7 @@ def main():
'** Not enough window space for the interface. Please make your window larger and try again.' '** Not enough window space for the interface. Please make your window larger and try again.'
) )
else: else:
print(f"** A layout error has occurred: {str(e)}") print(f"** An error has occurred: {str(e)}")
sys.exit(-1) sys.exit(-1)

View File

@ -430,7 +430,7 @@ class TextualInversionDataset(Dataset):
placeholder_token="*", placeholder_token="*",
center_crop=False, center_crop=False,
): ):
self.data_root = data_root self.data_root = Path(data_root)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.learnable_property = learnable_property self.learnable_property = learnable_property
self.size = size self.size = size
@ -439,9 +439,9 @@ class TextualInversionDataset(Dataset):
self.flip_p = flip_p self.flip_p = flip_p
self.image_paths = [ self.image_paths = [
os.path.join(self.data_root, file_path) self.data_root / file_path
for file_path in os.listdir(self.data_root) for file_path in self.data_root.iterdir()
if os.path.isfile(file_path) and file_path.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF')) if file_path.is_file() and file_path.name.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF'))
] ]
self.num_images = len(self.image_paths) self.num_images = len(self.image_paths)