import os.path from cmath import log import torch from attr import dataclass from torch import nn import sys from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.data.personalized import per_img_token_list from transformers import CLIPTokenizer from functools import partial from picklescan.scanner import scan_file_path PROGRESSIVE_SCALE = 2000 def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int: token_id = tokenizer.convert_tokens_to_ids(token_str) return token_id def get_embedding_for_clip_token_id(embedder, token_id): if type(token_id) is not torch.Tensor: token_id = torch.tensor(token_id, dtype=torch.int) return embedder(token_id.unsqueeze(0))[0, 0] class EmbeddingManager(nn.Module): def __init__( self, embedder, placeholder_strings=None, initializer_words=None, per_image_tokens=False, num_vectors_per_token=1, progressive_words=False, **kwargs, ): super().__init__() self.embedder = embedder self.concepts_library=HuggingFaceConceptsLibrary() self.string_to_token_dict = {} self.string_to_param_dict = nn.ParameterDict() self.initial_embeddings = ( nn.ParameterDict() ) # These should not be optimized self.progressive_words = progressive_words self.progressive_counter = 0 self.max_vectors_per_token = num_vectors_per_token if hasattr( embedder, 'tokenizer' ): # using Stable Diffusion's CLIP encoder self.is_clip = True get_token_id_for_string = partial( get_clip_token_id_for_string, embedder.tokenizer ) get_embedding_for_tkn_id = partial( get_embedding_for_clip_token_id, embedder.transformer.text_model.embeddings, ) # per bug report #572 #token_dim = 1280 token_dim = 768 else: # using LDM's BERT encoder self.is_clip = False get_token_id_for_string = partial( get_bert_token_id_for_string, embedder.tknz_fn ) get_embedding_for_tkn_id = embedder.transformer.token_emb token_dim = 1280 if per_image_tokens: placeholder_strings.extend(per_img_token_list) for idx, placeholder_string in enumerate(placeholder_strings): token_id = get_token_id_for_string(placeholder_string) if initializer_words and idx < len(initializer_words): init_word_token_id = get_token_id_for_string(initializer_words[idx]) with torch.no_grad(): init_word_embedding = get_embedding_for_tkn_id(init_word_token_id) token_params = torch.nn.Parameter( init_word_embedding.unsqueeze(0).repeat( num_vectors_per_token, 1 ), requires_grad=True, ) self.initial_embeddings[ placeholder_string ] = torch.nn.Parameter( init_word_embedding.unsqueeze(0).repeat( num_vectors_per_token, 1 ), requires_grad=False, ) else: token_params = torch.nn.Parameter( torch.rand( size=(num_vectors_per_token, token_dim), requires_grad=True, ) ) self.string_to_token_dict[placeholder_string] = token_id self.string_to_param_dict[placeholder_string] = token_params def forward( self, tokenized_text, embedded_text, ): # torch.save(embedded_text, '/tmp/embedding-manager-uglysonic-pre-rewrite.pt') b, n, device = *tokenized_text.shape, tokenized_text.device for ( placeholder_string, placeholder_token, ) in self.string_to_token_dict.items(): placeholder_embedding = self.string_to_param_dict[ placeholder_string ].to(device) if self.progressive_words: self.progressive_counter += 1 max_step_tokens = ( 1 + self.progressive_counter // PROGRESSIVE_SCALE ) else: max_step_tokens = self.max_vectors_per_token num_vectors_for_token = min( placeholder_embedding.shape[0], max_step_tokens ) placeholder_rows, placeholder_cols = torch.where( tokenized_text == placeholder_token ) if placeholder_rows.nelement() == 0: continue sorted_cols, sort_idx = torch.sort( placeholder_cols, descending=True ) sorted_rows = placeholder_rows[sort_idx] for idx in range(sorted_rows.shape[0]): row = sorted_rows[idx] col = sorted_cols[idx] new_token_row = torch.cat( [ tokenized_text[row][:col], torch.tensor([placeholder_token] * num_vectors_for_token, device=device), tokenized_text[row][col + 1 :], ], axis=0, )[:n] new_embed_row = torch.cat( [ embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1 :], ], axis=0, )[:n] embedded_text[row] = new_embed_row tokenized_text[row] = new_token_row return embedded_text def save(self, ckpt_path): torch.save( { 'string_to_token': self.string_to_token_dict, 'string_to_param': self.string_to_param_dict, }, ckpt_path, ) def load(self, ckpt_paths, full=True): if len(ckpt_paths) == 0: return if type(ckpt_paths) != list: ckpt_paths = [ckpt_paths] ckpt_paths = self._expand_directories(ckpt_paths) for c in ckpt_paths: self._load(c,full) # remember that we know this term and don't try to download it again from the concepts library # note that if the concept name is also provided and different from the trigger term, they # both will be stored in this dictionary for term in self.string_to_param_dict.keys(): term = term.strip('<').strip('>') self.concepts_loaded[term] = True print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}') def _expand_directories(self, paths:list[str]): expanded_paths = list() for path in paths: if os.path.isfile(path): expanded_paths.append(path) elif os.path.isdir(path): for root, _, files in os.walk(path): for name in files: expanded_paths.append(os.path.join(root,name)) return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')] def _load(self, ckpt_path, full=True): try: scan_result = scan_file_path(ckpt_path) if scan_result.infected_files == 1: print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') print('### For your safety, InvokeAI will not load this embed.') return except Exception: print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}") return embedding_info = self.parse_embedding(ckpt_path) if embedding_info: self.max_vectors_per_token = embedding_info['num_vectors_per_token'] self.add_embedding(embedding_info['name'], embedding_info['embedding'], full) else: print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.') def add_embedding(self, token_str, embedding, full): if token_str in self.string_to_param_dict: print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'") return if not full: embedding = embedding.half() if len(embedding.shape) == 1: embedding = embedding.unsqueeze(0) existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str) if existing_token_id == self.embedder.tokenizer.unk_token_id: num_tokens_added = self.embedder.tokenizer.add_tokens(token_str) current_embeddings = self.embedder.transformer.resize_token_embeddings(None) current_token_count = current_embeddings.num_embeddings new_token_count = current_token_count + num_tokens_added self.embedder.transformer.resize_token_embeddings(new_token_count) token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str) self.string_to_token_dict[token_str] = token_id self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding) def parse_embedding(self, embedding_file: str): file_type = embedding_file.split('.')[-1] if file_type == 'pt': return self.parse_embedding_pt(embedding_file) elif file_type == 'bin': return self.parse_embedding_bin(embedding_file) else: print(f'>> Not a recognized embedding file: {embedding_file}') def parse_embedding_pt(self, embedding_file): embedding_ckpt = torch.load(embedding_file, map_location='cpu') embedding_info = {} # Check if valid embedding file if 'string_to_token' and 'string_to_param' in embedding_ckpt: # Catch variants that do not have the expected keys or values. try: embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0]) # Check num of embeddings and warn user only the first will be used embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"]) if embedding_info['num_of_embeddings'] > 1: print('>> More than 1 embedding found. Will use the first one') embedding = list(embedding_ckpt['string_to_param'].values())[0] except (AttributeError,KeyError): return self.handle_broken_pt_variants(embedding_ckpt, embedding_file) embedding_info['embedding'] = embedding embedding_info['num_vectors_per_token'] = embedding.size()[0] embedding_info['token_dim'] = embedding.size()[1] try: embedding_info['trained_steps'] = embedding_ckpt['step'] embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name'] embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint'] except AttributeError: print(">> No Training Details Found. Passing ...") # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # They are actually .bin files elif len(embedding_ckpt.keys())==1: print('>> Detected .bin file masquerading as .pt file') embedding_info = self.parse_embedding_bin(embedding_file) else: print('>> Invalid embedding format') embedding_info = None return embedding_info def parse_embedding_bin(self, embedding_file): embedding_ckpt = torch.load(embedding_file, map_location='cpu') embedding_info = {} if list(embedding_ckpt.keys()) == 0: print(">> Invalid concepts file") embedding_info = None else: for token in list(embedding_ckpt.keys()): embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0]) embedding_info['embedding'] = embedding_ckpt[token] embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1 embedding_info['token_dim'] = embedding_info['embedding'].size()[0] return embedding_info def handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict: ''' This handles the broken .pt file variants. We only know of one at present. ''' embedding_info = {} if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor): print(f'>> Variant Embedding Detected. Parsing: {embedding_file}') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 token = list(embedding_ckpt['string_to_token'].keys())[0] embedding_info['name'] = os.path.basename(os.path.splitext(embedding_file)[0]) embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] embedding_info['token_dim'] = embedding_info['embedding'].size()[0] else: print('>> Invalid embedding format') embedding_info = None return embedding_info def has_embedding_for_token(self, token_str): return token_str in self.string_to_token_dict def get_embedding_norms_squared(self): all_params = torch.cat( list(self.string_to_param_dict.values()), axis=0 ) # num_placeholders x embedding_dim param_norm_squared = (all_params * all_params).sum( axis=-1 ) # num_placeholders return param_norm_squared def embedding_parameters(self): return self.string_to_param_dict.parameters() def embedding_to_coarse_loss(self): loss = 0.0 num_embeddings = len(self.initial_embeddings) for key in self.initial_embeddings: optimized = self.string_to_param_dict[key] coarse = self.initial_embeddings[key].clone().to(optimized.device) loss = ( loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings ) return loss