restore ability of textual inversion manager to read .pt files

- Fixes longstanding bug in the token vector size code which caused
  .pt files to be assigned the wrong token vector length. These
  were then tossed out during directory scanning.
This commit is contained in:
Lincoln Stein 2023-02-20 14:34:14 -05:00
parent 63e790b79b
commit 0d22fd59ed

View File

@ -61,9 +61,15 @@ class TextualInversionManager:
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file():
return
if str(ckpt_path).endswith(".DS_Store"):
return
try:
print(f'>> Scanning {str(ckpt_path)} for embedding terms')
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
print(
@ -87,7 +93,7 @@ class TextualInversionManager:
!= embedding_info['token_dim']
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
@ -333,7 +339,6 @@ class TextualInversionManager:
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
print(">> Detected .bin file masquerading as .pt file")
embedding_info = self._parse_embedding_bin(embedding_file)
else:
@ -372,9 +377,6 @@ class TextualInversionManager:
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
):
print(
">> Detected .pt file variant 1"
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = (
token
@ -387,7 +389,7 @@ class TextualInversionManager:
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else:
print(">> Invalid embedding format")
embedding_info = None