diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 239fd346ab..af9383bbd6 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -256,32 +256,22 @@ class EmbeddingManager(nn.Module): return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')] def _load(self, ckpt_path, full=True): - - scan_result = scan_file_path(ckpt_path) - if scan_result.infected_files == 1: - print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') - print('### For your safety, InvokeAI will not load this embed.') + try: + scan_result = scan_file_path(ckpt_path) + if scan_result.infected_files == 1: + print(f'\n### Security Issues Found in Model: {scan_result.issues_count}') + print('### For your safety, InvokeAI will not load this embed.') + return + except Exception: + print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}") return - - ckpt = torch.load(ckpt_path, map_location='cpu') - # Handle .pt textual inversion files - if 'string_to_token' in ckpt and 'string_to_param' in ckpt: - filename = os.path.basename(ckpt_path) - token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension - if len(ckpt["string_to_token"]) > 1: - print(f">> {ckpt_path} has >1 embedding, only the first will be used") - - string_to_param_dict = ckpt['string_to_param'] - embedding = list(string_to_param_dict.values())[0] - self.add_embedding(token_str, embedding, full) - - # Handle .bin textual inversion files from Huggingface Concepts - # https://huggingface.co/sd-concepts-library + embedding_info = self.parse_embedding(ckpt_path) + if embedding_info: + self.max_vectors_per_token = embedding_info['num_vectors_per_token'] + self.add_embedding(embedding_info['name'], embedding_info['embedding'], full) else: - for token_str in list(ckpt.keys()): - embedding = ckpt[token_str] - self.add_embedding(token_str, embedding, full) + print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.') def add_embedding(self, token_str, embedding, full): if token_str in self.string_to_param_dict: @@ -302,6 +292,92 @@ class EmbeddingManager(nn.Module): self.string_to_token_dict[token_str] = token self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding) + def parse_embedding(self, embedding_file: str): + file_type = embedding_file.split('.')[-1] + if file_type == 'pt': + return self.parse_embedding_pt(embedding_file) + elif file_type == 'bin': + return self.parse_embedding_bin(embedding_file) + else: + print(f'>> Not a recognized embedding file: {embedding_file}') + + def parse_embedding_pt(self, embedding_file): + embedding_ckpt = torch.load(embedding_file, map_location='cpu') + embedding_info = {} + + # Check if valid embedding file + if 'string_to_token' and 'string_to_param' in embedding_ckpt: + + # Catch variants that do not have the expected keys or values. + try: + embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0]) + + # Check num of embeddings and warn user only the first will be used + embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"]) + if embedding_info['num_of_embeddings'] > 1: + print('>> More than 1 embedding found. Will use the first one') + + embedding = list(embedding_ckpt['string_to_param'].values())[0] + except (AttributeError,KeyError): + return self.handle_broken_pt_variants(embedding_ckpt, embedding_file) + + embedding_info['embedding'] = embedding + embedding_info['num_vectors_per_token'] = embedding.size()[0] + embedding_info['token_dim'] = embedding.size()[1] + + try: + embedding_info['trained_steps'] = embedding_ckpt['step'] + embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name'] + embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint'] + except AttributeError: + print(">> No Training Details Found. Passing ...") + + # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ + # They are actually .bin files + elif len(embedding_ckpt.keys())==1: + print('>> Detected .bin file masquerading as .pt file') + embedding_info = self.parse_embedding_bin(embedding_file) + + else: + print('>> Invalid embedding format') + embedding_info = None + + return embedding_info + + def parse_embedding_bin(self, embedding_file): + embedding_ckpt = torch.load(embedding_file, map_location='cpu') + embedding_info = {} + + if list(embedding_ckpt.keys()) == 0: + print(">> Invalid concepts file") + embedding_info = None + else: + for token in list(embedding_ckpt.keys()): + embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0]) + embedding_info['embedding'] = embedding_ckpt[token] + embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1 + embedding_info['token_dim'] = embedding_info['embedding'].size()[0] + + return embedding_info + + def handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict: + ''' + This handles the broken .pt file variants. We only know of one at present. + ''' + embedding_info = {} + if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor): + print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 + for token in list(embedding_ckpt['string_to_token'].keys()): + embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0]) + embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] + embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] + embedding_info['token_dim'] = embedding_info['embedding'].size()[0] + else: + print('>> Invalid embedding format') + embedding_info = None + + return embedding_info + def has_embedding_for_token(self, token_str): return token_str in self.string_to_token_dict