mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
unit tests passing for embeddings with vector length >1
This commit is contained in:
parent
beb1b08d9a
commit
009f32ed39
@ -1099,6 +1099,8 @@ class InvokeAIWebServer:
|
|||||||
get_tokens_for_prompt(self.generate.model, parsed_prompt)
|
get_tokens_for_prompt(self.generate.model, parsed_prompt)
|
||||||
attention_maps_image_base64_url = None if attention_maps_image is None \
|
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||||
else image_to_dataURL(attention_maps_image)
|
else image_to_dataURL(attention_maps_image)
|
||||||
|
if attention_maps_image is not None:
|
||||||
|
attention_maps_image.save(path + '.attention.png', 'PNG')
|
||||||
|
|
||||||
self.socketio.emit(
|
self.socketio.emit(
|
||||||
"generationResult",
|
"generationResult",
|
||||||
|
@ -150,12 +150,18 @@ class TextualInversionManager():
|
|||||||
subsequent rows in `prompt_embeddings` as well.
|
subsequent rows in `prompt_embeddings` as well.
|
||||||
|
|
||||||
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
|
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
|
||||||
>1 (call `expand_textual_inversion_token_ids()` to do this)
|
>1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers.
|
||||||
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
|
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
|
||||||
`prompt_token_ids` (i.e., also already expanded).
|
`prompt_token_ids` (i.e., also already expanded).
|
||||||
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
|
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
|
||||||
"""
|
"""
|
||||||
assert prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})"
|
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
|
||||||
|
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
|
||||||
|
if len(prompt_token_ids) != self.clip_embedder.max_length:
|
||||||
|
raise ValueError(f"prompt_token_ids must be fully padded out to {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
|
||||||
|
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
|
||||||
|
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")
|
||||||
|
|
||||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||||
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
|
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
|
||||||
overwritten_prompt_embeddings = prompt_embeddings.clone()
|
overwritten_prompt_embeddings = prompt_embeddings.clone()
|
||||||
@ -164,7 +170,9 @@ class TextualInversionManager():
|
|||||||
continue
|
continue
|
||||||
if token_id in textual_inversion_token_ids:
|
if token_id in textual_inversion_token_ids:
|
||||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||||
for j in range(0, textual_inversion.embedding_vector_length):
|
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
|
||||||
|
count_to_overwrite = end_index - i
|
||||||
|
for j in range(0, count_to_overwrite):
|
||||||
# only overwrite the textual inversion token id or the padding token id
|
# only overwrite the textual inversion token id or the padding token id
|
||||||
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
|
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
|
||||||
break
|
break
|
||||||
|
@ -269,7 +269,9 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4]))
|
||||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v[0]))
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v[0]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
|
||||||
|
|
||||||
# at the start
|
# at the start
|
||||||
prompt_token_ids = [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS
|
prompt_token_ids = [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS
|
||||||
@ -280,7 +282,9 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
|
||||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v[0]))
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v[0]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2:77], default_prompt_embeddings[2:77]))
|
||||||
|
|
||||||
# in the middle
|
# in the middle
|
||||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
|
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_1v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
|
||||||
@ -291,7 +295,9 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
|
||||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v[0]))
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v[0]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3:77], default_prompt_embeddings[3:77]))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -367,4 +373,167 @@ class TextualInversionManagerTestCase(unittest.TestCase):
|
|||||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_2[0]))
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_2[0]))
|
||||||
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
|
||||||
|
|
||||||
|
def test_overwrite_textual_inversion_4v_single(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
default_prompt_embeddings = torch.randn([77, 768])
|
||||||
|
|
||||||
|
# add embedding
|
||||||
|
test_embedding_4v = torch.randn([4, 768])
|
||||||
|
test_embedding_4v_token = "<inversion-trigger-vector-length-4>"
|
||||||
|
test_embedding_4v_token_id = tim.add_textual_inversion(test_embedding_4v_token, test_embedding_4v)
|
||||||
|
self.assertEqual(test_embedding_4v_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS + [test_embedding_4v_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4:8], test_embedding_4v))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[8:77], default_prompt_embeddings[8:77]))
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids = [test_embedding_4v_token_id] + KNOWN_WORDS_TOKEN_IDS
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS[0:1] + [test_embedding_4v_token_id] + KNOWN_WORDS_TOKEN_IDS[1:3]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2:6], test_embedding_4v))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[6:77], default_prompt_embeddings[6:77]))
|
||||||
|
|
||||||
|
def test_overwrite_textual_inversion_4v_overflow(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
default_prompt_embeddings = torch.randn([77, 768])
|
||||||
|
|
||||||
|
# add embedding
|
||||||
|
test_embedding_4v = torch.randn([4, 768])
|
||||||
|
test_embedding_4v_token = "<inversion-trigger-vector-length-4>"
|
||||||
|
test_embedding_4v_token_id = tim.add_textual_inversion(test_embedding_4v_token, test_embedding_4v)
|
||||||
|
self.assertEqual(test_embedding_4v_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
base_prompt = KNOWN_WORDS_TOKEN_IDS * 24
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids = base_prompt + [test_embedding_4v_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
base_prompt_length = len(base_prompt)
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:base_prompt_length+1], default_prompt_embeddings[0:base_prompt_length+1]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1:base_prompt_length+1+3], test_embedding_4v[0:3]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+3:77], default_prompt_embeddings[base_prompt_length+1+3:77]))
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids = [test_embedding_4v_token_id] + base_prompt
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
expanded_prompt_token_ids = expanded_prompt_token_ids[0:75]
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids = base_prompt[0:20] + [test_embedding_4v_token_id] + base_prompt[20:-1]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:21], default_prompt_embeddings[0:21]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[21:25], test_embedding_4v))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[25:77], default_prompt_embeddings[25:77]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_overwrite_textual_inversion_4v_multiple(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
default_prompt_embeddings = torch.randn([77, 768])
|
||||||
|
|
||||||
|
# add embedding
|
||||||
|
test_embedding_4v_1 = torch.randn([4, 768])
|
||||||
|
test_embedding_4v_1_token = "<inversion-trigger-vector-length-4-a>"
|
||||||
|
test_embedding_4v_1_token_id = tim.add_textual_inversion(test_embedding_4v_1_token, test_embedding_4v_1)
|
||||||
|
self.assertEqual(test_embedding_4v_1_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
test_embedding_4v_2 = torch.randn([4, 768])
|
||||||
|
test_embedding_4v_2_token = "<inversion-trigger-vector-length-4-b>"
|
||||||
|
test_embedding_4v_2_token_id = tim.add_textual_inversion(test_embedding_4v_2_token, test_embedding_4v_2)
|
||||||
|
self.assertEqual(test_embedding_4v_2_token_id, len(KNOWN_WORDS)+1)
|
||||||
|
|
||||||
|
base_prompt = KNOWN_WORDS_TOKEN_IDS * 20
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids = base_prompt + [test_embedding_4v_1_token_id] + [test_embedding_4v_2_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
base_prompt_length = len(base_prompt)
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:base_prompt_length+1], default_prompt_embeddings[0:base_prompt_length+1]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1:base_prompt_length+1+4], test_embedding_4v_1))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+4:base_prompt_length+1+4+4], test_embedding_4v_2))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+4+4:77], default_prompt_embeddings[base_prompt_length+1+4+4:77]))
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids = [test_embedding_4v_1_token_id] + [test_embedding_4v_2_token_id] + base_prompt
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
expanded_prompt_token_ids = expanded_prompt_token_ids[0:75]
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v_1))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:9], test_embedding_4v_2))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[9:77], default_prompt_embeddings[9:77]))
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids = base_prompt[0:10] + [test_embedding_4v_1_token_id] + base_prompt[10:20] + [test_embedding_4v_2_token_id] + base_prompt[20:-1]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids)
|
||||||
|
padded_prompt_token_ids = [tim.clip_embedder.tokenizer.bos_token_id] + \
|
||||||
|
expanded_prompt_token_ids + \
|
||||||
|
(76 - len(expanded_prompt_token_ids)) * [tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
|
||||||
|
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:11], default_prompt_embeddings[0:11]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[11:15], test_embedding_4v_1))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[15:25], default_prompt_embeddings[15:25]))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[25:29], test_embedding_4v_2))
|
||||||
|
self.assertTrue(torch.equal(overwritten_prompt_embeddings[29:77], default_prompt_embeddings[29:77]))
|
||||||
|
Loading…
Reference in New Issue
Block a user