From 417c2b57d90924a839616bfb66804faab8039e4c Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 15 Dec 2022 12:30:55 +0100 Subject: [PATCH] wip textual inversion manager (unit tests passing for base stuff + padding) --- ldm/modules/embedding_manager.py | 143 ++++++++++++++------- tests/test_textual_inversion.py | 205 +++++++++++++++++++++++++++++++ tests/text_textual_inversion.py | 43 ------- 3 files changed, 303 insertions(+), 88 deletions(-) create mode 100644 tests/test_textual_inversion.py delete mode 100644 tests/text_textual_inversion.py diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 7f53f83039..285ef7429f 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -15,23 +15,9 @@ from picklescan.scanner import scan_file_path PROGRESSIVE_SCALE = 2000 -def get_clip_token_id_for_string(tokenizer, string): - batch_encoding = tokenizer( - string, - truncation=True, - max_length=77, - return_length=True, - return_overflowing_tokens=False, - padding='max_length', - return_tensors='pt', - ) - tokens = batch_encoding['input_ids'] - assert ( - torch.count_nonzero(tokens - 49407) == 2 - ), f"String '{string}' maps to more than a single token. Please use another string" - - return tokens[0, 1] - +def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str): + token_id = tokenizer.convert_tokens_to_ids(token_str) + return token_id def get_bert_token_for_string(tokenizer, string): token = tokenizer(string) @@ -47,7 +33,7 @@ def get_embedding_for_clip_token(embedder, token): @dataclass class TextualInversion: - token_string: str + trigger_string: str token_id: int embedding: torch.Tensor @@ -58,8 +44,8 @@ class TextualInversion: class TextualInversionManager(): def __init__(self, clip_embedder): self.clip_embedder = clip_embedder - defatul_textual_inversions: list[TextualInversion] = [] - self.textual_inversions = defatul_textual_inversions + default_textual_inversions: list[TextualInversion] = [] + self.textual_inversions = default_textual_inversions def load_textual_inversion(self, ckpt_path, full_precision=True): @@ -89,38 +75,50 @@ class TextualInversionManager(): embedding = ckpt[token_str] self.add_textual_inversion(token_str, embedding, full_precision) - def add_textual_inversion(self, token_str, embedding): + def add_textual_inversion(self, token_str, embedding) -> int: """ Add a textual inversion to be recognised. - :param token_str: The trigger text in the prompt that activates this textual inversion. Should be unknown to the embedder's tokenizer. + :param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears. - :return: The token id of the added embedding. + :return: The token id for the added embedding, either existing or newly-added. """ - if token_str in [ti.token_string for ti in self.textual_inversions]: - print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'") + if token_str in [ti.trigger_string for ti in self.textual_inversions]: + print(f">> TextualInversionManager refusing to overwrite already-loaded token '{token_str}'") return if len(embedding.shape) == 1: embedding = embedding.unsqueeze(0) elif len(embedding.shape) > 2: raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2") - num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str) - current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None) - current_token_count = current_embeddings.num_embeddings - new_token_count = current_token_count + num_tokens_added - self.clip_embedder.transformer.resize_token_embeddings(new_token_count) + existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str) + if existing_token_id == self.clip_embedder.tokenizer.unk_token_id: + num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str) + current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None) + current_token_count = current_embeddings.num_embeddings + new_token_count = current_token_count + num_tokens_added + self.clip_embedder.transformer.resize_token_embeddings(new_token_count) token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str) self.textual_inversions.append(TextualInversion( - token_string=token_str, + trigger_string=token_str, token_id=token_id, embedding=embedding )) - return token_id - def has_textual_inversion(self, token_str): - return token_str in [ti.token_string for ti in self.textual_inversions] + def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool: + try: + ti = self.get_textual_inversion_for_trigger_string(trigger_string) + return ti is not None + except StopIteration: + return False + + def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion: + return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string) + + + def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion: + return next(ti for ti in self.textual_inversions if ti.token_id == token_id) def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]: """ @@ -131,14 +129,17 @@ class TextualInversionManager(): :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids. """ - assert(prompt_token_ids[0] != self.clip_embedder.bos_token_id) - assert(prompt_token_ids[-1] != self.clip_embedder.eos_token_id) + if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id: + raise ValueError("prompt_token_ids must not start with bos_token_id") + if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id: + raise ValueError("prompt_token_ids must not end with eos_token_id") textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] + prompt_token_ids = prompt_token_ids[:] for i, token_id in reversed(list(enumerate(prompt_token_ids))): if token_id in textual_inversion_token_ids: textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) for pad_idx in range(1, textual_inversion.embedding_vector_length): - prompt_token_ids.insert(i+1, self.clip_embedder.pad_token_id) + prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id) return prompt_token_ids @@ -154,8 +155,9 @@ class TextualInversionManager(): `prompt_token_ids` (i.e., also already expanded). :return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings. """ - assert(prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})") + assert prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})" textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] + pad_token_id = self.clip_embedder.pad_token_id for i, token_id in enumerate(prompt_token_ids): if token_id == pad_token_id: continue @@ -163,7 +165,7 @@ class TextualInversionManager(): textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) for j in range(0, textual_inversion.embedding_vector_length): # only overwrite the textual inversion token id or the padding token id - if prompt_token_ids[i+j] != self.clip_embedder.pad_token_id and prompt_token_ids[i+j] != token_id: + if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id: break prompt_embeddings[i+j] = textual_inversion.embedding[j] @@ -206,7 +208,7 @@ class EmbeddingManager(nn.Module): embedder, 'tokenizer' ): # using Stable Diffusion's CLIP encoder self.is_clip = True - get_token_for_string = partial( + get_token_id_for_string = partial( get_clip_token_id_for_string, embedder.tokenizer ) get_embedding_for_tkn = partial( @@ -218,7 +220,7 @@ class EmbeddingManager(nn.Module): token_dim = 768 else: # using LDM's BERT encoder self.is_clip = False - get_token_for_string = partial( + get_token_id_for_string = partial( get_bert_token_for_string, embedder.tknz_fn ) get_embedding_for_tkn = embedder.transformer.token_emb @@ -229,14 +231,14 @@ class EmbeddingManager(nn.Module): for idx, placeholder_string in enumerate(placeholder_strings): - token = get_token_for_string(placeholder_string) + token_id = get_token_id_for_string(placeholder_string) if initializer_words and idx < len(initializer_words): - init_word_token = get_token_for_string(initializer_words[idx]) + init_word_token_id = get_token_id_for_string(initializer_words[idx]) with torch.no_grad(): init_word_embedding = get_embedding_for_tkn( - init_word_token.cpu() + init_word_token_id.cpu() ) token_params = torch.nn.Parameter( @@ -261,7 +263,7 @@ class EmbeddingManager(nn.Module): ) ) - self.string_to_token_dict[placeholder_string] = token + self.string_to_token_dict[placeholder_string] = token_id self.string_to_param_dict[placeholder_string] = token_params def forward( @@ -384,6 +386,57 @@ class EmbeddingManager(nn.Module): expanded_paths.append(os.path.join(root,name)) return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')] + def _load(self, ckpt_path, full=True): + + scan_result = scan_file_path(ckpt_path) + if scan_result.infected_files == 1: + print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') + print('### For your safety, InvokeAI will not load this embed.') + return + + ckpt = torch.load(ckpt_path, map_location='cpu') + + # Handle .pt textual inversion files + if 'string_to_token' in ckpt and 'string_to_param' in ckpt: + filename = os.path.basename(ckpt_path) + token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension + if len(ckpt["string_to_token"]) > 1: + print(f">> {ckpt_path} has >1 embedding, only the first will be used") + + string_to_param_dict = ckpt['string_to_param'] + embedding = list(string_to_param_dict.values())[0] + self.add_embedding(token_str, embedding, full) + + # Handle .bin textual inversion files from Huggingface Concepts + # https://huggingface.co/sd-concepts-library + else: + for token_str in list(ckpt.keys()): + embedding = ckpt[token_str] + self.add_embedding(token_str, embedding, full) + + def add_embedding(self, token_str, embedding, full): + if token_str in self.string_to_param_dict: + print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'") + return + if not full: + embedding = embedding.half() + if len(embedding.shape) == 1: + embedding = embedding.unsqueeze(0) + + existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str) + if existing_token_id == self.embedder.tokenizer.unk_token_id: + num_tokens_added = self.embedder.tokenizer.add_tokens(token_str) + current_embeddings = self.embedder.transformer.resize_token_embeddings(None) + current_token_count = current_embeddings.num_embeddings + new_token_count = current_token_count + num_tokens_added + self.embedder.transformer.resize_token_embeddings(new_token_count) + + token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str) + self.string_to_token_dict[token_str] = token_id + self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding) + + def has_embedding_for_token(self, token_str): + return token_str in self.string_to_token_dict def get_embedding_norms_squared(self): all_params = torch.cat( diff --git a/tests/test_textual_inversion.py b/tests/test_textual_inversion.py new file mode 100644 index 0000000000..17b4da0c88 --- /dev/null +++ b/tests/test_textual_inversion.py @@ -0,0 +1,205 @@ + +import unittest + +import torch + +from ldm.modules.embedding_manager import TextualInversionManager + + +KNOWN_WORDS = ['a', 'b', 'c'] +UNKNOWN_WORDS = ['d', 'e', 'f'] + +class DummyEmbeddingsList(list): + def __getattr__(self, name): + if name == 'num_embeddings': + return len(self) + +class DummyTransformer: + def __init__(self): + self.embeddings = DummyEmbeddingsList([0] * len(KNOWN_WORDS)) + + def resize_token_embeddings(self, new_size=None): + if new_size is None: + return self.embeddings + else: + while len(self.embeddings) > new_size: + self.embeddings.pop(-1) + while len(self.embeddings) < new_size: + self.embeddings.append(0) + + +class DummyTokenizer(): + def __init__(self): + self.tokens = KNOWN_WORDS.copy() + self.bos_token_id = 49406 + self.eos_token_id = 49407 + self.pad_token_id = 49407 + self.unk_token_id = 49407 + + def convert_tokens_to_ids(self, token_str): + try: + return self.tokens.index(token_str) + except ValueError: + return self.unk_token_id + + def add_tokens(self, token_str): + self.tokens.append(token_str) + return 1 + + +class DummyClipEmbedder: + def __init__(self): + self.max_length = 77 + self.transformer = DummyTransformer() + self.tokenizer = DummyTokenizer() + + +class TextualInversionManagerTestCase(unittest.TestCase): + + + def test_construction(self): + tim = TextualInversionManager(DummyClipEmbedder()) + + def test_add_embedding_for_known_token(self): + tim = TextualInversionManager(DummyClipEmbedder()) + test_embedding = torch.randn([1, 768]) + test_embedding_name = KNOWN_WORDS[0] + self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) + + pre_embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None)) + + token_id = tim.add_textual_inversion(test_embedding_name, test_embedding) + self.assertEqual(token_id, 0) + + + # check adding 'test' did not create a new word + embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None)) + self.assertEqual(pre_embeddings_count, embeddings_count) + + # check it was added + self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) + textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name) + self.assertIsNotNone(textual_inversion) + self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding)) + self.assertEqual(textual_inversion.trigger_string, test_embedding_name) + self.assertEqual(textual_inversion.token_id, token_id) + + def test_add_embedding_for_unknown_token(self): + tim = TextualInversionManager(DummyClipEmbedder()) + test_embedding_1 = torch.randn([1, 768]) + test_embedding_name_1 = UNKNOWN_WORDS[0] + + pre_embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None)) + + added_token_id_1 = tim.add_textual_inversion(test_embedding_name_1, test_embedding_1) + # new token id should get added on the end + self.assertEqual(added_token_id_1, len(KNOWN_WORDS)) + + # check adding did create a new word + embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None)) + self.assertEqual(pre_embeddings_count+1, embeddings_count) + + # check it was added + self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1)) + textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == added_token_id_1) + self.assertIsNotNone(textual_inversion) + self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1)) + self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1) + self.assertEqual(textual_inversion.token_id, added_token_id_1) + + # add another one + test_embedding_2 = torch.randn([1, 768]) + test_embedding_name_2 = UNKNOWN_WORDS[1] + + pre_embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None)) + + added_token_id_2 = tim.add_textual_inversion(test_embedding_name_2, test_embedding_2) + self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1) + + # check adding did create a new word + embeddings_count = len(tim.clip_embedder.transformer.resize_token_embeddings(None)) + self.assertEqual(pre_embeddings_count+1, embeddings_count) + + # check it was added + self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2)) + textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == added_token_id_2) + self.assertIsNotNone(textual_inversion) + self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2)) + self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2) + self.assertEqual(textual_inversion.token_id, added_token_id_2) + + # check the old one is still there + self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1)) + textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == added_token_id_1) + self.assertIsNotNone(textual_inversion) + self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1)) + self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1) + self.assertEqual(textual_inversion.token_id, added_token_id_1) + + + def test_pad_raises_on_eos_bos(self): + tim = TextualInversionManager(DummyClipEmbedder()) + prompt_token_ids_with_eos_bos = [tim.clip_embedder.tokenizer.bos_token_id, + 0, 1, 2, + tim.clip_embedder.tokenizer.eos_token_id] + with self.assertRaises(ValueError): + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_with_eos_bos) + + def test_pad_tokens_list_vector_length_1(self): + tim = TextualInversionManager(DummyClipEmbedder()) + prompt_token_ids = [0, 1, 2] + + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids) + self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) + + test_embedding_1v = torch.randn([1, 768]) + test_embedding_1v_token = "" + test_embedding_1v_token_id = tim.add_textual_inversion(test_embedding_1v_token, test_embedding_1v) + self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS)) + + # at the end + prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_1v_append) + self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids) + + # at the start + prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_1v_prepend) + self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids) + + # in the middle + prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_1v_insert) + self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids) + + def test_pad_tokens_list_vector_length_2(self): + tim = TextualInversionManager(DummyClipEmbedder()) + prompt_token_ids = [0, 1, 2] + + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids) + self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) + + test_embedding_2v = torch.randn([2, 768]) + test_embedding_2v_token = "" + test_embedding_2v_token_id = tim.add_textual_inversion(test_embedding_2v_token, test_embedding_2v) + self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS)) + + # at the end + prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_2v_append) + self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids) + self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id], expanded_prompt_token_ids) + + # at the start + prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_2v_prepend) + self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids) + self.assertEqual([test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id] + prompt_token_ids, expanded_prompt_token_ids) + + # in the middle + prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids_2v_insert) + self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids) + self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id, tim.clip_embedder.tokenizer.pad_token_id] + prompt_token_ids[2:3], expanded_prompt_token_ids) + + diff --git a/tests/text_textual_inversion.py b/tests/text_textual_inversion.py deleted file mode 100644 index c7f7105252..0000000000 --- a/tests/text_textual_inversion.py +++ /dev/null @@ -1,43 +0,0 @@ - -import unittest - -import torch - -from ldm.modules.embedding_manager import TextualInversionManager - - -class DummyClipEmbedder: - max_length = 77 - bos_token_id = 49406 - eos_token_id = 49407 - -class TextualInversionManagerTestCase(unittest.TestCase): - - - def test_construction(self): - tim = TextualInversionManager(DummyClipEmbedder()) - - def test_add_embedding(self): - tim = TextualInversionManager(DummyClipEmbedder()) - test_embedding = torch.random([1, 768]) - test_embedding_name = "test" - token_id = tim.add_textual_inversion(test_embedding_name, test_embedding) - self.assertTrue(tim.has_textual_inversion(test_embedding_name)) - - textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == token_id) - self.assertIsNotNone(textual_inversion) - self.assertEqual(textual_inversion.embedding, test_embedding) - self.assertEqual(textual_inversion.token_string, test_embedding_name) - self.assertEqual(textual_inversion.token_id, token_id) - - def test_pad_tokens_list(self): - tim = TextualInversionManager(DummyClipEmbedder()) - prompt_token_ids = [DummyClipEmbedder.bos_token_id, 0, 1, 2, DummyClipEmbedder.eos_token_id] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids) - self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) - - test_embedding = torch.random([1, 768]) - test_embedding_name = "test" - tim.add_textual_inversion("", - - self.assertRaises()