restore ability of textual inversion manager to read .pt files (#2746)

- Fixes longstanding bug in the token vector size code which caused .pt
files to be assigned the wrong token vector length. These were then
tossed out during directory scanning.
This commit is contained in:
Lincoln Stein 2023-02-20 15:34:56 -05:00 committed by GitHub
commit 17294bfa55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 7 deletions

View File

@ -964,6 +964,7 @@ class Generate:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None:
print(f'>> Loading embeddings from {self.embedding_path}')
for root, _, files in os.walk(self.embedding_path):
for name in files:
ti_path = os.path.join(root, name)
@ -971,7 +972,7 @@ class Generate:
ti_path, defer_injecting_tokens=True
)
print(
f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
)
self.model_name = model_name

View File

@ -61,8 +61,13 @@ class TextualInversionManager:
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file():
return
if str(ckpt_path).endswith(".DS_Store"):
return
try:
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
@ -87,7 +92,7 @@ class TextualInversionManager:
!= embedding_info['token_dim']
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
@ -333,7 +338,6 @@ class TextualInversionManager:
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
print(">> Detected .bin file masquerading as .pt file")
embedding_info = self._parse_embedding_bin(embedding_file)
else:
@ -372,9 +376,6 @@ class TextualInversionManager:
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
):
print(
">> Detected .pt file variant 1"
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = (
token
@ -387,7 +388,7 @@ class TextualInversionManager:
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else:
print(">> Invalid embedding format")
embedding_info = None