From 0d22fd59edf546c2a79c3e8d6844d2b4f72565a2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 20 Feb 2023 14:34:14 -0500 Subject: [PATCH 1/2] restore ability of textual inversion manager to read .pt files - Fixes longstanding bug in the token vector size code which caused .pt files to be assigned the wrong token vector length. These were then tossed out during directory scanning. --- ldm/modules/textual_inversion_manager.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index 19fa5f6b0f..3a29aa5702 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -61,9 +61,15 @@ class TextualInversionManager: def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False): ckpt_path = Path(ckpt_path) + + if not ckpt_path.is_file(): + return + if str(ckpt_path).endswith(".DS_Store"): return + try: + print(f'>> Scanning {str(ckpt_path)} for embedding terms') scan_result = scan_file_path(str(ckpt_path)) if scan_result.infected_files == 1: print( @@ -87,7 +93,7 @@ class TextualInversionManager: != embedding_info['token_dim'] ): print( - f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model." + f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." ) return @@ -333,7 +339,6 @@ class TextualInversionManager: # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # They are actually .bin files elif len(embedding_ckpt.keys()) == 1: - print(">> Detected .bin file masquerading as .pt file") embedding_info = self._parse_embedding_bin(embedding_file) else: @@ -372,9 +377,6 @@ class TextualInversionManager: if isinstance( list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor ): - print( - ">> Detected .pt file variant 1" - ) # example at https://github.com/invoke-ai/InvokeAI/issues/1829 for token in list(embedding_ckpt["string_to_token"].keys()): embedding_info["name"] = ( token @@ -387,7 +389,7 @@ class TextualInversionManager: embedding_info["num_vectors_per_token"] = embedding_info[ "embedding" ].shape[0] - embedding_info["token_dim"] = embedding_info["embedding"].size()[0] + embedding_info["token_dim"] = embedding_info["embedding"].size()[1] else: print(">> Invalid embedding format") embedding_info = None From 1d9845557f90646265e4cea4eaa89caed7429b18 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 20 Feb 2023 15:18:55 -0500 Subject: [PATCH 2/2] reduced verbosity of embed loading messages --- ldm/generate.py | 3 ++- ldm/modules/textual_inversion_manager.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index d7ea87a5fd..76629ff2f9 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -964,6 +964,7 @@ class Generate: seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) if self.embedding_path is not None: + print(f'>> Loading embeddings from {self.embedding_path}') for root, _, files in os.walk(self.embedding_path): for name in files: ti_path = os.path.join(root, name) @@ -971,7 +972,7 @@ class Generate: ti_path, defer_injecting_tokens=True ) print( - f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}' + f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}' ) self.model_name = model_name diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index 3a29aa5702..bf0d3ed8b9 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -69,7 +69,6 @@ class TextualInversionManager: return try: - print(f'>> Scanning {str(ckpt_path)} for embedding terms') scan_result = scan_file_path(str(ckpt_path)) if scan_result.infected_files == 1: print(