wip textual inversion manager (unit tests passing for 1v embedding overwriting)

This commit is contained in:
Damian Stewart 2022-12-15 13:30:13 +01:00
parent 417c2b57d9
commit 44d8a5a7c8
2 changed files with 167 additions and 8 deletions

View File

@ -157,7 +157,8 @@ class TextualInversionManager():
"""
assert prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})"
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
pad_token_id = self.clip_embedder.pad_token_id
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
overwritten_prompt_embeddings = prompt_embeddings.clone()
for i, token_id in enumerate(prompt_token_ids):
if token_id == pad_token_id:
continue
@ -167,9 +168,9 @@ class TextualInversionManager():
# only overwrite the textual inversion token id or the padding token id
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
break
prompt_embeddings[i+j] = textual_inversion.embedding[j]
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
return prompt_embeddings
return overwritten_prompt_embeddings

View File

@ -7,6 +7,7 @@ from ldm.modules.embedding_manager import TextualInversionManager
KNOWN_WORDS = ['a', 'b', 'c']
KNOWN_WORDS_TOKEN_IDS = [0, 1, 2]
UNKNOWN_WORDS = ['d', 'e', 'f']
class DummyEmbeddingsList(list):
@ -139,15 +140,15 @@ class TextualInversionManagerTestCase(unittest.TestCase):
def test_pad_raises_on_eos_bos(self):
tim = TextualInversionManager(DummyClipEmbedder())
prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id,
0, 1, 2,
tim.clip_embedder.tokenizer.eos_token_id]
prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id] + \
[KNOWN_WORDS_TOKEN_IDS] + \
[tim.clip_embedder.tokenizer.eos_token_id]
with self.assertRaises(ValueError):
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_with_eos_bos)
def test_pad_tokens_list_vector_length_1(self):
tim = TextualInversionManager(DummyClipEmbedder())
prompt_token_ids = [0, 1, 2]
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
@ -174,7 +175,7 @@ class TextualInversionManagerTestCase(unittest.TestCase):
def test_pad_tokens_list_vector_length_2(self):
tim = TextualInversionManager(DummyClipEmbedder())
prompt_token_ids = [0, 1, 2]
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
@ -202,4 +203,161 @@ class TextualInversionManagerTestCase(unittest.TestCase):
self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id] + prompt_token_ids[2:3], expanded_prompt_token_ids)
def test_pad_tokens_list_vector_length_8(self):
tim = TextualInversionManager(DummyClipEmbedder())
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
test_embedding_8v = torch.randn([8, 768])
test_embedding_8v_token = "<inversion-trigger-vector-length-8>"
test_embedding_8v_token_id = tim.add_textual_inversion(test_embedding_8v_token, test_embedding_8v)
self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS))
# at the end
prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_append)
self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids)
self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7, expanded_prompt_token_ids)
# at the start
prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_prepend)
self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids)
self.assertEqual([test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7 + prompt_token_ids, expanded_prompt_token_ids)
# in the middle
prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_insert)
self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids)
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7 + prompt_token_ids[2:3], expanded_prompt_token_ids)
def test_overwrite_textual_inversion_noop(self):
tim = TextualInversionManager(DummyClipEmbedder())
prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
KNOWN_WORDS_TOKEN_IDS + \
(77-4) * [tim.clip_embedder.tokenizer.eos_token_id]
prompt_embeddings = torch.randn([77, 768])
# add embedding
test_embedding_1v = torch.randn([1, 768])
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v)
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
overwritten_embeddings = tim.overwrite_textual_inversion_embeddings(prompt_token_ids, prompt_embeddings)
self.assertTrue(torch.equal(prompt_embeddings, overwritten_embeddings))
def test_overwrite_textual_inversion_1v_single(self):
tim = TextualInversionManager(DummyClipEmbedder())
default_prompt_embeddings = torch.randn([77, 768])
# add embedding
test_embedding_1v = torch.randn([1, 768])
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v)
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
# at the end
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_1v_token_id]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v[0]))
# at the start
prompt_token_ids = [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v[0]))
# in the middle
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v[0]))
def test_overwrite_textual_inversion_1v_multiple(self):
tim = TextualInversionManager(DummyClipEmbedder())
default_prompt_embeddings = torch.randn([77, 768])
# add embeddings
test_embedding_1v_1 = torch.randn([1, 768])
test_embedding_1v_1_token = "<inversion-trigger-vector-length-1-a>"
test_embedding_1v_1_token_id = tim.add_textual_inversion(test_embedding_1v_1_token, test_embedding_1v_1)
self.assertEqual(test_embedding_1v_1_token_id, len(KNOWN_WORDS))
test_embedding_1v_2 = torch.randn([1, 768])
test_embedding_1v_2_token = "<inversion-trigger-vector-length-1-b>"
test_embedding_1v_2_token_id = tim.add_textual_inversion(test_embedding_1v_2_token, test_embedding_1v_2)
self.assertEqual(test_embedding_1v_2_token_id, len(KNOWN_WORDS)+1)
# at the end
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5], test_embedding_1v_2[0]))
# at the start
prompt_token_ids = [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_2[0]))
# clumped in the middle
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
# scattered
"""
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
expanded_prompt_token_ids + \
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
"""