import traceback from dataclasses import dataclass from pathlib import Path from typing import Optional, Union, List import safetensors.torch import torch from compel.embeddings_provider import BaseTextualInversionManager from picklescan.scanner import scan_file_path from transformers import CLIPTextModel, CLIPTokenizer from .concepts_lib import HuggingFaceConceptsLibrary @dataclass class EmbeddingInfo: name: str embedding: torch.Tensor num_vectors_per_token: int token_dim: int trained_steps: int = None trained_model_name: str = None trained_model_checksum: str = None @dataclass class TextualInversion: trigger_string: str embedding: torch.Tensor trigger_token_id: Optional[int] = None pad_token_ids: Optional[list[int]] = None @property def embedding_vector_length(self) -> int: return self.embedding.shape[0] class TextualInversionManager(BaseTextualInversionManager): def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, full_precision: bool = True, ): self.tokenizer = tokenizer self.text_encoder = text_encoder self.full_precision = full_precision self.hf_concepts_library = HuggingFaceConceptsLibrary() self.trigger_to_sourcefile = dict() default_textual_inversions: list[TextualInversion] = [] self.textual_inversions = default_textual_inversions def load_huggingface_concepts(self, concepts: list[str]): for concept_name in concepts: if concept_name in self.hf_concepts_library.concepts_loaded: continue trigger = self.hf_concepts_library.concept_to_trigger(concept_name) if ( self.has_textual_inversion_for_trigger_string(trigger) or self.has_textual_inversion_for_trigger_string(concept_name) or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>") ): # in case a token with literal angle brackets encountered print(f">> Loaded local embedding for trigger {concept_name}") continue bin_file = self.hf_concepts_library.get_concept_model_path(concept_name) if not bin_file: continue print(f">> Loaded remote embedding for trigger {concept_name}") self.load_textual_inversion(bin_file) self.hf_concepts_library.concepts_loaded[concept_name] = True def get_all_trigger_strings(self) -> list[str]: return [ti.trigger_string for ti in self.textual_inversions] def load_textual_inversion( self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False ): ckpt_path = Path(ckpt_path) if not ckpt_path.is_file(): return if str(ckpt_path).endswith(".DS_Store"): return embedding_list = self._parse_embedding(str(ckpt_path)) for embedding_info in embedding_list: if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim): print( f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}." ) continue # Resolve the situation in which an earlier embedding has claimed the same # trigger string. We replace the trigger with '', as we used to. trigger_str = embedding_info.name sourcefile = ( f"{ckpt_path.parent.name}/{ckpt_path.name}" if ckpt_path.name == "learned_embeds.bin" else ckpt_path.name ) if trigger_str in self.trigger_to_sourcefile: replacement_trigger_str = ( f"<{ckpt_path.parent.name}>" if ckpt_path.name == "learned_embeds.bin" else f"<{ckpt_path.stem}>" ) print( f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}" ) trigger_str = replacement_trigger_str try: self._add_textual_inversion( trigger_str, embedding_info.embedding, defer_injecting_tokens=defer_injecting_tokens, ) # remember which source file claims this trigger self.trigger_to_sourcefile[trigger_str] = sourcefile except ValueError as e: print(f' | Ignoring incompatible embedding {embedding_info["name"]}') print(f" | The error was {str(e)}") def _add_textual_inversion( self, trigger_str, embedding, defer_injecting_tokens=False ) -> Optional[TextualInversion]: """ Add a textual inversion to be recognised. :param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears. :return: The token id for the added embedding, either existing or newly-added. """ if trigger_str in [ti.trigger_string for ti in self.textual_inversions]: print( f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'" ) return if not self.full_precision: embedding = embedding.half() if len(embedding.shape) == 1: embedding = embedding.unsqueeze(0) elif len(embedding.shape) > 2: raise ValueError( f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2." ) try: ti = TextualInversion(trigger_string=trigger_str, embedding=embedding) if not defer_injecting_tokens: self._inject_tokens_and_assign_embeddings(ti) self.textual_inversions.append(ti) return ti except ValueError as e: if str(e).startswith("Warning"): print(f">> {str(e)}") else: traceback.print_exc() print( f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}." ) raise def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int: if ti.trigger_token_id is not None: raise ValueError( f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'" ) trigger_token_id = self._get_or_create_token_id_and_assign_embedding( ti.trigger_string, ti.embedding[0] ) if ti.embedding_vector_length > 1: # for embeddings with vector length > 1 pad_token_strings = [ ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length) ] # todo: batched UI for faster loading when vector length >2 pad_token_ids = [ self._get_or_create_token_id_and_assign_embedding( pad_token_str, ti.embedding[1 + i] ) for (i, pad_token_str) in enumerate(pad_token_strings) ] else: pad_token_ids = [] ti.trigger_token_id = trigger_token_id ti.pad_token_ids = pad_token_ids return ti.trigger_token_id def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool: try: ti = self.get_textual_inversion_for_trigger_string(trigger_string) return ti is not None except StopIteration: return False def get_textual_inversion_for_trigger_string( self, trigger_string: str ) -> TextualInversion: return next( ti for ti in self.textual_inversions if ti.trigger_string == trigger_string ) def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion: return next( ti for ti in self.textual_inversions if ti.trigger_token_id == token_id ) def create_deferred_token_ids_for_any_trigger_terms( self, prompt_string: str ) -> list[int]: injected_token_ids = [] for ti in self.textual_inversions: if ti.trigger_token_id is None and ti.trigger_string in prompt_string: if ti.embedding_vector_length > 1: print( f">> Preparing tokens for textual inversion {ti.trigger_string}..." ) try: self._inject_tokens_and_assign_embeddings(ti) except ValueError as e: print( f" | Ignoring incompatible embedding trigger {ti.trigger_string}" ) print(f" | The error was {str(e)}") continue injected_token_ids.append(ti.trigger_token_id) injected_token_ids.extend(ti.pad_token_ids) return injected_token_ids def expand_textual_inversion_token_ids_if_necessary( self, prompt_token_ids: list[int] ) -> list[int]: """ Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes. :param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers. :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary. """ if len(prompt_token_ids) == 0: return prompt_token_ids if prompt_token_ids[0] == self.tokenizer.bos_token_id: raise ValueError("prompt_token_ids must not start with bos_token_id") if prompt_token_ids[-1] == self.tokenizer.eos_token_id: raise ValueError("prompt_token_ids must not end with eos_token_id") textual_inversion_trigger_token_ids = [ ti.trigger_token_id for ti in self.textual_inversions ] prompt_token_ids = prompt_token_ids.copy() for i, token_id in reversed(list(enumerate(prompt_token_ids))): if token_id in textual_inversion_trigger_token_ids: textual_inversion = next( ti for ti in self.textual_inversions if ti.trigger_token_id == token_id ) for pad_idx in range(0, textual_inversion.embedding_vector_length - 1): prompt_token_ids.insert( i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx] ) return prompt_token_ids def _get_or_create_token_id_and_assign_embedding( self, token_str: str, embedding: torch.Tensor ) -> int: if len(embedding.shape) != 1: raise ValueError( "Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2" ) existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str) if existing_token_id == self.tokenizer.unk_token_id: num_tokens_added = self.tokenizer.add_tokens(token_str) current_embeddings = self.text_encoder.resize_token_embeddings(None) current_token_count = current_embeddings.num_embeddings new_token_count = current_token_count + num_tokens_added # the following call is slow - todo make batched for better performance with vector length >1 self.text_encoder.resize_token_embeddings(new_token_count) token_id = self.tokenizer.convert_tokens_to_ids(token_str) if token_id == self.tokenizer.unk_token_id: raise RuntimeError(f"Unable to find token id for token '{token_str}'") if ( self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape ): raise ValueError( f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}." ) self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding return token_id def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]: suffix = Path(embedding_file).suffix try: if suffix in [".pt",".ckpt",".bin"]: scan_result = scan_file_path(embedding_file) if scan_result.infected_files > 0: print( f" ** Security Issues Found in Model: {scan_result.issues_count}" ) print(" ** For your safety, InvokeAI will not load this embed.") return list() ckpt = torch.load(embedding_file,map_location="cpu") else: ckpt = safetensors.torch.load_file(embedding_file) except Exception as e: print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}") return list() # try to figure out what kind of embedding file it is and parse accordingly keys = list(ckpt.keys()) if all(x in keys for x in ['string_to_token','string_to_param','name','step']): return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt elif all(x in keys for x in ['string_to_token','string_to_param']): return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt elif 'emb_params' in keys: return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors else: return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]: basename = Path(file_path).stem print(f' | Loading v1 embedding file: {basename}') embeddings = list() token_counter = -1 for token,embedding in embedding_ckpt["string_to_param"].items(): if token_counter < 0: trigger = embedding_ckpt["name"] elif token_counter == 0: trigger = f'' else: trigger = f'<{basename}-{int(token_counter:=token_counter)}>' token_counter += 1 embedding_info = EmbeddingInfo( name = trigger, embedding = embedding, num_vectors_per_token = embedding.size()[0], token_dim = embedding.size()[1], trained_steps = embedding_ckpt["step"], trained_model_name = embedding_ckpt["sd_checkpoint_name"], trained_model_checksum = embedding_ckpt["sd_checkpoint"] ) embeddings.append(embedding_info) return embeddings def _parse_embedding_v2 ( self, embedding_ckpt: dict, file_path: str ) -> List[EmbeddingInfo]: """ This handles embedding .pt file variant #2. """ basename = Path(file_path).stem print(f' | Loading v2 embedding file: {basename}') embeddings = list() if isinstance( list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor ): token_counter = 0 for token,embedding in embedding_ckpt["string_to_param"].items(): trigger = token if token != '*' \ else f'<{basename}>' if token_counter == 0 \ else f'<{basename}-{int(token_counter:=token_counter+1)}>' embedding_info = EmbeddingInfo( name = trigger, embedding = embedding, num_vectors_per_token = embedding.size()[0], token_dim = embedding.size()[1], ) embeddings.append(embedding_info) else: print(f" ** {basename}: Unrecognized embedding format") return embeddings def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]: """ Parse 'version 3' of the .pt textual inversion embedding files. """ basename = Path(file_path).stem print(f' | Loading v3 embedding file: {basename}') embedding = embedding_ckpt['emb_params'] embedding_info = EmbeddingInfo( name = f'<{basename}>', embedding = embedding, num_vectors_per_token = embedding.size()[0], token_dim = embedding.size()[1], ) return [embedding_info] def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]: """ Parse 'version 4' of the textual inversion embedding files. This one is usually associated with .bin files trained by HuggingFace diffusers. """ basename = Path(filepath).stem short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name print(f' | Loading v4 embedding file: {short_path}') embeddings = list() if list(embedding_ckpt.keys()) == 0: print(f" ** Invalid embeddings file: {short_path}") else: for token,embedding in embedding_ckpt.items(): embedding_info = EmbeddingInfo( name = token or f"<{basename}>", embedding = embedding, num_vectors_per_token = 1, # All Concepts seem to default to 1 token_dim = embedding.size()[0], ) embeddings.append(embedding_info) return embeddings