From 664a6e9e146b42d96703f0cc8baf8f5efec04ee1 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 16 Dec 2022 12:48:38 +0100 Subject: [PATCH] use TextualInversionManager in place of embeddings (wip, doesn't work) --- ldm/generate.py | 21 +-- ldm/invoke/CLI.py | 11 +- ldm/invoke/concepts_lib.py | 8 +- ldm/invoke/conditioning.py | 2 +- ldm/invoke/readline.py | 4 +- ldm/models/diffusion/ddpm.py | 4 + ldm/modules/embedding_manager.py | 172 +-------------------- ldm/modules/encoders/modules.py | 111 +++++++++----- ldm/modules/textual_inversion_manager.py | 185 +++++++++++++++++++++++ 9 files changed, 290 insertions(+), 228 deletions(-) create mode 100644 ldm/modules/textual_inversion_manager.py diff --git a/ldm/generate.py b/ldm/generate.py index db36717135..c6e51f65f2 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -22,6 +22,7 @@ import skimage from omegaconf import OmegaConf import ldm.invoke.conditioning +from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.invoke.generator.base import downsampling from PIL import Image, ImageOps from torch import nn @@ -41,7 +42,6 @@ from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale -from ldm.invoke.concepts_lib import Concepts def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): @@ -438,7 +438,7 @@ class Generate: self._set_sampler() # apply the concepts library to the prompt - prompt = self.concept_lib().replace_concepts_with_triggers(prompt, lambda concepts: self.load_concepts(concepts)) + prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts)) # bit of a hack to change the cached sampler's karras threshold to # whatever the user asked for @@ -862,19 +862,22 @@ class Generate: seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) if self.embedding_path is not None: - self.model.embedding_manager.load( - self.embedding_path, self.precision == 'float32' or self.precision == 'autocast' - ) + for root, _, files in os.walk(self.embedding_path): + for name in files: + ti_path = os.path.join(root, name) + self.model.textual_inversion_manager.load_textual_inversion(ti_path) + print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}') self._set_sampler() self.model_name = model_name return self.model - def load_concepts(self,concepts:list[str]): - self.model.embedding_manager.load_concepts(concepts, self.precision=='float32' or self.precision=='autocast') + def load_huggingface_concepts(self, concepts:list[str]): + self.model.textual_inversion_manager.load_huggingface_concepts(concepts) - def concept_lib(self)->Concepts: - return self.model.embedding_manager.concepts_library + @property + def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary: + return self.model.textual_inversion_manager.hf_concepts_library def correct_colors(self, image_list, diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index b5e32d7222..43fa47f597 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -16,7 +16,7 @@ from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_f from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log -from ldm.invoke.concepts_lib import Concepts +from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from omegaconf import OmegaConf from pathlib import Path import pyparsing @@ -133,6 +133,10 @@ def main(): main_loop(gen, opt) except KeyboardInterrupt: print("\ngoodbye!") + except Exception: + print(">> An error occurred:") + traceback.print_exc() + # TODO: main_loop() has gotten busy. Needs to be refactored. def main_loop(gen, opt): @@ -310,7 +314,7 @@ def main_loop(gen, opt): if use_prefix is not None: prefix = use_prefix postprocessed = upscaled if upscaled else operation=='postprocess' - opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers + opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers filename, formatted_dream_prompt = prepare_image_metadata( opt, prefix, @@ -809,7 +813,8 @@ def add_embedding_terms(gen,completer): Called after setting the model, updates the autocompleter with any terms loaded by the embedding manager. ''' - completer.add_embedding_terms(gen.model.embedding_manager.list_terms()) + trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings() + completer.add_embedding_terms(trigger_strings) def split_variations(variations_string) -> list: # shotgun parsing, woo diff --git a/ldm/invoke/concepts_lib.py b/ldm/invoke/concepts_lib.py index 942406acd3..679c1d9a62 100644 --- a/ldm/invoke/concepts_lib.py +++ b/ldm/invoke/concepts_lib.py @@ -12,7 +12,7 @@ from urllib import request, error as ul_error from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi from ldm.invoke.globals import Globals -class Concepts(object): +class HuggingFaceConceptsLibrary(object): def __init__(self, root=None): ''' Initialize the Concepts object. May optionally pass a root directory. @@ -116,11 +116,11 @@ class Concepts(object): self.download_concept(concept_name) path = os.path.join(self._concept_path(concept_name), file_name) return path if os.path.exists(path) else None - + def concept_is_downloaded(self, concept_name)->bool: concept_directory = self._concept_path(concept_name) return os.path.exists(concept_directory) - + def download_concept(self,concept_name)->bool: repo_id = self._concept_id(concept_name) dest = self._concept_path(concept_name) @@ -133,7 +133,7 @@ class Concepts(object): os.makedirs(dest, exist_ok=True) succeeded = True - + bytes = 0 def tally_download_size(chunk, size, total): nonlocal bytes diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index aba329ccde..d70ce43589 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -231,7 +231,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm def _get_tokens_length(model, fragments: list[Fragment]): fragment_texts = [x.text for x in fragments] - tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) + tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False) return sum([len(x) for x in tokens]) diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index 9c21180dea..376e009296 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -12,7 +12,7 @@ import os import re import atexit from ldm.invoke.args import Args -from ldm.invoke.concepts_lib import Concepts +from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.invoke.globals import Globals # ---------------readline utilities--------------------- @@ -276,7 +276,7 @@ class Completer(object): def _concept_completions(self, text, state): if self.concepts is None: - self.concepts = set(Concepts().list_concepts()) + self.concepts = set(HuggingFaceConceptsLibrary().list_concepts()) self.embedding_terms.update(self.concepts) partial = text[1:] # this removes the leading '<' diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index e2e38459ff..d9fa762f0b 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only from omegaconf import ListConfig import urllib +from ldm.modules.textual_inversion_manager import TextualInversionManager from ldm.util import ( log_txt_as_img, exists, @@ -678,6 +679,9 @@ class LatentDiffusion(DDPM): self.embedding_manager = self.instantiate_embedding_manager( personalization_config, self.cond_stage_model ) + self.textual_inversion_manager = TextualInversionManager(self.cond_stage_model, full_precision=True) + # this circular component dependency is gross and bad, needs to be rethought + self.cond_stage_model.set_textual_inversion_manager(self.textual_inversion_manager) self.emb_ckpt_counter = 0 diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 613bf7a430..3f81918bd4 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -6,7 +6,7 @@ from torch import nn import sys -from ldm.invoke.concepts_lib import Concepts +from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.data.personalized import per_img_token_list from transformers import CLIPTokenizer from functools import partial @@ -31,157 +31,6 @@ def get_embedding_for_clip_token_id(embedder, token_id): token_id = torch.tensor(token_id, dtype=torch.int) return embedder(token_id.unsqueeze(0))[0, 0] -@dataclass -class TextualInversion: - trigger_string: str - token_id: int - embedding: torch.Tensor - - @property - def embedding_vector_length(self) -> int: - return self.embedding.shape[0] - -class TextualInversionManager(): - def __init__(self, clip_embedder): - self.clip_embedder = clip_embedder - default_textual_inversions: list[TextualInversion] = [] - self.textual_inversions = default_textual_inversions - - def load_textual_inversion(self, ckpt_path, full_precision=True): - - scan_result = scan_file_path(ckpt_path) - if scan_result.infected_files == 1: - print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') - print('### For your safety, InvokeAI will not load this embed.') - return - - ckpt = torch.load(ckpt_path, map_location='cpu') - - # Handle .pt textual inversion files - if 'string_to_token' in ckpt and 'string_to_param' in ckpt: - filename = os.path.basename(ckpt_path) - token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension - if len(ckpt["string_to_token"]) > 1: - print(f">> {ckpt_path} has >1 embedding, only the first will be used") - - string_to_param_dict = ckpt['string_to_param'] - embedding = list(string_to_param_dict.values())[0] - self.add_textual_inversion(token_str, embedding, full_precision) - - # Handle .bin textual inversion files from Huggingface Concepts - # https://huggingface.co/sd-concepts-library - else: - for token_str in list(ckpt.keys()): - embedding = ckpt[token_str] - self.add_textual_inversion(token_str, embedding, full_precision) - - def add_textual_inversion(self, token_str, embedding) -> int: - """ - Add a textual inversion to be recognised. - :param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. - :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears. - :return: The token id for the added embedding, either existing or newly-added. - """ - if token_str in [ti.trigger_string for ti in self.textual_inversions]: - print(f">> TextualInversionManager refusing to overwrite already-loaded token '{token_str}'") - return - if len(embedding.shape) == 1: - embedding = embedding.unsqueeze(0) - elif len(embedding.shape) > 2: - raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2") - - existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str) - if existing_token_id == self.clip_embedder.tokenizer.unk_token_id: - num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str) - current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None) - current_token_count = current_embeddings.num_embeddings - new_token_count = current_token_count + num_tokens_added - self.clip_embedder.transformer.resize_token_embeddings(new_token_count) - - token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str) - self.textual_inversions.append(TextualInversion( - trigger_string=token_str, - token_id=token_id, - embedding=embedding - )) - return token_id - - def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool: - try: - ti = self.get_textual_inversion_for_trigger_string(trigger_string) - return ti is not None - except StopIteration: - return False - - def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion: - return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string) - - - def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion: - return next(ti for ti in self.textual_inversions if ti.token_id == token_id) - - def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]: - """ - Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes. - - :param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers. - :param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1. - :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too - long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids. - """ - if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id: - raise ValueError("prompt_token_ids must not start with bos_token_id") - if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id: - raise ValueError("prompt_token_ids must not end with eos_token_id") - textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] - prompt_token_ids = prompt_token_ids[:] - for i, token_id in reversed(list(enumerate(prompt_token_ids))): - if token_id in textual_inversion_token_ids: - textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) - for pad_idx in range(1, textual_inversion.embedding_vector_length): - prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id) - - return prompt_token_ids - - def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor: - """ - For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding - row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite - subsequent rows in `prompt_embeddings` as well. - - :param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght - >1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers. - :param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in - `prompt_token_ids` (i.e., also already expanded). - :return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings. - """ - if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77 - raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") - if len(prompt_token_ids) > self.clip_embedder.max_length: - raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})") - if len(prompt_token_ids) < self.clip_embedder.max_length: - raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)") - if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id: - raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id") - - textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] - pad_token_id = self.clip_embedder.tokenizer.pad_token_id - overwritten_prompt_embeddings = prompt_embeddings.clone() - for i, token_id in enumerate(prompt_token_ids): - if token_id == pad_token_id: - continue - if token_id in textual_inversion_token_ids: - textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) - end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1) - count_to_overwrite = end_index - i - for j in range(0, count_to_overwrite): - # only overwrite the textual inversion token id or the padding token id - if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id: - break - overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j] - - return overwritten_prompt_embeddings - class EmbeddingManager(nn.Module): def __init__( @@ -197,8 +46,7 @@ class EmbeddingManager(nn.Module): super().__init__() self.embedder = embedder - self.concepts_library=Concepts() - self.concepts_loaded = dict() + self.concepts_library=HuggingFaceConceptsLibrary() self.string_to_token_dict = {} self.string_to_param_dict = nn.ParameterDict() @@ -349,22 +197,6 @@ class EmbeddingManager(nn.Module): ckpt_path, ) - def load_concepts(self, concepts:list[str], full=True): - bin_files = list() - for concept_name in concepts: - if concept_name in self.concepts_loaded: - continue - else: - bin_file = self.concepts_library.get_concept_model_path(concept_name) - if not bin_file: - continue - bin_files.append(bin_file) - self.concepts_loaded[concept_name]=True - self.load(bin_files, full) - - def list_terms(self) -> list[str]: - return self.concepts_loaded.keys() - def load(self, ckpt_paths, full=True): if len(ckpt_paths) == 0: return diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index be9f88cdd2..4db2b1fbab 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -9,6 +9,7 @@ from transformers import CLIPTokenizer, CLIPTextModel import kornia from ldm.invoke.devices import choose_torch_device from ldm.invoke.globals import Globals +#from ldm.modules.textual_inversion_manager import TextualInversionManager from ldm.modules.x_transformer import ( Encoder, @@ -465,7 +466,12 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): fragment_weights_key = "fragment_weights" return_tokens_key = "return_tokens" + def set_textual_inversion_manager(self, manager): #TextualInversionManager): + # TODO all of the weighting and expanding stuff needs be moved out of this class + self.textual_inversion_manager = manager + def forward(self, text: list, **kwargs): + # TODO all of the weighting and expanding stuff needs be moved out of this class ''' :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different @@ -560,19 +566,42 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): else: return batch_z - def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: - tokens = self.tokenizer( + def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: + """ + Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like + `[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if + `include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length + (typically 75 tokens + eos/bos markers). + + :param fragments: The strings to convert. + :param include_start_and_end_markers: + :return: + """ + # for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib) + token_ids_list = self.tokenizer( fragments, truncation=True, max_length=self.max_length, return_overflowing_tokens=False, padding='do_not_pad', - return_tensors=None, # just give me a list of ints + return_tensors=None, # just give me lists of ints )['input_ids'] - if include_start_and_end_markers: - return tokens - else: - return [x[1:-1] for x in tokens] + + result = [] + for token_ids in token_ids_list: + # trim eos/bos + token_ids = token_ids[1:-1] + # pad for textual inversions with vector length >1 + token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids(token_ids) + # restrict length to max_length-2 (leaving room for bos/eos) + token_ids = token_ids[0:self.max_length - 2] + # add back eos/bos if requested + if include_start_and_end_markers: + token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id] + + result.append(token_ids) + + return result @classmethod @@ -597,56 +626,60 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): if len(fragments) == 0 and len(weights) == 0: fragments = [''] weights = [1] - item_encodings = self.tokenizer( - fragments, - truncation=True, - max_length=self.max_length, - return_overflowing_tokens=True, - padding='do_not_pad', - return_tensors=None, # just give me a list of ints - )['input_ids'] - all_tokens = [] + per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False) + all_token_ids = [] per_token_weights = [] #print("all fragments:", fragments, weights) - for index, fragment in enumerate(item_encodings): - weight = weights[index] + for index, fragment in enumerate(per_fragment_token_ids): + weight = float(weights[index]) #print("processing fragment", fragment, weight) - fragment_tokens = item_encodings[index] - #print("fragment", fragment, "processed to", fragment_tokens) - # trim bos and eos markers before appending - all_tokens.extend(fragment_tokens[1:-1]) - per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) + this_fragment_token_ids = per_fragment_token_ids[index] + #print("fragment", fragment, "processed to", this_fragment_token_ids) + # append + all_token_ids += this_fragment_token_ids + # fill out weights tensor with one float per token + per_token_weights += [weight] * len(this_fragment_token_ids) - if (len(all_tokens) + 2) > self.max_length: - excess_token_count = (len(all_tokens) + 2) - self.max_length + # leave room for bos/eos + if len(all_token_ids) > self.max_length - 2: + excess_token_count = len(all_token_ids) - self.max_length - 2 + # TODO build nice description string of how the truncation was applied + # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to + # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit. print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") - all_tokens = all_tokens[:self.max_length - 2] - per_token_weights = per_token_weights[:self.max_length - 2] + all_token_ids = all_token_ids[0:self.max_length] + per_token_weights = per_token_weights[0:self.max_length] # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] # (77 = self.max_length) - pad_length = self.max_length - 1 - len(all_tokens) - all_tokens.insert(0, self.tokenizer.bos_token_id) - all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) - per_token_weights.insert(0, 1) - per_token_weights.extend([1] * pad_length) + all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id] + per_token_weights = [1.0] + per_token_weights + [1.0] + pad_length = self.max_length - len(all_token_ids) + all_token_ids += [self.tokenizer.eos_token_id] * pad_length + per_token_weights += [1.0] * pad_length - all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) + all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device) per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) - #print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") - return all_tokens_tensor, per_token_weights_tensor + #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") + return all_token_ids_tensor, per_token_weights_tensor - def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: + def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: ''' Build a tensor representing the passed-in tokens, each of which has a weight. - :param tokens: A tensor of shape (77) containing token ids (integers) + :param token_ids: A tensor of shape (77) containing token ids (integers) :param per_token_weights: A tensor of shape (77) containing weights (floats) :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector :param kwargs: passed on to self.transformer() :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. ''' #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") - z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) + if token_ids.shape[0] != self.max_length: + raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]") + z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs) + assert(z.shape[0] == 1) + new_z0 = self.textual_inversion_manager.overwrite_textual_inversion_embeddings(token_ids, z[0]) + z[0] = new_z0 + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) if weight_delta_from_empty: @@ -660,7 +693,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): z_delta_from_empty = z - empty_z weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) - weighted_z_delta_from_empty = (weighted_z-empty_z) + #weighted_z_delta_from_empty = (weighted_z-empty_z) #print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) #print("using empty-delta method, first 5 rows:") diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py new file mode 100644 index 0000000000..90293e0c82 --- /dev/null +++ b/ldm/modules/textual_inversion_manager.py @@ -0,0 +1,185 @@ +import os +from typing import Union + +import torch +from attr import dataclass +from picklescan.scanner import scan_file_path + +from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary +from ldm.modules.embedding_manager import get_clip_token_id_for_string +from ldm.modules.encoders.modules import FrozenCLIPEmbedder + + +@dataclass +class TextualInversion: + trigger_string: str + token_id: int + embedding: torch.Tensor + + @property + def embedding_vector_length(self) -> int: + return self.embedding.shape[0] + +class TextualInversionManager(): + def __init__(self, clip_embedder: FrozenCLIPEmbedder, full_precision: bool): + self.clip_embedder = clip_embedder + self.full_precision = full_precision + self.hf_concepts_library = HuggingFaceConceptsLibrary() + default_textual_inversions: list[TextualInversion] = [] + self.textual_inversions = default_textual_inversions + + def load_huggingface_concepts(self, concepts: list[str]): + for concept_name in concepts: + if concept_name in self.hf_concepts_library.concepts_loaded: + continue + bin_file = self.hf_concepts_library.get_concept_model_path(concept_name) + if not bin_file: + continue + self.load_textual_inversion(bin_file) + self.hf_concepts_library.concepts_loaded[concept_name]=True + + def get_all_trigger_strings(self) -> list[str]: + return [ti.trigger_string for ti in self.textual_inversions] + + def load_textual_inversion(self, ckpt_path): + scan_result = scan_file_path(ckpt_path) + if scan_result.infected_files == 1: + print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') + print('### For your safety, InvokeAI will not load this embed.') + return + + ckpt = torch.load(ckpt_path, map_location='cpu') + + # Handle .pt textual inversion files + if 'string_to_token' in ckpt and 'string_to_param' in ckpt: + filename = os.path.basename(ckpt_path) + trigger_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension + if len(ckpt["string_to_token"]) > 1: + print(f">> {ckpt_path} has >1 embedding, only the first will be used") + + string_to_param_dict = ckpt['string_to_param'] + embedding = list(string_to_param_dict.values())[0] + self.add_textual_inversion(trigger_str, embedding) + + # Handle .bin textual inversion files from Huggingface Concepts + # https://huggingface.co/sd-concepts-library + else: + for trigger_str in list(ckpt.keys()): + embedding = ckpt[trigger_str] + self.add_textual_inversion(trigger_str, embedding) + + def add_textual_inversion(self, trigger_str, embedding) -> int: + """ + Add a textual inversion to be recognised. + :param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. + :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears. + :return: The token id for the added embedding, either existing or newly-added. + """ + if trigger_str in [ti.trigger_string for ti in self.textual_inversions]: + print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'") + return + if not self.full_precision: + embedding = embedding.half() + if len(embedding.shape) == 1: + embedding = embedding.unsqueeze(0) + elif len(embedding.shape) > 2: + raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2") + + existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, trigger_str) + if existing_token_id == self.clip_embedder.tokenizer.unk_token_id: + num_tokens_added = self.clip_embedder.tokenizer.add_tokens(trigger_str) + current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None) + current_token_count = current_embeddings.num_embeddings + new_token_count = current_token_count + num_tokens_added + self.clip_embedder.transformer.resize_token_embeddings(new_token_count) + + token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, trigger_str) + self.textual_inversions.append(TextualInversion( + trigger_string=trigger_str, + token_id=token_id, + embedding=embedding + )) + return token_id + + def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool: + try: + ti = self.get_textual_inversion_for_trigger_string(trigger_string) + return ti is not None + except StopIteration: + return False + + def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion: + return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string) + + + def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion: + return next(ti for ti in self.textual_inversions if ti.token_id == token_id) + + def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]: + """ + Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes. + + :param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers. + :param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1. + :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too + long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids. + """ + if len(prompt_token_ids) == 0: + return prompt_token_ids + + if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id: + raise ValueError("prompt_token_ids must not start with bos_token_id") + if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id: + raise ValueError("prompt_token_ids must not end with eos_token_id") + textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] + prompt_token_ids = prompt_token_ids[:] + for i, token_id in reversed(list(enumerate(prompt_token_ids))): + if token_id in textual_inversion_token_ids: + textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) + for pad_idx in range(1, textual_inversion.embedding_vector_length): + prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id) + + return prompt_token_ids + + def overwrite_textual_inversion_embeddings(self, prompt_token_ids: Union[torch.Tensor,list[int]], prompt_embeddings: torch.Tensor) -> torch.Tensor: + """ + For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding + row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite + subsequent rows in `prompt_embeddings` as well. + + :param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector length + >1 (call `expand_textual_inversion_token_ids()` to do this), padded to max length, and including bos and eos markers. + :param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in + `prompt_token_ids` (i.e., also already expanded). + :return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings. + """ + if type(prompt_token_ids) is torch.Tensor: + if prompt_token_ids.shape != torch.Size([self.clip_embedder.max_length]): + raise ValueError(f"prompt_token_ids must be a list of length {self.clip_embedder.max_length} or a tensor of shape [{self.clip_embedder.max_length}]") + prompt_token_ids = list(prompt_token_ids.cpu().numpy()) + if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77 + raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") + if len(prompt_token_ids) > self.clip_embedder.max_length: + raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})") + if len(prompt_token_ids) < self.clip_embedder.max_length: + raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)") + if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id: + raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id") + + textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions] + pad_token_id = self.clip_embedder.tokenizer.pad_token_id + overwritten_prompt_embeddings = prompt_embeddings.clone() + for i, token_id in enumerate(prompt_token_ids): + if token_id == pad_token_id: + continue + if token_id in textual_inversion_token_ids: + textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id) + end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1) + count_to_overwrite = end_index - i + for j in range(0, count_to_overwrite): + # only overwrite the textual inversion token id or the padding token id + if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id: + break + overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j] + + return overwritten_prompt_embeddings