diff --git a/invokeai/backend/invoke_ai_web_server.py b/invokeai/backend/invoke_ai_web_server.py index f624cb7710..f6e9446618 100644 --- a/invokeai/backend/invoke_ai_web_server.py +++ b/invokeai/backend/invoke_ai_web_server.py @@ -26,7 +26,7 @@ from invokeai.backend.modules.get_canvas_generation_mode import ( from invokeai.backend.modules.parameters import parameters_to_command from ldm.generate import Generate from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer +from invokeai.backend.ldm.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer from .generator import infill_methods, PipelineIntermediateState from ldm.invoke.globals import ( Globals, global_converted_ckpts_dir, global_models_dir diff --git a/invokeai/backend/ldm/modules/embedding_manager.py b/invokeai/backend/ldm/modules/embedding_manager.py deleted file mode 100644 index 3d60e70eb7..0000000000 --- a/invokeai/backend/ldm/modules/embedding_manager.py +++ /dev/null @@ -1,377 +0,0 @@ -import os.path -from cmath import log -import torch -from attr import dataclass -from torch import nn - -import sys - -from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary -from ldm.data.personalized import per_img_token_list -from transformers import CLIPTokenizer -from functools import partial -from picklescan.scanner import scan_file_path - -PROGRESSIVE_SCALE = 2000 - - -def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int: - token_id = tokenizer.convert_tokens_to_ids(token_str) - return token_id - -def get_embedding_for_clip_token_id(embedder, token_id): - if type(token_id) is not torch.Tensor: - token_id = torch.tensor(token_id, dtype=torch.int) - return embedder(token_id.unsqueeze(0))[0, 0] - - -class EmbeddingManager(nn.Module): - def __init__( - self, - embedder, - placeholder_strings=None, - initializer_words=None, - per_image_tokens=False, - num_vectors_per_token=1, - progressive_words=False, - **kwargs, - ): - super().__init__() - - self.embedder = embedder - self.concepts_library=HuggingFaceConceptsLibrary() - - self.string_to_token_dict = {} - self.string_to_param_dict = nn.ParameterDict() - - self.initial_embeddings = ( - nn.ParameterDict() - ) # These should not be optimized - - self.progressive_words = progressive_words - self.progressive_counter = 0 - - self.max_vectors_per_token = num_vectors_per_token - - if hasattr( - embedder, 'tokenizer' - ): # using Stable Diffusion's CLIP encoder - self.is_clip = True - get_token_id_for_string = partial( - get_clip_token_id_for_string, embedder.tokenizer - ) - get_embedding_for_tkn_id = partial( - get_embedding_for_clip_token_id, - embedder.transformer.text_model.embeddings, - ) - # per bug report #572 - #token_dim = 1280 - token_dim = 768 - else: # using LDM's BERT encoder - self.is_clip = False - get_token_id_for_string = partial( - get_bert_token_id_for_string, embedder.tknz_fn - ) - get_embedding_for_tkn_id = embedder.transformer.token_emb - token_dim = 1280 - - if per_image_tokens: - placeholder_strings.extend(per_img_token_list) - - for idx, placeholder_string in enumerate(placeholder_strings): - - token_id = get_token_id_for_string(placeholder_string) - - if initializer_words and idx < len(initializer_words): - init_word_token_id = get_token_id_for_string(initializer_words[idx]) - - with torch.no_grad(): - init_word_embedding = get_embedding_for_tkn_id(init_word_token_id) - - token_params = torch.nn.Parameter( - init_word_embedding.unsqueeze(0).repeat( - num_vectors_per_token, 1 - ), - requires_grad=True, - ) - self.initial_embeddings[ - placeholder_string - ] = torch.nn.Parameter( - init_word_embedding.unsqueeze(0).repeat( - num_vectors_per_token, 1 - ), - requires_grad=False, - ) - else: - token_params = torch.nn.Parameter( - torch.rand( - size=(num_vectors_per_token, token_dim), - requires_grad=True, - ) - ) - - self.string_to_token_dict[placeholder_string] = token_id - self.string_to_param_dict[placeholder_string] = token_params - - def forward( - self, - tokenized_text, - embedded_text, - ): - # torch.save(embedded_text, '/tmp/embedding-manager-uglysonic-pre-rewrite.pt') - - b, n, device = *tokenized_text.shape, tokenized_text.device - - for ( - placeholder_string, - placeholder_token, - ) in self.string_to_token_dict.items(): - - placeholder_embedding = self.string_to_param_dict[ - placeholder_string - ].to(device) - - if self.progressive_words: - self.progressive_counter += 1 - max_step_tokens = ( - 1 + self.progressive_counter // PROGRESSIVE_SCALE - ) - else: - max_step_tokens = self.max_vectors_per_token - - num_vectors_for_token = min( - placeholder_embedding.shape[0], max_step_tokens - ) - - placeholder_rows, placeholder_cols = torch.where( - tokenized_text == placeholder_token - ) - - if placeholder_rows.nelement() == 0: - continue - - sorted_cols, sort_idx = torch.sort( - placeholder_cols, descending=True - ) - sorted_rows = placeholder_rows[sort_idx] - - for idx in range(sorted_rows.shape[0]): - row = sorted_rows[idx] - col = sorted_cols[idx] - - new_token_row = torch.cat( - [ - tokenized_text[row][:col], - torch.tensor([placeholder_token] * num_vectors_for_token, device=device), - tokenized_text[row][col + 1 :], - ], - axis=0, - )[:n] - new_embed_row = torch.cat( - [ - embedded_text[row][:col], - placeholder_embedding[:num_vectors_for_token], - embedded_text[row][col + 1 :], - ], - axis=0, - )[:n] - - embedded_text[row] = new_embed_row - tokenized_text[row] = new_token_row - - return embedded_text - - def save(self, ckpt_path): - torch.save( - { - 'string_to_token': self.string_to_token_dict, - 'string_to_param': self.string_to_param_dict, - }, - ckpt_path, - ) - - def load(self, ckpt_paths, full=True): - if len(ckpt_paths) == 0: - return - if type(ckpt_paths) != list: - ckpt_paths = [ckpt_paths] - ckpt_paths = self._expand_directories(ckpt_paths) - for c in ckpt_paths: - self._load(c,full) - # remember that we know this term and don't try to download it again from the concepts library - # note that if the concept name is also provided and different from the trigger term, they - # both will be stored in this dictionary - for term in self.string_to_param_dict.keys(): - term = term.strip('<').strip('>') - self.concepts_loaded[term] = True - print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}') - - def _expand_directories(self, paths:list[str]): - expanded_paths = list() - for path in paths: - if os.path.isfile(path): - expanded_paths.append(path) - elif os.path.isdir(path): - for root, _, files in os.walk(path): - for name in files: - expanded_paths.append(os.path.join(root,name)) - return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')] - - def _load(self, ckpt_path, full=True): - try: - scan_result = scan_file_path(ckpt_path) - if scan_result.infected_files == 1: - print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') - print('### For your safety, InvokeAI will not load this embed.') - return - except Exception: - print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}") - return - - embedding_info = self.parse_embedding(ckpt_path) - if embedding_info: - self.max_vectors_per_token = embedding_info['num_vectors_per_token'] - self.add_embedding(embedding_info['name'], embedding_info['embedding'], full) - else: - print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.') - - def add_embedding(self, token_str, embedding, full): - if token_str in self.string_to_param_dict: - print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'") - return - if not full: - embedding = embedding.half() - if len(embedding.shape) == 1: - embedding = embedding.unsqueeze(0) - - existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str) - if existing_token_id == self.embedder.tokenizer.unk_token_id: - num_tokens_added = self.embedder.tokenizer.add_tokens(token_str) - current_embeddings = self.embedder.transformer.resize_token_embeddings(None) - current_token_count = current_embeddings.num_embeddings - new_token_count = current_token_count + num_tokens_added - self.embedder.transformer.resize_token_embeddings(new_token_count) - - token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str) - self.string_to_token_dict[token_str] = token_id - self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding) - - def parse_embedding(self, embedding_file: str): - file_type = embedding_file.split('.')[-1] - if file_type == 'pt': - return self.parse_embedding_pt(embedding_file) - elif file_type == 'bin': - return self.parse_embedding_bin(embedding_file) - else: - print(f'>> Not a recognized embedding file: {embedding_file}') - - def parse_embedding_pt(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location='cpu') - embedding_info = {} - - # Check if valid embedding file - if 'string_to_token' and 'string_to_param' in embedding_ckpt: - - # Catch variants that do not have the expected keys or values. - try: - embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0]) - - # Check num of embeddings and warn user only the first will be used - embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"]) - if embedding_info['num_of_embeddings'] > 1: - print('>> More than 1 embedding found. Will use the first one') - - embedding = list(embedding_ckpt['string_to_param'].values())[0] - except (AttributeError,KeyError): - return self.handle_broken_pt_variants(embedding_ckpt, embedding_file) - - embedding_info['embedding'] = embedding - embedding_info['num_vectors_per_token'] = embedding.size()[0] - embedding_info['token_dim'] = embedding.size()[1] - - try: - embedding_info['trained_steps'] = embedding_ckpt['step'] - embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name'] - embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint'] - except AttributeError: - print(">> No Training Details Found. Passing ...") - - # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ - # They are actually .bin files - elif len(embedding_ckpt.keys())==1: - print('>> Detected .bin file masquerading as .pt file') - embedding_info = self.parse_embedding_bin(embedding_file) - - else: - print('>> Invalid embedding format') - embedding_info = None - - return embedding_info - - def parse_embedding_bin(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location='cpu') - embedding_info = {} - - if list(embedding_ckpt.keys()) == 0: - print(">> Invalid concepts file") - embedding_info = None - else: - for token in list(embedding_ckpt.keys()): - embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0]) - embedding_info['embedding'] = embedding_ckpt[token] - embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1 - embedding_info['token_dim'] = embedding_info['embedding'].size()[0] - - return embedding_info - - def handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict: - ''' - This handles the broken .pt file variants. We only know of one at present. - ''' - embedding_info = {} - if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor): - print(f'>> Variant Embedding Detected. Parsing: {embedding_file}') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 - token = list(embedding_ckpt['string_to_token'].keys())[0] - embedding_info['name'] = os.path.basename(os.path.splitext(embedding_file)[0]) - embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] - embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] - embedding_info['token_dim'] = embedding_info['embedding'].size()[0] - else: - print('>> Invalid embedding format') - embedding_info = None - - return embedding_info - - def has_embedding_for_token(self, token_str): - return token_str in self.string_to_token_dict - - def get_embedding_norms_squared(self): - all_params = torch.cat( - list(self.string_to_param_dict.values()), axis=0 - ) # num_placeholders x embedding_dim - param_norm_squared = (all_params * all_params).sum( - axis=-1 - ) # num_placeholders - - return param_norm_squared - - def embedding_parameters(self): - return self.string_to_param_dict.parameters() - - def embedding_to_coarse_loss(self): - - loss = 0.0 - num_embeddings = len(self.initial_embeddings) - - for key in self.initial_embeddings: - optimized = self.string_to_param_dict[key] - coarse = self.initial_embeddings[key].clone().to(optimized.device) - - loss = ( - loss - + (optimized - coarse) - @ (optimized - coarse).T - / num_embeddings - ) - - return loss