diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 285ef7429f..9532493b70 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -157,7 +157,8 @@ class TextualInversionManager(): """ assert prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})" textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] - pad_token_id = self.clip_embedder.pad_token_id + pad_token_id = self.clip_embedder.tokenizer.pad_token_id + overwritten_prompt_embeddings = prompt_embeddings.clone() for i, token_id in enumerate(prompt_token_ids): if token_id == pad_token_id: continue @@ -167,9 +168,9 @@ class TextualInversionManager(): # only overwrite the textual inversion token id or the padding token id if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id: break - prompt_embeddings[i+j] = textual_inversion.embedding[j] + overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j] - return prompt_embeddings + return overwritten_prompt_embeddings diff --git a/tests/test_textual_inversion.py b/tests/test_textual_inversion.py index 17b4da0c88..59d1ffa55e 100644 --- a/tests/test_textual_inversion.py +++ b/tests/test_textual_inversion.py @@ -7,6 +7,7 @@ from ldm.modules.embedding_manager import TextualInversionManager KNOWN_WORDS = ['a', 'b', 'c'] +KNOWN_WORDS_TOKEN_IDS = [0, 1, 2] UNKNOWN_WORDS = ['d', 'e', 'f'] class DummyEmbeddingsList(list): @@ -139,15 +140,15 @@ class TextualInversionManagerTestCase(unittest.TestCase): def test_pad_raises_on_eos_bos(self): tim = TextualInversionManager(DummyClipEmbedder()) - prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id, - 0, 1, 2, - tim.clip_embedder.tokenizer.eos_token_id] + prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id] + \ + [KNOWN_WORDS_TOKEN_IDS] + \ + [tim.clip_embedder.tokenizer.eos_token_id] with self.assertRaises(ValueError): expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_with_eos_bos) def test_pad_tokens_list_vector_length_1(self): tim = TextualInversionManager(DummyClipEmbedder()) - prompt_token_ids = [0, 1, 2] + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids) self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) @@ -174,7 +175,7 @@ class TextualInversionManagerTestCase(unittest.TestCase): def test_pad_tokens_list_vector_length_2(self): tim = TextualInversionManager(DummyClipEmbedder()) - prompt_token_ids = [0, 1, 2] + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids) self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) @@ -202,4 +203,161 @@ class TextualInversionManagerTestCase(unittest.TestCase): self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids) self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id] + prompt_token_ids[2:3], expanded_prompt_token_ids) + def test_pad_tokens_list_vector_length_8(self): + tim = TextualInversionManager(DummyClipEmbedder()) + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() + + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids) + self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) + + test_embedding_8v = torch.randn([8, 768]) + test_embedding_8v_token = "" + test_embedding_8v_token_id = tim.add_textual_inversion(test_embedding_8v_token, test_embedding_8v) + self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS)) + + # at the end + prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_append) + self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids) + self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7, expanded_prompt_token_ids) + + # at the start + prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_prepend) + self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids) + self.assertEqual([test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7 + prompt_token_ids, expanded_prompt_token_ids) + + # in the middle + prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_insert) + self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids) + self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7 + prompt_token_ids[2:3], expanded_prompt_token_ids) + + + def test_overwrite_textual_inversion_noop(self): + tim = TextualInversionManager(DummyClipEmbedder()) + prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + KNOWN_WORDS_TOKEN_IDS + \ + (77-4) * [tim.clip_embedder.tokenizer.eos_token_id] + prompt_embeddings = torch.randn([77, 768]) + + # add embedding + test_embedding_1v = torch.randn([1, 768]) + test_embedding_1v_token = "" + test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v) + self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS)) + + overwritten_embeddings = tim.overwrite_textual_inversion_embeddings(prompt_token_ids, prompt_embeddings) + self.assertTrue(torch.equal(prompt_embeddings, overwritten_embeddings)) + + def test_overwrite_textual_inversion_1v_single(self): + tim = TextualInversionManager(DummyClipEmbedder()) + default_prompt_embeddings = torch.randn([77, 768]) + + # add embedding + test_embedding_1v = torch.randn([1, 768]) + test_embedding_1v_token = "" + test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v) + self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS)) + + # at the end + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_1v_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v[0])) + + # at the start + prompt_token_ids = [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v[0])) + + # in the middle + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v[0])) + + + + def test_overwrite_textual_inversion_1v_multiple(self): + tim = TextualInversionManager(DummyClipEmbedder()) + default_prompt_embeddings = torch.randn([77, 768]) + + # add embeddings + test_embedding_1v_1 = torch.randn([1, 768]) + test_embedding_1v_1_token = "" + test_embedding_1v_1_token_id = tim.add_textual_inversion(test_embedding_1v_1_token, test_embedding_1v_1) + self.assertEqual(test_embedding_1v_1_token_id, len(KNOWN_WORDS)) + + test_embedding_1v_2 = torch.randn([1, 768]) + test_embedding_1v_2_token = "" + test_embedding_1v_2_token_id = tim.add_textual_inversion(test_embedding_1v_2_token, test_embedding_1v_2) + self.assertEqual(test_embedding_1v_2_token_id, len(KNOWN_WORDS)+1) + + # at the end + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_1[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[5], test_embedding_1v_2[0])) + + # at the start + prompt_token_ids = [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v_1[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_2[0])) + + # clumped in the middle + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0])) + + # scattered + """ + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0])) + """ +