restore ability of textual inversion manager to read .pt files (#2746)

- Fixes longstanding bug in the token vector size code which caused .pt
files to be assigned the wrong token vector length. These were then
tossed out during directory scanning.
This commit is contained in:
Lincoln Stein 2023-02-20 15:34:56 -05:00 committed by GitHub
commit 17294bfa55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 7 deletions

View File

@ -964,6 +964,7 @@ class Generate:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None: if self.embedding_path is not None:
print(f'>> Loading embeddings from {self.embedding_path}')
for root, _, files in os.walk(self.embedding_path): for root, _, files in os.walk(self.embedding_path):
for name in files: for name in files:
ti_path = os.path.join(root, name) ti_path = os.path.join(root, name)
@ -971,7 +972,7 @@ class Generate:
ti_path, defer_injecting_tokens=True ti_path, defer_injecting_tokens=True
) )
print( print(
f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}' f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
) )
self.model_name = model_name self.model_name = model_name

View File

@ -61,8 +61,13 @@ class TextualInversionManager:
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False): def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path) ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file():
return
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
try: try:
scan_result = scan_file_path(str(ckpt_path)) scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
@ -87,7 +92,7 @@ class TextualInversionManager:
!= embedding_info['token_dim'] != embedding_info['token_dim']
): ):
print( print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model." f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
) )
return return
@ -333,7 +338,6 @@ class TextualInversionManager:
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files # They are actually .bin files
elif len(embedding_ckpt.keys()) == 1: elif len(embedding_ckpt.keys()) == 1:
print(">> Detected .bin file masquerading as .pt file")
embedding_info = self._parse_embedding_bin(embedding_file) embedding_info = self._parse_embedding_bin(embedding_file)
else: else:
@ -372,9 +376,6 @@ class TextualInversionManager:
if isinstance( if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
): ):
print(
">> Detected .pt file variant 1"
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt["string_to_token"].keys()): for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = ( embedding_info["name"] = (
token token
@ -387,7 +388,7 @@ class TextualInversionManager:
embedding_info["num_vectors_per_token"] = embedding_info[ embedding_info["num_vectors_per_token"] = embedding_info[
"embedding" "embedding"
].shape[0] ].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[0] embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else: else:
print(">> Invalid embedding format") print(">> Invalid embedding format")
embedding_info = None embedding_info = None