mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip textual inversion manager (unit tests passing for 1v embedding overwriting)
This commit is contained in:
parent
417c2b57d9
commit
44d8a5a7c8
@ -157,7 +157,8 @@ class TextualInversionManager():
|
||||
"""
|
||||
assert prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})"
|
||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||
pad_token_id = self.clip_embedder.pad_token_id
|
||||
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
|
||||
overwritten_prompt_embeddings = prompt_embeddings.clone()
|
||||
for i, token_id in enumerate(prompt_token_ids):
|
||||
if token_id == pad_token_id:
|
||||
continue
|
||||
@ -167,9 +168,9 @@ class TextualInversionManager():
|
||||
# only overwrite the textual inversion token id or the padding token id
|
||||
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
|
||||
break
|
||||
prompt_embeddings[i+j] = textual_inversion.embedding[j]
|
||||
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
|
||||
|
||||
return prompt_embeddings
|
||||
return overwritten_prompt_embeddings
|
||||
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ from ldm.modules.embedding_manager import TextualInversionManager
|
||||
|
||||
|
||||
KNOWN_WORDS = ['a', 'b', 'c']
|
||||
KNOWN_WORDS_TOKEN_IDS = [0, 1, 2]
|
||||
UNKNOWN_WORDS = ['d', 'e', 'f']
|
||||
|
||||
class DummyEmbeddingsList(list):
|
||||
@ -139,15 +140,15 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
def test_pad_raises_on_eos_bos(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id,
|
||||
0, 1, 2,
|
||||
tim.clip_embedder.tokenizer.eos_token_id]
|
||||
prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
[KNOWN_WORDS_TOKEN_IDS] + \
|
||||
[tim.clip_embedder.tokenizer.eos_token_id]
|
||||
with self.assertRaises(ValueError):
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_with_eos_bos)
|
||||
|
||||
def test_pad_tokens_list_vector_length_1(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
prompt_token_ids = [0, 1, 2]
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||
@ -174,7 +175,7 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
def test_pad_tokens_list_vector_length_2(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
prompt_token_ids = [0, 1, 2]
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||
@ -202,4 +203,161 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id] + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
||||
|
||||
def test_pad_tokens_list_vector_length_8(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
test_embedding_8v = torch.randn([8, 768])
|
||||
test_embedding_8v_token = "<inversion-trigger-vector-length-8>"
|
||||
test_embedding_8v_token_id = tim.add_textual_inversion(test_embedding_8v_token, test_embedding_8v)
|
||||
self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS))
|
||||
|
||||
# at the end
|
||||
prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_append)
|
||||
self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7, expanded_prompt_token_ids)
|
||||
|
||||
# at the start
|
||||
prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_prepend)
|
||||
self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids)
|
||||
self.assertEqual([test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7 + prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
# in the middle
|
||||
prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_8v_insert)
|
||||
self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + [tim.clip_embedder.tokenizer.pad_token_id]*7 + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
||||
|
||||
|
||||
def test_overwrite_textual_inversion_noop(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
KNOWN_WORDS_TOKEN_IDS + \
|
||||
(77-4) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
prompt_embeddings = torch.randn([77, 768])
|
||||
|
||||
# add embedding
|
||||
test_embedding_1v = torch.randn([1, 768])
|
||||
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
|
||||
test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v)
|
||||
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
|
||||
|
||||
overwritten_embeddings = tim.overwrite_textual_inversion_embeddings(prompt_token_ids, prompt_embeddings)
|
||||
self.assertTrue(torch.equal(prompt_embeddings, overwritten_embeddings))
|
||||
|
||||
def test_overwrite_textual_inversion_1v_single(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
default_prompt_embeddings = torch.randn([77, 768])
|
||||
|
||||
# add embedding
|
||||
test_embedding_1v = torch.randn([1, 768])
|
||||
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
|
||||
test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v)
|
||||
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
|
||||
|
||||
# at the end
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_1v_token_id]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v[0]))
|
||||
|
||||
# at the start
|
||||
prompt_token_ids = [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v[0]))
|
||||
|
||||
# in the middle
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v[0]))
|
||||
|
||||
|
||||
|
||||
def test_overwrite_textual_inversion_1v_multiple(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
default_prompt_embeddings = torch.randn([77, 768])
|
||||
|
||||
# add embeddings
|
||||
test_embedding_1v_1 = torch.randn([1, 768])
|
||||
test_embedding_1v_1_token = "<inversion-trigger-vector-length-1-a>"
|
||||
test_embedding_1v_1_token_id = tim.add_textual_inversion(test_embedding_1v_1_token, test_embedding_1v_1)
|
||||
self.assertEqual(test_embedding_1v_1_token_id, len(KNOWN_WORDS))
|
||||
|
||||
test_embedding_1v_2 = torch.randn([1, 768])
|
||||
test_embedding_1v_2_token = "<inversion-trigger-vector-length-1-b>"
|
||||
test_embedding_1v_2_token_id = tim.add_textual_inversion(test_embedding_1v_2_token, test_embedding_1v_2)
|
||||
self.assertEqual(test_embedding_1v_2_token_id, len(KNOWN_WORDS)+1)
|
||||
|
||||
# at the end
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5], test_embedding_1v_2[0]))
|
||||
|
||||
# at the start
|
||||
prompt_token_ids = [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_2[0]))
|
||||
|
||||
# clumped in the middle
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id, test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
|
||||
|
||||
# scattered
|
||||
"""
|
||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_1_token_id] + KNOWN_WORDS_TOKEN_IDS[1:2] + test_embedding_1v_2_token_id] + KNOWN_WORDS_TOKEN_IDS[2:3]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||
expanded_prompt_token_ids + \
|
||||
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||
|
||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
|
||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
|
||||
"""
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user