diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index af9383bbd6..8b37ec602a 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -366,12 +366,12 @@ class EmbeddingManager(nn.Module): ''' embedding_info = {} if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor): - print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 - for token in list(embedding_ckpt['string_to_token'].keys()): - embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0]) - embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] - embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] - embedding_info['token_dim'] = embedding_info['embedding'].size()[0] + print(f'>> Variant Embedding Detected. Parsing: {embedding_file}') # example at https://github.com/invoke-ai/InvokeAI/issues/1829 + token = list(embedding_ckpt['string_to_token'].keys())[0] + embedding_info['name'] = os.path.basename(os.path.splitext(embedding_file)[0]) + embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token] + embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0] + embedding_info['token_dim'] = embedding_info['embedding'].size()[0] else: print('>> Invalid embedding format') embedding_info = None