import os import traceback from dataclasses import dataclass from pathlib import Path from typing import Optional, Union import torch from picklescan.scanner import scan_file_path from transformers import CLIPTextModel, CLIPTokenizer from compel.embeddings_provider import BaseTextualInversionManager from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary @dataclass class TextualInversion: trigger_string: str embedding: torch.Tensor trigger_token_id: Optional[int] = None pad_token_ids: Optional[list[int]] = None @property def embedding_vector_length(self) -> int: return self.embedding.shape[0] class TextualInversionManager(BaseTextualInversionManager): def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, full_precision: bool = True, ): self.tokenizer = tokenizer self.text_encoder = text_encoder self.full_precision = full_precision self.hf_concepts_library = HuggingFaceConceptsLibrary() self.trigger_to_sourcefile = dict() default_textual_inversions: list[TextualInversion] = [] self.textual_inversions = default_textual_inversions def load_huggingface_concepts(self, concepts: list[str]): for concept_name in concepts: if concept_name in self.hf_concepts_library.concepts_loaded: continue trigger = self.hf_concepts_library.concept_to_trigger(concept_name) if ( self.has_textual_inversion_for_trigger_string(trigger) or self.has_textual_inversion_for_trigger_string(concept_name) or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>") ): # in case a token with literal angle brackets encountered print(f">> Loaded local embedding for trigger {concept_name}") continue bin_file = self.hf_concepts_library.get_concept_model_path(concept_name) if not bin_file: continue print(f">> Loaded remote embedding for trigger {concept_name}") self.load_textual_inversion(bin_file) self.hf_concepts_library.concepts_loaded[concept_name] = True def get_all_trigger_strings(self) -> list[str]: return [ti.trigger_string for ti in self.textual_inversions] def load_textual_inversion( self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False ): ckpt_path = Path(ckpt_path) if not ckpt_path.is_file(): return if str(ckpt_path).endswith(".DS_Store"): return try: scan_result = scan_file_path(str(ckpt_path)) if scan_result.infected_files == 1: print( f"\n### Security Issues Found in Model: {scan_result.issues_count}" ) print("### For your safety, InvokeAI will not load this embed.") return except Exception: print( f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt." ) return embedding_info = self._parse_embedding(str(ckpt_path)) if embedding_info is None: # We've already put out an error message about the bad embedding in _parse_embedding, so just return. return elif ( self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info["token_dim"] ): print( f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." ) return # Resolve the situation in which an earlier embedding has claimed the same # trigger string. We replace the trigger with '', as we used to. trigger_str = embedding_info["name"] sourcefile = ( f"{ckpt_path.parent.name}/{ckpt_path.name}" if ckpt_path.name == "learned_embeds.bin" else ckpt_path.name ) if trigger_str in self.trigger_to_sourcefile: replacement_trigger_str = ( f"<{ckpt_path.parent.name}>" if ckpt_path.name == "learned_embeds.bin" else f"<{ckpt_path.stem}>" ) print( f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}" ) trigger_str = replacement_trigger_str try: self._add_textual_inversion( trigger_str, embedding_info["embedding"], defer_injecting_tokens=defer_injecting_tokens, ) # remember which source file claims this trigger self.trigger_to_sourcefile[trigger_str] = sourcefile except ValueError as e: print(f' | Ignoring incompatible embedding {embedding_info["name"]}') print(f" | The error was {str(e)}") def _add_textual_inversion( self, trigger_str, embedding, defer_injecting_tokens=False ) -> Optional[TextualInversion]: """ Add a textual inversion to be recognised. :param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears. :return: The token id for the added embedding, either existing or newly-added. """ if trigger_str in [ti.trigger_string for ti in self.textual_inversions]: print( f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'" ) return if not self.full_precision: embedding = embedding.half() if len(embedding.shape) == 1: embedding = embedding.unsqueeze(0) elif len(embedding.shape) > 2: raise ValueError( f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2." ) try: ti = TextualInversion(trigger_string=trigger_str, embedding=embedding) if not defer_injecting_tokens: self._inject_tokens_and_assign_embeddings(ti) self.textual_inversions.append(ti) return ti except ValueError as e: if str(e).startswith("Warning"): print(f">> {str(e)}") else: traceback.print_exc() print( f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}." ) raise def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int: if ti.trigger_token_id is not None: raise ValueError( f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'" ) trigger_token_id = self._get_or_create_token_id_and_assign_embedding( ti.trigger_string, ti.embedding[0] ) if ti.embedding_vector_length > 1: # for embeddings with vector length > 1 pad_token_strings = [ ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length) ] # todo: batched UI for faster loading when vector length >2 pad_token_ids = [ self._get_or_create_token_id_and_assign_embedding( pad_token_str, ti.embedding[1 + i] ) for (i, pad_token_str) in enumerate(pad_token_strings) ] else: pad_token_ids = [] ti.trigger_token_id = trigger_token_id ti.pad_token_ids = pad_token_ids return ti.trigger_token_id def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool: try: ti = self.get_textual_inversion_for_trigger_string(trigger_string) return ti is not None except StopIteration: return False def get_textual_inversion_for_trigger_string( self, trigger_string: str ) -> TextualInversion: return next( ti for ti in self.textual_inversions if ti.trigger_string == trigger_string ) def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion: return next( ti for ti in self.textual_inversions if ti.trigger_token_id == token_id ) def create_deferred_token_ids_for_any_trigger_terms( self, prompt_string: str ) -> list[int]: injected_token_ids = [] for ti in self.textual_inversions: if ti.trigger_token_id is None and ti.trigger_string in prompt_string: if ti.embedding_vector_length > 1: print( f">> Preparing tokens for textual inversion {ti.trigger_string}..." ) try: self._inject_tokens_and_assign_embeddings(ti) except ValueError as e: print( f" | Ignoring incompatible embedding trigger {ti.trigger_string}" ) print(f" | The error was {str(e)}") continue injected_token_ids.append(ti.trigger_token_id) injected_token_ids.extend(ti.pad_token_ids) return injected_token_ids def expand_textual_inversion_token_ids_if_necessary( self, prompt_token_ids: list[int] ) -> list[int]: """ Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes. :param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers. :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary. """ if len(prompt_token_ids) == 0: return prompt_token_ids if prompt_token_ids[0] == self.tokenizer.bos_token_id: raise ValueError("prompt_token_ids must not start with bos_token_id") if prompt_token_ids[-1] == self.tokenizer.eos_token_id: raise ValueError("prompt_token_ids must not end with eos_token_id") textual_inversion_trigger_token_ids = [ ti.trigger_token_id for ti in self.textual_inversions ] prompt_token_ids = prompt_token_ids.copy() for i, token_id in reversed(list(enumerate(prompt_token_ids))): if token_id in textual_inversion_trigger_token_ids: textual_inversion = next( ti for ti in self.textual_inversions if ti.trigger_token_id == token_id ) for pad_idx in range(0, textual_inversion.embedding_vector_length - 1): prompt_token_ids.insert( i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx] ) return prompt_token_ids def _get_or_create_token_id_and_assign_embedding( self, token_str: str, embedding: torch.Tensor ) -> int: if len(embedding.shape) != 1: raise ValueError( "Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2" ) existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str) if existing_token_id == self.tokenizer.unk_token_id: num_tokens_added = self.tokenizer.add_tokens(token_str) current_embeddings = self.text_encoder.resize_token_embeddings(None) current_token_count = current_embeddings.num_embeddings new_token_count = current_token_count + num_tokens_added # the following call is slow - todo make batched for better performance with vector length >1 self.text_encoder.resize_token_embeddings(new_token_count) token_id = self.tokenizer.convert_tokens_to_ids(token_str) if token_id == self.tokenizer.unk_token_id: raise RuntimeError(f"Unable to find token id for token '{token_str}'") if ( self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape ): raise ValueError( f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}." ) self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding return token_id def _parse_embedding(self, embedding_file: str): file_type = embedding_file.split(".")[-1] if file_type == "pt": return self._parse_embedding_pt(embedding_file) elif file_type == "bin": return self._parse_embedding_bin(embedding_file) else: print(f"** Notice: unrecognized embedding file format: {embedding_file}") return None def _parse_embedding_pt(self, embedding_file): embedding_ckpt = torch.load(embedding_file, map_location="cpu") embedding_info = {} # Check if valid embedding file if "string_to_token" and "string_to_param" in embedding_ckpt: # Catch variants that do not have the expected keys or values. try: embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( os.path.splitext(embedding_file)[0] ) # Check num of embeddings and warn user only the first will be used embedding_info["num_of_embeddings"] = len( embedding_ckpt["string_to_token"] ) if embedding_info["num_of_embeddings"] > 1: print(">> More than 1 embedding found. Will use the first one") embedding = list(embedding_ckpt["string_to_param"].values())[0] except (AttributeError, KeyError): return self._handle_broken_pt_variants(embedding_ckpt, embedding_file) embedding_info["embedding"] = embedding embedding_info["num_vectors_per_token"] = embedding.size()[0] embedding_info["token_dim"] = embedding.size()[1] try: embedding_info["trained_steps"] = embedding_ckpt["step"] embedding_info["trained_model_name"] = embedding_ckpt[ "sd_checkpoint_name" ] embedding_info["trained_model_checksum"] = embedding_ckpt[ "sd_checkpoint" ] except AttributeError: print(">> No Training Details Found. Passing ...") # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # They are actually .bin files elif len(embedding_ckpt.keys()) == 1: embedding_info = self._parse_embedding_bin(embedding_file) else: print(">> Invalid embedding format") embedding_info = None return embedding_info def _parse_embedding_bin(self, embedding_file): embedding_ckpt = torch.load(embedding_file, map_location="cpu") embedding_info = {} if list(embedding_ckpt.keys()) == 0: print(">> Invalid concepts file") embedding_info = None else: for token in list(embedding_ckpt.keys()): embedding_info["name"] = ( token or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" ) embedding_info["embedding"] = embedding_ckpt[token] embedding_info[ "num_vectors_per_token" ] = 1 # All Concepts seem to default to 1 embedding_info["token_dim"] = embedding_info["embedding"].size()[0] return embedding_info def _handle_broken_pt_variants( self, embedding_ckpt: dict, embedding_file: str ) -> dict: """ This handles the broken .pt file variants. We only know of one at present. """ embedding_info = {} if isinstance( list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor ): for token in list(embedding_ckpt["string_to_token"].keys()): embedding_info["name"] = ( token if token != "*" else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" ) embedding_info["embedding"] = embedding_ckpt[ "string_to_param" ].state_dict()[token] embedding_info["num_vectors_per_token"] = embedding_info[ "embedding" ].shape[0] embedding_info["token_dim"] = embedding_info["embedding"].size()[1] else: print(">> Invalid embedding format") embedding_info = None return embedding_info