mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip textual inversion manager (unit tests passing for base stuff + padding)
This commit is contained in:
parent
2e80872e3b
commit
417c2b57d9
@ -15,23 +15,9 @@ from picklescan.scanner import scan_file_path
|
|||||||
PROGRESSIVE_SCALE = 2000
|
PROGRESSIVE_SCALE = 2000
|
||||||
|
|
||||||
|
|
||||||
def get_clip_token_id_for_string(tokenizer, string):
|
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str):
|
||||||
batch_encoding = tokenizer(
|
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||||
string,
|
return token_id
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_length=True,
|
|
||||||
return_overflowing_tokens=False,
|
|
||||||
padding='max_length',
|
|
||||||
return_tensors='pt',
|
|
||||||
)
|
|
||||||
tokens = batch_encoding['input_ids']
|
|
||||||
assert (
|
|
||||||
torch.count_nonzero(tokens - 49407) == 2
|
|
||||||
), f"String '{string}' maps to more than a single token. Please use another string"
|
|
||||||
|
|
||||||
return tokens[0, 1]
|
|
||||||
|
|
||||||
|
|
||||||
def get_bert_token_for_string(tokenizer, string):
|
def get_bert_token_for_string(tokenizer, string):
|
||||||
token = tokenizer(string)
|
token = tokenizer(string)
|
||||||
@ -47,7 +33,7 @@ def get_embedding_for_clip_token(embedder, token):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextualInversion:
|
class TextualInversion:
|
||||||
token_string: str
|
trigger_string: str
|
||||||
token_id: int
|
token_id: int
|
||||||
embedding: torch.Tensor
|
embedding: torch.Tensor
|
||||||
|
|
||||||
@ -58,8 +44,8 @@ class TextualInversion:
|
|||||||
class TextualInversionManager():
|
class TextualInversionManager():
|
||||||
def __init__(self, clip_embedder):
|
def __init__(self, clip_embedder):
|
||||||
self.clip_embedder = clip_embedder
|
self.clip_embedder = clip_embedder
|
||||||
defatul_textual_inversions: list[TextualInversion] = []
|
default_textual_inversions: list[TextualInversion] = []
|
||||||
self.textual_inversions = defatul_textual_inversions
|
self.textual_inversions = default_textual_inversions
|
||||||
|
|
||||||
def load_textual_inversion(self, ckpt_path, full_precision=True):
|
def load_textual_inversion(self, ckpt_path, full_precision=True):
|
||||||
|
|
||||||
@ -89,21 +75,23 @@ class TextualInversionManager():
|
|||||||
embedding = ckpt[token_str]
|
embedding = ckpt[token_str]
|
||||||
self.add_textual_inversion(token_str, embedding, full_precision)
|
self.add_textual_inversion(token_str, embedding, full_precision)
|
||||||
|
|
||||||
def add_textual_inversion(self, token_str, embedding):
|
def add_textual_inversion(self, token_str, embedding) -> int:
|
||||||
"""
|
"""
|
||||||
Add a textual inversion to be recognised.
|
Add a textual inversion to be recognised.
|
||||||
:param token_str: The trigger text in the prompt that activates this textual inversion. Should be unknown to the embedder's tokenizer.
|
:param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||||
:return: The token id of the added embedding.
|
:return: The token id for the added embedding, either existing or newly-added.
|
||||||
"""
|
"""
|
||||||
if token_str in [ti.token_string for ti in self.textual_inversions]:
|
if token_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'")
|
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{token_str}'")
|
||||||
return
|
return
|
||||||
if len(embedding.shape) == 1:
|
if len(embedding.shape) == 1:
|
||||||
embedding = embedding.unsqueeze(0)
|
embedding = embedding.unsqueeze(0)
|
||||||
elif len(embedding.shape) > 2:
|
elif len(embedding.shape) > 2:
|
||||||
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
|
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
|
||||||
|
|
||||||
|
existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
|
||||||
|
if existing_token_id == self.clip_embedder.tokenizer.unk_token_id:
|
||||||
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
|
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
|
||||||
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
|
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
|
||||||
current_token_count = current_embeddings.num_embeddings
|
current_token_count = current_embeddings.num_embeddings
|
||||||
@ -112,15 +100,25 @@ class TextualInversionManager():
|
|||||||
|
|
||||||
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
|
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
|
||||||
self.textual_inversions.append(TextualInversion(
|
self.textual_inversions.append(TextualInversion(
|
||||||
token_string=token_str,
|
trigger_string=token_str,
|
||||||
token_id=token_id,
|
token_id=token_id,
|
||||||
embedding=embedding
|
embedding=embedding
|
||||||
))
|
))
|
||||||
|
|
||||||
return token_id
|
return token_id
|
||||||
|
|
||||||
def has_textual_inversion(self, token_str):
|
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||||
return token_str in [ti.token_string for ti in self.textual_inversions]
|
try:
|
||||||
|
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||||
|
return ti is not None
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
||||||
|
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
||||||
|
|
||||||
|
|
||||||
|
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||||
|
return next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
|
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
@ -131,14 +129,17 @@ class TextualInversionManager():
|
|||||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||||
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
|
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
|
||||||
"""
|
"""
|
||||||
assert(prompt_token_ids[0] != self.clip_embedder.bos_token_id)
|
if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id:
|
||||||
assert(prompt_token_ids[-1] != self.clip_embedder.eos_token_id)
|
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||||
|
if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id:
|
||||||
|
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||||
|
prompt_token_ids = prompt_token_ids[:]
|
||||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||||
if token_id in textual_inversion_token_ids:
|
if token_id in textual_inversion_token_ids:
|
||||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||||
for pad_idx in range(1, textual_inversion.embedding_vector_length):
|
for pad_idx in range(1, textual_inversion.embedding_vector_length):
|
||||||
prompt_token_ids.insert(i+1, self.clip_embedder.pad_token_id)
|
prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id)
|
||||||
|
|
||||||
return prompt_token_ids
|
return prompt_token_ids
|
||||||
|
|
||||||
@ -154,8 +155,9 @@ class TextualInversionManager():
|
|||||||
`prompt_token_ids` (i.e., also already expanded).
|
`prompt_token_ids` (i.e., also already expanded).
|
||||||
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
|
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
|
||||||
"""
|
"""
|
||||||
assert(prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})")
|
assert prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})"
|
||||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||||
|
pad_token_id = self.clip_embedder.pad_token_id
|
||||||
for i, token_id in enumerate(prompt_token_ids):
|
for i, token_id in enumerate(prompt_token_ids):
|
||||||
if token_id == pad_token_id:
|
if token_id == pad_token_id:
|
||||||
continue
|
continue
|
||||||
@ -163,7 +165,7 @@ class TextualInversionManager():
|
|||||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||||
for j in range(0, textual_inversion.embedding_vector_length):
|
for j in range(0, textual_inversion.embedding_vector_length):
|
||||||
# only overwrite the textual inversion token id or the padding token id
|
# only overwrite the textual inversion token id or the padding token id
|
||||||
if prompt_token_ids[i+j] != self.clip_embedder.pad_token_id and prompt_token_ids[i+j] != token_id:
|
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
|
||||||
break
|
break
|
||||||
prompt_embeddings[i+j] = textual_inversion.embedding[j]
|
prompt_embeddings[i+j] = textual_inversion.embedding[j]
|
||||||
|
|
||||||
@ -206,7 +208,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
embedder, 'tokenizer'
|
embedder, 'tokenizer'
|
||||||
): # using Stable Diffusion's CLIP encoder
|
): # using Stable Diffusion's CLIP encoder
|
||||||
self.is_clip = True
|
self.is_clip = True
|
||||||
get_token_for_string = partial(
|
get_token_id_for_string = partial(
|
||||||
get_clip_token_id_for_string, embedder.tokenizer
|
get_clip_token_id_for_string, embedder.tokenizer
|
||||||
)
|
)
|
||||||
get_embedding_for_tkn = partial(
|
get_embedding_for_tkn = partial(
|
||||||
@ -218,7 +220,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
token_dim = 768
|
token_dim = 768
|
||||||
else: # using LDM's BERT encoder
|
else: # using LDM's BERT encoder
|
||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
get_token_for_string = partial(
|
get_token_id_for_string = partial(
|
||||||
get_bert_token_for_string, embedder.tknz_fn
|
get_bert_token_for_string, embedder.tknz_fn
|
||||||
)
|
)
|
||||||
get_embedding_for_tkn = embedder.transformer.token_emb
|
get_embedding_for_tkn = embedder.transformer.token_emb
|
||||||
@ -229,14 +231,14 @@ class EmbeddingManager(nn.Module):
|
|||||||
|
|
||||||
for idx, placeholder_string in enumerate(placeholder_strings):
|
for idx, placeholder_string in enumerate(placeholder_strings):
|
||||||
|
|
||||||
token = get_token_for_string(placeholder_string)
|
token_id = get_token_id_for_string(placeholder_string)
|
||||||
|
|
||||||
if initializer_words and idx < len(initializer_words):
|
if initializer_words and idx < len(initializer_words):
|
||||||
init_word_token = get_token_for_string(initializer_words[idx])
|
init_word_token_id = get_token_id_for_string(initializer_words[idx])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
init_word_embedding = get_embedding_for_tkn(
|
init_word_embedding = get_embedding_for_tkn(
|
||||||
init_word_token.cpu()
|
init_word_token_id.cpu()
|
||||||
)
|
)
|
||||||
|
|
||||||
token_params = torch.nn.Parameter(
|
token_params = torch.nn.Parameter(
|
||||||
@ -261,7 +263,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.string_to_token_dict[placeholder_string] = token
|
self.string_to_token_dict[placeholder_string] = token_id
|
||||||
self.string_to_param_dict[placeholder_string] = token_params
|
self.string_to_param_dict[placeholder_string] = token_params
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -384,6 +386,57 @@ class EmbeddingManager(nn.Module):
|
|||||||
expanded_paths.append(os.path.join(root,name))
|
expanded_paths.append(os.path.join(root,name))
|
||||||
return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')]
|
return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')]
|
||||||
|
|
||||||
|
def _load(self, ckpt_path, full=True):
|
||||||
|
|
||||||
|
scan_result = scan_file_path(ckpt_path)
|
||||||
|
if scan_result.infected_files == 1:
|
||||||
|
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||||
|
print('### For your safety, InvokeAI will not load this embed.')
|
||||||
|
return
|
||||||
|
|
||||||
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
|
# Handle .pt textual inversion files
|
||||||
|
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
||||||
|
filename = os.path.basename(ckpt_path)
|
||||||
|
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
|
||||||
|
if len(ckpt["string_to_token"]) > 1:
|
||||||
|
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
|
||||||
|
|
||||||
|
string_to_param_dict = ckpt['string_to_param']
|
||||||
|
embedding = list(string_to_param_dict.values())[0]
|
||||||
|
self.add_embedding(token_str, embedding, full)
|
||||||
|
|
||||||
|
# Handle .bin textual inversion files from Huggingface Concepts
|
||||||
|
# https://huggingface.co/sd-concepts-library
|
||||||
|
else:
|
||||||
|
for token_str in list(ckpt.keys()):
|
||||||
|
embedding = ckpt[token_str]
|
||||||
|
self.add_embedding(token_str, embedding, full)
|
||||||
|
|
||||||
|
def add_embedding(self, token_str, embedding, full):
|
||||||
|
if token_str in self.string_to_param_dict:
|
||||||
|
print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'")
|
||||||
|
return
|
||||||
|
if not full:
|
||||||
|
embedding = embedding.half()
|
||||||
|
if len(embedding.shape) == 1:
|
||||||
|
embedding = embedding.unsqueeze(0)
|
||||||
|
|
||||||
|
existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||||
|
if existing_token_id == self.embedder.tokenizer.unk_token_id:
|
||||||
|
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
|
||||||
|
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
|
||||||
|
current_token_count = current_embeddings.num_embeddings
|
||||||
|
new_token_count = current_token_count + num_tokens_added
|
||||||
|
self.embedder.transformer.resize_token_embeddings(new_token_count)
|
||||||
|
|
||||||
|
token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||||
|
self.string_to_token_dict[token_str] = token_id
|
||||||
|
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
|
||||||
|
|
||||||
|
def has_embedding_for_token(self, token_str):
|
||||||
|
return token_str in self.string_to_token_dict
|
||||||
|
|
||||||
def get_embedding_norms_squared(self):
|
def get_embedding_norms_squared(self):
|
||||||
all_params = torch.cat(
|
all_params = torch.cat(
|
||||||
|
205
tests/test_textual_inversion.py
Normal file
205
tests/test_textual_inversion.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ldm.modules.embedding_manager import TextualInversionManager
|
||||||
|
|
||||||
|
|
||||||
|
KNOWN_WORDS = ['a', 'b', 'c']
|
||||||
|
UNKNOWN_WORDS = ['d', 'e', 'f']
|
||||||
|
|
||||||
|
class DummyEmbeddingsList(list):
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name == 'num_embeddings':
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
class DummyTransformer:
|
||||||
|
def __init__(self):
|
||||||
|
self.embeddings = DummyEmbeddingsList([0] * len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
def resize_token_embeddings(self, new_size=None):
|
||||||
|
if new_size is None:
|
||||||
|
return self.embeddings
|
||||||
|
else:
|
||||||
|
while len(self.embeddings) > new_size:
|
||||||
|
self.embeddings.pop(-1)
|
||||||
|
while len(self.embeddings) < new_size:
|
||||||
|
self.embeddings.append(0)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTokenizer():
|
||||||
|
def __init__(self):
|
||||||
|
self.tokens = KNOWN_WORDS.copy()
|
||||||
|
self.bos_token_id = 49406
|
||||||
|
self.eos_token_id = 49407
|
||||||
|
self.pad_token_id = 49407
|
||||||
|
self.unk_token_id = 49407
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, token_str):
|
||||||
|
try:
|
||||||
|
return self.tokens.index(token_str)
|
||||||
|
except ValueError:
|
||||||
|
return self.unk_token_id
|
||||||
|
|
||||||
|
def add_tokens(self, token_str):
|
||||||
|
self.tokens.append(token_str)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClipEmbedder:
|
||||||
|
def __init__(self):
|
||||||
|
self.max_length = 77
|
||||||
|
self.transformer = DummyTransformer()
|
||||||
|
self.tokenizer = DummyTokenizer()
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionManagerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
|
def test_construction(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
|
||||||
|
def test_add_embedding_for_known_token(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
test_embedding = torch.randn([1, 768])
|
||||||
|
test_embedding_name = KNOWN_WORDS[0]
|
||||||
|
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||||
|
|
||||||
|
pre_embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None))
|
||||||
|
|
||||||
|
token_id = tim.add_textual_inversion(test_embedding_name, test_embedding)
|
||||||
|
self.assertEqual(token_id, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# check adding 'test' did not create a new word
|
||||||
|
embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count, embeddings_count)
|
||||||
|
|
||||||
|
# check it was added
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||||
|
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
||||||
|
self.assertEqual(textual_inversion.token_id, token_id)
|
||||||
|
|
||||||
|
def test_add_embedding_for_unknown_token(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
test_embedding_1 = torch.randn([1, 768])
|
||||||
|
test_embedding_name_1 = UNKNOWN_WORDS[0]
|
||||||
|
|
||||||
|
pre_embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None))
|
||||||
|
|
||||||
|
added_token_id_1 = tim.add_textual_inversion(test_embedding_name_1, test_embedding_1)
|
||||||
|
# new token id should get added on the end
|
||||||
|
self.assertEqual(added_token_id_1, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# check adding did create a new word
|
||||||
|
embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||||
|
|
||||||
|
# check it was added
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
||||||
|
textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == added_token_id_1)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
||||||
|
self.assertEqual(textual_inversion.token_id, added_token_id_1)
|
||||||
|
|
||||||
|
# add another one
|
||||||
|
test_embedding_2 = torch.randn([1, 768])
|
||||||
|
test_embedding_name_2 = UNKNOWN_WORDS[1]
|
||||||
|
|
||||||
|
pre_embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None))
|
||||||
|
|
||||||
|
added_token_id_2 = tim.add_textual_inversion(test_embedding_name_2, test_embedding_2)
|
||||||
|
self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1)
|
||||||
|
|
||||||
|
# check adding did create a new word
|
||||||
|
embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||||
|
|
||||||
|
# check it was added
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2))
|
||||||
|
textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == added_token_id_2)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2)
|
||||||
|
self.assertEqual(textual_inversion.token_id, added_token_id_2)
|
||||||
|
|
||||||
|
# check the old one is still there
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
||||||
|
textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == added_token_id_1)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
||||||
|
self.assertEqual(textual_inversion.token_id, added_token_id_1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pad_raises_on_eos_bos(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id,
|
||||||
|
0, 1, 2,
|
||||||
|
tim.clip_embedder.tokenizer.eos_token_id]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_with_eos_bos)
|
||||||
|
|
||||||
|
def test_pad_tokens_list_vector_length_1(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
prompt_token_ids = [0, 1, 2]
|
||||||
|
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
test_embedding_1v = torch.randn([1, 768])
|
||||||
|
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
|
||||||
|
test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v)
|
||||||
|
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_1v_append)
|
||||||
|
self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_1v_prepend)
|
||||||
|
self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_1v_insert)
|
||||||
|
self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
def test_pad_tokens_list_vector_length_2(self):
|
||||||
|
tim = TextualInversionManager(DummyClipEmbedder())
|
||||||
|
prompt_token_ids = [0, 1, 2]
|
||||||
|
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
test_embedding_2v = torch.randn([2, 768])
|
||||||
|
test_embedding_2v_token = "<inversion-trigger-vector-length-2>"
|
||||||
|
test_embedding_2v_token_id = tim.add_textual_inversion(test_embedding_2v_token, test_embedding_2v)
|
||||||
|
self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_2v_append)
|
||||||
|
self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id], expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_2v_prepend)
|
||||||
|
self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual([test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id] + prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_2v_insert)
|
||||||
|
self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id] + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
|||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ldm.modules.embedding_manager import TextualInversionManager
|
|
||||||
|
|
||||||
|
|
||||||
class DummyClipEmbedder:
|
|
||||||
max_length = 77
|
|
||||||
bos_token_id = 49406
|
|
||||||
eos_token_id = 49407
|
|
||||||
|
|
||||||
class TextualInversionManagerTestCase(unittest.TestCase):
|
|
||||||
|
|
||||||
|
|
||||||
def test_construction(self):
|
|
||||||
tim = TextualInversionManager(DummyClipEmbedder())
|
|
||||||
|
|
||||||
def test_add_embedding(self):
|
|
||||||
tim = TextualInversionManager(DummyClipEmbedder())
|
|
||||||
test_embedding = torch.random([1, 768])
|
|
||||||
test_embedding_name = "test"
|
|
||||||
token_id = tim.add_textual_inversion(test_embedding_name, test_embedding)
|
|
||||||
self.assertTrue(tim.has_textual_inversion(test_embedding_name))
|
|
||||||
|
|
||||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == token_id)
|
|
||||||
self.assertIsNotNone(textual_inversion)
|
|
||||||
self.assertEqual(textual_inversion.embedding, test_embedding)
|
|
||||||
self.assertEqual(textual_inversion.token_string, test_embedding_name)
|
|
||||||
self.assertEqual(textual_inversion.token_id, token_id)
|
|
||||||
|
|
||||||
def test_pad_tokens_list(self):
|
|
||||||
tim = TextualInversionManager(DummyClipEmbedder())
|
|
||||||
prompt_token_ids = [DummyClipEmbedder.bos_token_id, 0, 1, 2, DummyClipEmbedder.eos_token_id]
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
test_embedding = torch.random([1, 768])
|
|
||||||
test_embedding_name = "test"
|
|
||||||
tim.add_textual_inversion("<token>",
|
|
||||||
|
|
||||||
self.assertRaises()
|
|
Loading…
Reference in New Issue
Block a user