From 2e80872e3b6f7fd7d8eb8928822bd824b63cb2ff Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 15 Dec 2022 10:57:57 +0100 Subject: [PATCH] wip new TextualInversionManager --- ldm/modules/embedding_manager.py | 188 ++++++++++++++++++++++--------- tests/text_textual_inversion.py | 43 +++++++ 2 files changed, 177 insertions(+), 54 deletions(-) create mode 100644 tests/text_textual_inversion.py diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 239fd346ab..7f53f83039 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -1,6 +1,7 @@ import os.path from cmath import log import torch +from attr import dataclass from torch import nn import sys @@ -14,7 +15,7 @@ from picklescan.scanner import scan_file_path PROGRESSIVE_SCALE = 2000 -def get_clip_token_for_string(tokenizer, string): +def get_clip_token_id_for_string(tokenizer, string): batch_encoding = tokenizer( string, truncation=True, @@ -25,9 +26,9 @@ def get_clip_token_for_string(tokenizer, string): return_tensors='pt', ) tokens = batch_encoding['input_ids'] - """ assert ( + assert ( torch.count_nonzero(tokens - 49407) == 2 - ), f"String '{string}' maps to more than a single token. Please use another string" """ + ), f"String '{string}' maps to more than a single token. Please use another string" return tokens[0, 1] @@ -44,6 +45,134 @@ def get_bert_token_for_string(tokenizer, string): def get_embedding_for_clip_token(embedder, token): return embedder(token.unsqueeze(0))[0, 0] +@dataclass +class TextualInversion: + token_string: str + token_id: int + embedding: torch.Tensor + + @property + def embedding_vector_length(self) -> int: + return self.embedding.shape[0] + +class TextualInversionManager(): + def __init__(self, clip_embedder): + self.clip_embedder = clip_embedder + defatul_textual_inversions: list[TextualInversion] = [] + self.textual_inversions = defatul_textual_inversions + + def load_textual_inversion(self, ckpt_path, full_precision=True): + + scan_result = scan_file_path(ckpt_path) + if scan_result.infected_files == 1: + print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') + print('### For your safety, InvokeAI will not load this embed.') + return + + ckpt = torch.load(ckpt_path, map_location='cpu') + + # Handle .pt textual inversion files + if 'string_to_token' in ckpt and 'string_to_param' in ckpt: + filename = os.path.basename(ckpt_path) + token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension + if len(ckpt["string_to_token"]) > 1: + print(f">> {ckpt_path} has >1 embedding, only the first will be used") + + string_to_param_dict = ckpt['string_to_param'] + embedding = list(string_to_param_dict.values())[0] + self.add_textual_inversion(token_str, embedding, full_precision) + + # Handle .bin textual inversion files from Huggingface Concepts + # https://huggingface.co/sd-concepts-library + else: + for token_str in list(ckpt.keys()): + embedding = ckpt[token_str] + self.add_textual_inversion(token_str, embedding, full_precision) + + def add_textual_inversion(self, token_str, embedding): + """ + Add a textual inversion to be recognised. + :param token_str: The trigger text in the prompt that activates this textual inversion. Should be unknown to the embedder's tokenizer. + :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears. + :return: The token id of the added embedding. + """ + if token_str in [ti.token_string for ti in self.textual_inversions]: + print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'") + return + if len(embedding.shape) == 1: + embedding = embedding.unsqueeze(0) + elif len(embedding.shape) > 2: + raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2") + + num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str) + current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None) + current_token_count = current_embeddings.num_embeddings + new_token_count = current_token_count + num_tokens_added + self.clip_embedder.transformer.resize_token_embeddings(new_token_count) + + token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str) + self.textual_inversions.append(TextualInversion( + token_string=token_str, + token_id=token_id, + embedding=embedding + )) + + return token_id + + def has_textual_inversion(self, token_str): + return token_str in [ti.token_string for ti in self.textual_inversions] + + def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]: + """ + Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes. + + :param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers. + :param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1. + :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too + long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids. + """ + assert(prompt_token_ids[0] != self.clip_embedder.bos_token_id) + assert(prompt_token_ids[-1] != self.clip_embedder.eos_token_id) + textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] + for i, token_id in reversed(list(enumerate(prompt_token_ids))): + if token_id in textual_inversion_token_ids: + textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) + for pad_idx in range(1, textual_inversion.embedding_vector_length): + prompt_token_ids.insert(i+1, self.clip_embedder.pad_token_id) + + return prompt_token_ids + + def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor: + """ + For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding + row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite + subsequent rows in `prompt_embeddings` as well. + + :param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght + >1 (call `expand_textual_inversion_token_ids()` to do this) + :param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in + `prompt_token_ids` (i.e., also already expanded). + :return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings. + """ + assert(prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})") + textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] + for i, token_id in enumerate(prompt_token_ids): + if token_id == pad_token_id: + continue + if token_id in textual_inversion_token_ids: + textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) + for j in range(0, textual_inversion.embedding_vector_length): + # only overwrite the textual inversion token id or the padding token id + if prompt_token_ids[i+j] != self.clip_embedder.pad_token_id and prompt_token_ids[i+j] != token_id: + break + prompt_embeddings[i+j] = textual_inversion.embedding[j] + + return prompt_embeddings + + + + + class EmbeddingManager(nn.Module): def __init__( self, @@ -78,7 +207,7 @@ class EmbeddingManager(nn.Module): ): # using Stable Diffusion's CLIP encoder self.is_clip = True get_token_for_string = partial( - get_clip_token_for_string, embedder.tokenizer + get_clip_token_id_for_string, embedder.tokenizer ) get_embedding_for_tkn = partial( get_embedding_for_clip_token, @@ -241,7 +370,7 @@ class EmbeddingManager(nn.Module): # both will be stored in this dictionary for term in self.string_to_param_dict.keys(): term = term.strip('<').strip('>') - self.concepts_loaded[term] = True + self.concepts_loaded[term] = True print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}') def _expand_directories(self, paths:list[str]): @@ -255,55 +384,6 @@ class EmbeddingManager(nn.Module): expanded_paths.append(os.path.join(root,name)) return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')] - def _load(self, ckpt_path, full=True): - - scan_result = scan_file_path(ckpt_path) - if scan_result.infected_files == 1: - print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') - print('### For your safety, InvokeAI will not load this embed.') - return - - ckpt = torch.load(ckpt_path, map_location='cpu') - - # Handle .pt textual inversion files - if 'string_to_token' in ckpt and 'string_to_param' in ckpt: - filename = os.path.basename(ckpt_path) - token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension - if len(ckpt["string_to_token"]) > 1: - print(f">> {ckpt_path} has >1 embedding, only the first will be used") - - string_to_param_dict = ckpt['string_to_param'] - embedding = list(string_to_param_dict.values())[0] - self.add_embedding(token_str, embedding, full) - - # Handle .bin textual inversion files from Huggingface Concepts - # https://huggingface.co/sd-concepts-library - else: - for token_str in list(ckpt.keys()): - embedding = ckpt[token_str] - self.add_embedding(token_str, embedding, full) - - def add_embedding(self, token_str, embedding, full): - if token_str in self.string_to_param_dict: - print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'") - return - if not full: - embedding = embedding.half() - if len(embedding.shape) == 1: - embedding = embedding.unsqueeze(0) - - num_tokens_added = self.embedder.tokenizer.add_tokens(token_str) - current_embeddings = self.embedder.transformer.resize_token_embeddings(None) - current_token_count = current_embeddings.num_embeddings - new_token_count = current_token_count + num_tokens_added - self.embedder.transformer.resize_token_embeddings(new_token_count) - - token = get_clip_token_for_string(self.embedder.tokenizer, token_str) - self.string_to_token_dict[token_str] = token - self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding) - - def has_embedding_for_token(self, token_str): - return token_str in self.string_to_token_dict def get_embedding_norms_squared(self): all_params = torch.cat( diff --git a/tests/text_textual_inversion.py b/tests/text_textual_inversion.py new file mode 100644 index 0000000000..c7f7105252 --- /dev/null +++ b/tests/text_textual_inversion.py @@ -0,0 +1,43 @@ + +import unittest + +import torch + +from ldm.modules.embedding_manager import TextualInversionManager + + +class DummyClipEmbedder: + max_length = 77 + bos_token_id = 49406 + eos_token_id = 49407 + +class TextualInversionManagerTestCase(unittest.TestCase): + + + def test_construction(self): + tim = TextualInversionManager(DummyClipEmbedder()) + + def test_add_embedding(self): + tim = TextualInversionManager(DummyClipEmbedder()) + test_embedding = torch.random([1, 768]) + test_embedding_name = "test" + token_id = tim.add_textual_inversion(test_embedding_name, test_embedding) + self.assertTrue(tim.has_textual_inversion(test_embedding_name)) + + textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == token_id) + self.assertIsNotNone(textual_inversion) + self.assertEqual(textual_inversion.embedding, test_embedding) + self.assertEqual(textual_inversion.token_string, test_embedding_name) + self.assertEqual(textual_inversion.token_id, token_id) + + def test_pad_tokens_list(self): + tim = TextualInversionManager(DummyClipEmbedder()) + prompt_token_ids = [DummyClipEmbedder.bos_token_id, 0, 1, 2, DummyClipEmbedder.eos_token_id] + expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids) + self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) + + test_embedding = torch.random([1, 768]) + test_embedding_name = "test" + tim.add_textual_inversion("", + + self.assertRaises()