mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix broken embedding variants (#2037)
This commit is contained in:
parent
f1748d7017
commit
dc39f8d6a7
@ -366,12 +366,12 @@ class EmbeddingManager(nn.Module):
|
|||||||
'''
|
'''
|
||||||
embedding_info = {}
|
embedding_info = {}
|
||||||
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||||
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
print(f'>> Variant Embedding Detected. Parsing: {embedding_file}') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||||
for token in list(embedding_ckpt['string_to_token'].keys()):
|
token = list(embedding_ckpt['string_to_token'].keys())[0]
|
||||||
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
|
embedding_info['name'] = os.path.basename(os.path.splitext(embedding_file)[0])
|
||||||
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||||
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||||
else:
|
else:
|
||||||
print('>> Invalid embedding format')
|
print('>> Invalid embedding format')
|
||||||
embedding_info = None
|
embedding_info = None
|
||||||
|
Loading…
Reference in New Issue
Block a user