From 009f32ed39a7280997c3ffab112adadee0b44279 Mon Sep 17 00:00:00 2001 From: damian Date: Thu, 15 Dec 2022 21:29:47 +0100 Subject: [PATCH] unit tests passing for embeddings with vector length >1 --- backend/invoke_ai_web_server.py | 2 + ldm/modules/embedding_manager.py | 14 ++- tests/test_textual_inversion.py | 169 +++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 3 deletions(-) diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index d91d66e5be..654a8826a8 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -1099,6 +1099,8 @@ class InvokeAIWebServer: get_tokens_for_prompt(self.generate.model, parsed_prompt) attention_maps_image_base64_url = None if attention_maps_image is None \ else image_to_dataURL(attention_maps_image) + if attention_maps_image is not None: + attention_maps_image.save(path + '.attention.png', 'PNG') self.socketio.emit( "generationResult", diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 9532493b70..603c23a94a 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -150,12 +150,18 @@ class TextualInversionManager(): subsequent rows in `prompt_embeddings` as well. :param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght - >1 (call `expand_textual_inversion_token_ids()` to do this) + >1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers. :param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in `prompt_token_ids` (i.e., also already expanded). :return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings. """ - assert prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})" + if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77 + raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") + if len(prompt_token_ids) != self.clip_embedder.max_length: + raise ValueError(f"prompt_token_ids must be fully padded out to {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") + if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id: + raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id") + textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] pad_token_id = self.clip_embedder.tokenizer.pad_token_id overwritten_prompt_embeddings = prompt_embeddings.clone() @@ -164,7 +170,9 @@ class TextualInversionManager(): continue if token_id in textual_inversion_token_ids: textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) - for j in range(0, textual_inversion.embedding_vector_length): + end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1) + count_to_overwrite = end_index - i + for j in range(0, count_to_overwrite): # only overwrite the textual inversion token id or the padding token id if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id: break diff --git a/tests/test_textual_inversion.py b/tests/test_textual_inversion.py index 0c3be656d8..30b4c6cbcd 100644 --- a/tests/test_textual_inversion.py +++ b/tests/test_textual_inversion.py @@ -269,7 +269,9 @@ class TextualInversionManagerTestCase(unittest.TestCase): overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77])) # at the start prompt_token_ids = [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS @@ -280,7 +282,9 @@ class TextualInversionManagerTestCase(unittest.TestCase): overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[2:77], default_prompt_embeddings[2:77])) # in the middle prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3] @@ -291,7 +295,9 @@ class TextualInversionManagerTestCase(unittest.TestCase): overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v[0])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[3:77], default_prompt_embeddings[3:77])) @@ -367,4 +373,167 @@ class TextualInversionManagerTestCase(unittest.TestCase): self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_2[0])) self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77])) + def test_overwrite_textual_inversion_4v_single(self): + tim = TextualInversionManager(DummyClipEmbedder()) + default_prompt_embeddings = torch.randn([77, 768]) + # add embedding + test_embedding_4v = torch.randn([4, 768]) + test_embedding_4v_token = "" + test_embedding_4v_token_id = tim.add_textual_inversion(test_embedding_4v_token, test_embedding_4v) + self.assertEqual(test_embedding_4v_token_id, len(KNOWN_WORDS)) + + # at the end + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_4v_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[4:8], test_embedding_4v)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[8:77], default_prompt_embeddings[8:77])) + + # at the start + prompt_token_ids = [test_embedding_4v_token_id] + KNOWN_WORDS_TOKEN_IDS + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77])) + + # in the middle + prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_4v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[2:6], test_embedding_4v)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[6:77], default_prompt_embeddings[6:77])) + + def test_overwrite_textual_inversion_4v_overflow(self): + tim = TextualInversionManager(DummyClipEmbedder()) + default_prompt_embeddings = torch.randn([77, 768]) + + # add embedding + test_embedding_4v = torch.randn([4, 768]) + test_embedding_4v_token = "" + test_embedding_4v_token_id = tim.add_textual_inversion(test_embedding_4v_token, test_embedding_4v) + self.assertEqual(test_embedding_4v_token_id, len(KNOWN_WORDS)) + + base_prompt = KNOWN_WORDS_TOKEN_IDS * 24 + + # at the end + prompt_token_ids = base_prompt + [test_embedding_4v_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + base_prompt_length = len(base_prompt) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:base_prompt_length+1], default_prompt_embeddings[0:base_prompt_length+1])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1:base_prompt_length+1+3], test_embedding_4v[0:3])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+3:77], default_prompt_embeddings[base_prompt_length+1+3:77])) + + # at the start + prompt_token_ids = [test_embedding_4v_token_id] + base_prompt + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + expanded_prompt_token_ids = expanded_prompt_token_ids[0:75] + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77])) + + # in the middle + prompt_token_ids = base_prompt[0:20] + [test_embedding_4v_token_id] + base_prompt[20:-1] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:21], default_prompt_embeddings[0:21])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[21:25], test_embedding_4v)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[25:77], default_prompt_embeddings[25:77])) + + + def test_overwrite_textual_inversion_4v_multiple(self): + tim = TextualInversionManager(DummyClipEmbedder()) + default_prompt_embeddings = torch.randn([77, 768]) + + # add embedding + test_embedding_4v_1 = torch.randn([4, 768]) + test_embedding_4v_1_token = "" + test_embedding_4v_1_token_id = tim.add_textual_inversion(test_embedding_4v_1_token, test_embedding_4v_1) + self.assertEqual(test_embedding_4v_1_token_id, len(KNOWN_WORDS)) + + test_embedding_4v_2 = torch.randn([4, 768]) + test_embedding_4v_2_token = "" + test_embedding_4v_2_token_id = tim.add_textual_inversion(test_embedding_4v_2_token, test_embedding_4v_2) + self.assertEqual(test_embedding_4v_2_token_id, len(KNOWN_WORDS)+1) + + base_prompt = KNOWN_WORDS_TOKEN_IDS * 20 + + # at the end + prompt_token_ids = base_prompt + [test_embedding_4v_1_token_id] + [test_embedding_4v_2_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + base_prompt_length = len(base_prompt) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:base_prompt_length+1], default_prompt_embeddings[0:base_prompt_length+1])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1:base_prompt_length+1+4], test_embedding_4v_1)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+4:base_prompt_length+1+4+4], test_embedding_4v_2)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+4+4:77], default_prompt_embeddings[base_prompt_length+1+4+4:77])) + + # at the start + prompt_token_ids = [test_embedding_4v_1_token_id] + [test_embedding_4v_2_token_id] + base_prompt + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + expanded_prompt_token_ids = expanded_prompt_token_ids[0:75] + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v_1)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:9], test_embedding_4v_2)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[9:77], default_prompt_embeddings[9:77])) + + # in the middle + prompt_token_ids = base_prompt[0:10] + [test_embedding_4v_1_token_id] + base_prompt[10:20] + [test_embedding_4v_2_token_id] + base_prompt[20:-1] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids) + padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \ + expanded_prompt_token_ids + \ + (76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id] + + overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings) + self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:11], default_prompt_embeddings[0:11])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[11:15], test_embedding_4v_1)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[15:25], default_prompt_embeddings[15:25])) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[25:29], test_embedding_4v_2)) + self.assertTrue(torch.equal(overwritten_prompt_embeddings[29:77], default_prompt_embeddings[29:77]))