From beb1b08d9a98112ed2fe073580568e1a18698da3 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 15 Dec 2022 13:39:09 +0100 Subject: [PATCH] more explicit equality tests when overwriting --- tests/test_textual_inversion.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_textual_inversion.py b/tests/test_textual_inversion.py index 59d1ffa55e..0c3be656d8 100644 --- a/tests/test_textual_inversion.py +++ b/tests/test_textual_inversion.py @@ -319,8 +319,10 @@ class TextualInversionManagerTestCase(unittest.TestCase): overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_1[0])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[5], test_embedding_1v_2[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[6:77], default_prompt_embeddings[6:77])) # at the start prompt_token_ids = [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS @@ -331,8 +333,10 @@ class TextualInversionManagerTestCase(unittest.TestCase): overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v_1[0])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_2[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[3:77], default_prompt_embeddings[3:77])) # clumped in the middle prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3] @@ -343,12 +347,13 @@ class TextualInversionManagerTestCase(unittest.TestCase): overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[4:77], default_prompt_embeddings[4:77])) # scattered - """ - prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3] + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + [test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3] expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ expanded_prompt_token_ids + \ @@ -356,8 +361,10 @@ class TextualInversionManagerTestCase(unittest.TestCase): overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0])) - self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0])) - """ + self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], default_prompt_embeddings[3])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_2[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))