more explicit equality tests when overwriting

This commit is contained in:
Damian Stewart 2022-12-15 13:39:09 +01:00
parent 44d8a5a7c8
commit beb1b08d9a

View File

@ -319,8 +319,10 @@ class TextualInversionManagerTestCase(unittest.TestCase):
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5], test_embedding_1v_2[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[6:77], default_prompt_embeddings[6:77]))
# at the start
prompt_token_ids = [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS
@ -331,8 +333,10 @@ class TextualInversionManagerTestCase(unittest.TestCase):
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_2[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3:77], default_prompt_embeddings[3:77]))
# clumped in the middle
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
@ -343,12 +347,13 @@ class TextualInversionManagerTestCase(unittest.TestCase):
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4:77], default_prompt_embeddings[4:77]))
# scattered
"""
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3]
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + [test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
@ -356,8 +361,10 @@ class TextualInversionManagerTestCase(unittest.TestCase):
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
"""
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], default_prompt_embeddings[3]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_2[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))