mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
more explicit equality tests when overwriting
This commit is contained in:
parent
44d8a5a7c8
commit
beb1b08d9a
@ -319,8 +319,10 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5], test_embedding_1v_2[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[6:77], default_prompt_embeddings[6:77]))
|
||||
|
||||
# at the start
|
||||
prompt_token_ids = [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS
|
||||
@ -331,8 +333,10 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_2[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3:77], default_prompt_embeddings[3:77]))
|
||||
|
||||
# clumped in the middle
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
|
||||
@ -343,12 +347,13 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4:77], default_prompt_embeddings[4:77]))
|
||||
|
||||
# scattered
|
||||
"""
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3]
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + [test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
@ -356,8 +361,10 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
|
||||
"""
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], default_prompt_embeddings[3]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_2[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user