diff --git a/invokeai/backend/stable_diffusion/textual_inversion_manager.py b/invokeai/backend/stable_diffusion/textual_inversion_manager.py index 2b043afab7..2dba2b88d3 100644 --- a/invokeai/backend/stable_diffusion/textual_inversion_manager.py +++ b/invokeai/backend/stable_diffusion/textual_inversion_manager.py @@ -1,16 +1,26 @@ -import os import traceback from dataclasses import dataclass from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, List +import safetensors.torch import torch + from compel.embeddings_provider import BaseTextualInversionManager from picklescan.scanner import scan_file_path from transformers import CLIPTextModel, CLIPTokenizer from .concepts_lib import HuggingFaceConceptsLibrary +@dataclass +class EmbeddingInfo: + name: str + embedding: torch.Tensor + num_vectors_per_token: int + token_dim: int + trained_steps: int = None + trained_model_name: str = None + trained_model_checksum: str = None @dataclass class TextualInversion: @@ -72,66 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager): if str(ckpt_path).endswith(".DS_Store"): return - try: - scan_result = scan_file_path(str(ckpt_path)) - if scan_result.infected_files == 1: + embedding_list = self._parse_embedding(str(ckpt_path)) + for embedding_info in embedding_list: + if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim): print( - f"\n### Security Issues Found in Model: {scan_result.issues_count}" + f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}." ) - print("### For your safety, InvokeAI will not load this embed.") - return - except Exception: - print( - f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt." - ) - return + continue - embedding_info = self._parse_embedding(str(ckpt_path)) - - if embedding_info is None: - # We've already put out an error message about the bad embedding in _parse_embedding, so just return. - return - elif ( - self.text_encoder.get_input_embeddings().weight.data[0].shape[0] - != embedding_info["token_dim"] - ): - print( - f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." - ) - return - - # Resolve the situation in which an earlier embedding has claimed the same - # trigger string. We replace the trigger with '', as we used to. - trigger_str = embedding_info["name"] - sourcefile = ( - f"{ckpt_path.parent.name}/{ckpt_path.name}" - if ckpt_path.name == "learned_embeds.bin" - else ckpt_path.name - ) - - if trigger_str in self.trigger_to_sourcefile: - replacement_trigger_str = ( - f"<{ckpt_path.parent.name}>" + # Resolve the situation in which an earlier embedding has claimed the same + # trigger string. We replace the trigger with '', as we used to. + trigger_str = embedding_info.name + sourcefile = ( + f"{ckpt_path.parent.name}/{ckpt_path.name}" if ckpt_path.name == "learned_embeds.bin" - else f"<{ckpt_path.stem}>" + else ckpt_path.name ) - print( - f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}" - ) - trigger_str = replacement_trigger_str - try: - self._add_textual_inversion( - trigger_str, - embedding_info["embedding"], - defer_injecting_tokens=defer_injecting_tokens, - ) - # remember which source file claims this trigger - self.trigger_to_sourcefile[trigger_str] = sourcefile + if trigger_str in self.trigger_to_sourcefile: + replacement_trigger_str = ( + f"<{ckpt_path.parent.name}>" + if ckpt_path.name == "learned_embeds.bin" + else f"<{ckpt_path.stem}>" + ) + print( + f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}" + ) + trigger_str = replacement_trigger_str - except ValueError as e: - print(f' | Ignoring incompatible embedding {embedding_info["name"]}') - print(f" | The error was {str(e)}") + try: + self._add_textual_inversion( + trigger_str, + embedding_info.embedding, + defer_injecting_tokens=defer_injecting_tokens, + ) + # remember which source file claims this trigger + self.trigger_to_sourcefile[trigger_str] = sourcefile + + except ValueError as e: + print(f' | Ignoring incompatible embedding {embedding_info["name"]}') + print(f" | The error was {str(e)}") def _add_textual_inversion( self, trigger_str, embedding, defer_injecting_tokens=False @@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager): return token_id - def _parse_embedding(self, embedding_file: str): - file_type = embedding_file.split(".")[-1] - if file_type == "pt": - return self._parse_embedding_pt(embedding_file) - elif file_type == "bin": - return self._parse_embedding_bin(embedding_file) + + def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]: + suffix = Path(embedding_file).suffix + try: + if suffix in [".pt",".ckpt",".bin"]: + scan_result = scan_file_path(embedding_file) + if scan_result.infected_files > 0: + print( + f" ** Security Issues Found in Model: {scan_result.issues_count}" + ) + print(" ** For your safety, InvokeAI will not load this embed.") + return list() + ckpt = torch.load(embedding_file,map_location="cpu") + else: + ckpt = safetensors.torch.load_file(embedding_file) + except Exception as e: + print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}") + return list() + + # try to figure out what kind of embedding file it is and parse accordingly + keys = list(ckpt.keys()) + if all(x in keys for x in ['string_to_token','string_to_param','name','step']): + return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt + + elif all(x in keys for x in ['string_to_token','string_to_param']): + return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt + + elif 'emb_params' in keys: + return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors + else: - print(f"** Notice: unrecognized embedding file format: {embedding_file}") - return None + return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file - def _parse_embedding_pt(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location="cpu") - embedding_info = {} + def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]: + basename = Path(file_path).stem + print(f' | Loading v1 embedding file: {basename}') - # Check if valid embedding file - if "string_to_token" and "string_to_param" in embedding_ckpt: - # Catch variants that do not have the expected keys or values. - try: - embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( - os.path.splitext(embedding_file)[0] - ) + embeddings = list() + token_counter = -1 + for token,embedding in embedding_ckpt["string_to_param"].items(): + if token_counter < 0: + trigger = embedding_ckpt["name"] + elif token_counter == 0: + trigger = f'' + else: + trigger = f'<{basename}-{int(token_counter:=token_counter)}>' + token_counter += 1 + embedding_info = EmbeddingInfo( + name = trigger, + embedding = embedding, + num_vectors_per_token = embedding.size()[0], + token_dim = embedding.size()[1], + trained_steps = embedding_ckpt["step"], + trained_model_name = embedding_ckpt["sd_checkpoint_name"], + trained_model_checksum = embedding_ckpt["sd_checkpoint"] + ) + embeddings.append(embedding_info) + return embeddings - # Check num of embeddings and warn user only the first will be used - embedding_info["num_of_embeddings"] = len( - embedding_ckpt["string_to_token"] - ) - if embedding_info["num_of_embeddings"] > 1: - print(">> More than 1 embedding found. Will use the first one") - - embedding = list(embedding_ckpt["string_to_param"].values())[0] - except (AttributeError, KeyError): - return self._handle_broken_pt_variants(embedding_ckpt, embedding_file) - - embedding_info["embedding"] = embedding - embedding_info["num_vectors_per_token"] = embedding.size()[0] - embedding_info["token_dim"] = embedding.size()[1] - - try: - embedding_info["trained_steps"] = embedding_ckpt["step"] - embedding_info["trained_model_name"] = embedding_ckpt[ - "sd_checkpoint_name" - ] - embedding_info["trained_model_checksum"] = embedding_ckpt[ - "sd_checkpoint" - ] - except AttributeError: - print(">> No Training Details Found. Passing ...") - - # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ - # They are actually .bin files - elif len(embedding_ckpt.keys()) == 1: - embedding_info = self._parse_embedding_bin(embedding_file) - - else: - print(">> Invalid embedding format") - embedding_info = None - - return embedding_info - - def _parse_embedding_bin(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location="cpu") - embedding_info = {} - - if list(embedding_ckpt.keys()) == 0: - print(">> Invalid concepts file") - embedding_info = None - else: - for token in list(embedding_ckpt.keys()): - embedding_info["name"] = ( - token - or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" - ) - embedding_info["embedding"] = embedding_ckpt[token] - embedding_info[ - "num_vectors_per_token" - ] = 1 # All Concepts seem to default to 1 - embedding_info["token_dim"] = embedding_info["embedding"].size()[0] - - return embedding_info - - def _handle_broken_pt_variants( - self, embedding_ckpt: dict, embedding_file: str - ) -> dict: + def _parse_embedding_v2 ( + self, embedding_ckpt: dict, file_path: str + ) -> List[EmbeddingInfo]: """ - This handles the broken .pt file variants. We only know of one at present. + This handles embedding .pt file variant #2. """ - embedding_info = {} + basename = Path(file_path).stem + print(f' | Loading v2 embedding file: {basename}') + embeddings = list() + if isinstance( list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor ): - for token in list(embedding_ckpt["string_to_token"].keys()): - embedding_info["name"] = ( - token - if token != "*" - else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" + token_counter = 0 + for token,embedding in embedding_ckpt["string_to_param"].items(): + trigger = token if token != '*' \ + else f'<{basename}>' if token_counter == 0 \ + else f'<{basename}-{int(token_counter:=token_counter+1)}>' + embedding_info = EmbeddingInfo( + name = trigger, + embedding = embedding, + num_vectors_per_token = embedding.size()[0], + token_dim = embedding.size()[1], ) - embedding_info["embedding"] = embedding_ckpt[ - "string_to_param" - ].state_dict()[token] - embedding_info["num_vectors_per_token"] = embedding_info[ - "embedding" - ].shape[0] - embedding_info["token_dim"] = embedding_info["embedding"].size()[1] + embeddings.append(embedding_info) else: - print(">> Invalid embedding format") - embedding_info = None + print(f" ** {basename}: Unrecognized embedding format") - return embedding_info + return embeddings + + def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]: + """ + Parse 'version 3' of the .pt textual inversion embedding files. + """ + basename = Path(file_path).stem + print(f' | Loading v3 embedding file: {basename}') + embedding = embedding_ckpt['emb_params'] + embedding_info = EmbeddingInfo( + name = f'<{basename}>', + embedding = embedding, + num_vectors_per_token = embedding.size()[0], + token_dim = embedding.size()[1], + ) + return [embedding_info] + + def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]: + """ + Parse 'version 4' of the textual inversion embedding files. This one + is usually associated with .bin files trained by HuggingFace diffusers. + """ + basename = Path(filepath).stem + short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name + + print(f' | Loading v4 embedding file: {short_path}') + + embeddings = list() + if list(embedding_ckpt.keys()) == 0: + print(f" ** Invalid embeddings file: {short_path}") + else: + for token,embedding in embedding_ckpt.items(): + embedding_info = EmbeddingInfo( + name = token or f"<{basename}>", + embedding = embedding, + num_vectors_per_token = 1, # All Concepts seem to default to 1 + token_dim = embedding.size()[0], + ) + embeddings.append(embedding_info) + return embeddings