Fix broken embedding variants (#2037)

This commit is contained in:
blessedcoolant 2022-12-17 16:07:05 +13:00 committed by GitHub
parent f1748d7017
commit dc39f8d6a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -366,12 +366,12 @@ class EmbeddingManager(nn.Module):
''' '''
embedding_info = {} embedding_info = {}
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor): if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 print(f'>> Variant Embedding Detected. Parsing: {embedding_file}') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt['string_to_token'].keys()): token = list(embedding_ckpt['string_to_token'].keys())[0]
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0]) embedding_info['name'] = os.path.basename(os.path.splitext(embedding_file)[0])
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
embedding_info['token_dim'] = embedding_info['embedding'].size()[0] embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
else: else:
print('>> Invalid embedding format') print('>> Invalid embedding format')
embedding_info = None embedding_info = None