diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index 19fa5f6b0f..3a29aa5702 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -61,9 +61,15 @@ class TextualInversionManager: def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False): ckpt_path = Path(ckpt_path) + + if not ckpt_path.is_file(): + return + if str(ckpt_path).endswith(".DS_Store"): return + try: + print(f'>> Scanning {str(ckpt_path)} for embedding terms') scan_result = scan_file_path(str(ckpt_path)) if scan_result.infected_files == 1: print( @@ -87,7 +93,7 @@ class TextualInversionManager: != embedding_info['token_dim'] ): print( - f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model." + f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." ) return @@ -333,7 +339,6 @@ class TextualInversionManager: # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # They are actually .bin files elif len(embedding_ckpt.keys()) == 1: - print(">> Detected .bin file masquerading as .pt file") embedding_info = self._parse_embedding_bin(embedding_file) else: @@ -372,9 +377,6 @@ class TextualInversionManager: if isinstance( list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor ): - print( - ">> Detected .pt file variant 1" - ) # example at https://github.com/invoke-ai/InvokeAI/issues/1829 for token in list(embedding_ckpt["string_to_token"].keys()): embedding_info["name"] = ( token @@ -387,7 +389,7 @@ class TextualInversionManager: embedding_info["num_vectors_per_token"] = embedding_info[ "embedding" ].shape[0] - embedding_info["token_dim"] = embedding_info["embedding"].size()[0] + embedding_info["token_dim"] = embedding_info["embedding"].size()[1] else: print(">> Invalid embedding format") embedding_info = None