From abe4dc8ac11e8dd53121b1c390194dda785ec49d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 27 Mar 2023 09:39:03 -0400 Subject: [PATCH] Add support for yet another textual inversion embedding format - This PR adds support for embedding files that contain a single key "emb_params". The only example I know of this format is the "EasyNegative" embedding on HuggingFace, but there are certainly others. - This PR also adds support for loading embedding files that have been saved in safetensors format. - It also cleans up the code so that the logic of probing for and selecting the right format parser is clear. --- .../textual_inversion_manager.py | 199 +++++++++--------- 1 file changed, 102 insertions(+), 97 deletions(-) diff --git a/invokeai/backend/stable_diffusion/textual_inversion_manager.py b/invokeai/backend/stable_diffusion/textual_inversion_manager.py index 2b043afab7..9fa076693e 100644 --- a/invokeai/backend/stable_diffusion/textual_inversion_manager.py +++ b/invokeai/backend/stable_diffusion/textual_inversion_manager.py @@ -1,17 +1,17 @@ -import os import traceback from dataclasses import dataclass from pathlib import Path from typing import Optional, Union +import safetensors.torch import torch + from compel.embeddings_provider import BaseTextualInversionManager from picklescan.scanner import scan_file_path from transformers import CLIPTextModel, CLIPTokenizer from .concepts_lib import HuggingFaceConceptsLibrary - @dataclass class TextualInversion: trigger_string: str @@ -72,20 +72,6 @@ class TextualInversionManager(BaseTextualInversionManager): if str(ckpt_path).endswith(".DS_Store"): return - try: - scan_result = scan_file_path(str(ckpt_path)) - if scan_result.infected_files == 1: - print( - f"\n### Security Issues Found in Model: {scan_result.issues_count}" - ) - print("### For your safety, InvokeAI will not load this embed.") - return - except Exception: - print( - f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt." - ) - return - embedding_info = self._parse_embedding(str(ckpt_path)) if embedding_info is None: @@ -96,7 +82,7 @@ class TextualInversionManager(BaseTextualInversionManager): != embedding_info["token_dim"] ): print( - f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." + f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." ) return @@ -309,92 +295,72 @@ class TextualInversionManager(BaseTextualInversionManager): return token_id - def _parse_embedding(self, embedding_file: str): - file_type = embedding_file.split(".")[-1] - if file_type == "pt": - return self._parse_embedding_pt(embedding_file) - elif file_type == "bin": - return self._parse_embedding_bin(embedding_file) - else: - print(f"** Notice: unrecognized embedding file format: {embedding_file}") + def _parse_embedding(self, embedding_file: str)->dict: + suffix = Path(embedding_file).suffix + try: + if suffix in [".pt",".ckpt",".bin"]: + scan_result = scan_file_path(embedding_file) + if scan_result.infected_files == 1: + print( + f" ** Security Issues Found in Model: {scan_result.issues_count}" + ) + print(" ** For your safety, InvokeAI will not load this embed.") + return + ckpt = torch.load(embedding_file,map_location="cpu") + else: + ckpt = safetensors.torch.load_file(embedding_file) + except Exception as e: + print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}") return None - - def _parse_embedding_pt(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location="cpu") - embedding_info = {} - - # Check if valid embedding file - if "string_to_token" and "string_to_param" in embedding_ckpt: - # Catch variants that do not have the expected keys or values. - try: - embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( - os.path.splitext(embedding_file)[0] - ) - - # Check num of embeddings and warn user only the first will be used - embedding_info["num_of_embeddings"] = len( - embedding_ckpt["string_to_token"] - ) - if embedding_info["num_of_embeddings"] > 1: - print(">> More than 1 embedding found. Will use the first one") - - embedding = list(embedding_ckpt["string_to_param"].values())[0] - except (AttributeError, KeyError): - return self._handle_broken_pt_variants(embedding_ckpt, embedding_file) - - embedding_info["embedding"] = embedding - embedding_info["num_vectors_per_token"] = embedding.size()[0] - embedding_info["token_dim"] = embedding.size()[1] - - try: - embedding_info["trained_steps"] = embedding_ckpt["step"] - embedding_info["trained_model_name"] = embedding_ckpt[ - "sd_checkpoint_name" - ] - embedding_info["trained_model_checksum"] = embedding_ckpt[ - "sd_checkpoint" - ] - except AttributeError: - print(">> No Training Details Found. Passing ...") - - # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ - # They are actually .bin files - elif len(embedding_ckpt.keys()) == 1: - embedding_info = self._parse_embedding_bin(embedding_file) - + + # try to figure out what kind of embedding file it is and parse accordingly + keys = list(ckpt.keys()) + if all(x in keys for x in ['string_to_token','string_to_param','name','step']): + return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt + + elif all(x in keys for x in ['string_to_token','string_to_param']): + return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt + + elif 'emb_params' in keys: + return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors + else: - print(">> Invalid embedding format") - embedding_info = None + return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file + def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str): + basename = Path(file_path).stem + print(f' | Loading v1 embedding file: {basename}') + + embedding_info = {} + embedding_info["name"] = embedding_ckpt["name"] + + # Check num of embeddings and warn user only the first will be used + embedding_info["num_of_embeddings"] = len( + embedding_ckpt["string_to_token"] + ) + if embedding_info["num_of_embeddings"] > 1: + print(" | More than 1 embedding found. Will use the first one") + embedding = list(embedding_ckpt["string_to_param"].values())[0] + embedding_info["embedding"] = embedding + embedding_info["num_vectors_per_token"] = embedding.size()[0] + embedding_info["token_dim"] = embedding.size()[1] + embedding_info["trained_steps"] = embedding_ckpt["step"] + embedding_info["trained_model_name"] = embedding_ckpt[ + "sd_checkpoint_name" + ] + embedding_info["trained_model_checksum"] = embedding_ckpt[ + "sd_checkpoint" + ] return embedding_info - def _parse_embedding_bin(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location="cpu") - embedding_info = {} - - if list(embedding_ckpt.keys()) == 0: - print(">> Invalid concepts file") - embedding_info = None - else: - for token in list(embedding_ckpt.keys()): - embedding_info["name"] = ( - token - or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" - ) - embedding_info["embedding"] = embedding_ckpt[token] - embedding_info[ - "num_vectors_per_token" - ] = 1 # All Concepts seem to default to 1 - embedding_info["token_dim"] = embedding_info["embedding"].size()[0] - - return embedding_info - - def _handle_broken_pt_variants( - self, embedding_ckpt: dict, embedding_file: str + def _parse_embedding_v2 ( + self, embedding_ckpt: dict, file_path: str ) -> dict: """ - This handles the broken .pt file variants. We only know of one at present. + This handles embedding .pt file variant #2. """ + basename = Path(file_path).stem + print(f' | Loading v2 embedding file: {basename}') embedding_info = {} if isinstance( list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor @@ -403,7 +369,7 @@ class TextualInversionManager(BaseTextualInversionManager): embedding_info["name"] = ( token if token != "*" - else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" + else f"<{basename}>" ) embedding_info["embedding"] = embedding_ckpt[ "string_to_param" @@ -413,7 +379,46 @@ class TextualInversionManager(BaseTextualInversionManager): ].shape[0] embedding_info["token_dim"] = embedding_info["embedding"].size()[1] else: - print(">> Invalid embedding format") + print(f" ** {basename}: Unrecognized embedding format") embedding_info = None return embedding_info + + def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str): + """ + Parse 'version 3' of the .pt textual inversion embedding files. + """ + basename = Path(file_path).stem + print(f' | Loading v3 embedding file: {basename}') + embedding_info = {} + embedding_info["name"] = f'<{basename}>' + embedding_info["num_of_embeddings"] = 1 + embedding = embedding_ckpt['emb_params'] + embedding_info["embedding"] = embedding + embedding_info["num_vectors_per_token"] = embedding.size()[0] + embedding_info["token_dim"] = embedding.size()[1] + return embedding_info + + def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str): + """ + Parse 'version 4' of the textual inversion embedding files. This one + is usually associated with .bin files trained by HuggingFace diffusers. + """ + basename = Path(filepath).stem + short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name + + print(f' | Loading v4 embedding file: {short_path}') + embedding_info = {} + if list(embedding_ckpt.keys()) == 0: + print(f" ** Invalid embeddings file: {short_path}") + embedding_info = None + else: + for token in list(embedding_ckpt.keys()): + embedding_info["name"] = ( + token + or f"<{basename}>" + ) + embedding_info["embedding"] = embedding_ckpt[token] + embedding_info["num_vectors_per_token"] = 1 # All Concepts seem to default to 1 + embedding_info["token_dim"] = embedding_info["embedding"].size()[0] + return embedding_info