support full-precision embeddings in half precision mode

This commit is contained in:
Paul Sajna
2022-08-26 08:30:58 -07:00
parent 8b682ac83b
commit 9b5101cd8d
2 changed files with 6 additions and 4 deletions

View File

@ -135,11 +135,13 @@ class EmbeddingManager(nn.Module):
torch.save({"string_to_token": self.string_to_token_dict,
"string_to_param": self.string_to_param_dict}, ckpt_path)
def load(self, ckpt_path):
def load(self, ckpt_path, full=True):
ckpt = torch.load(ckpt_path, map_location='cpu')
self.string_to_token_dict = ckpt["string_to_token"]
self.string_to_param_dict = ckpt["string_to_param"]
if not full:
for key, value in self.string_to_param_dict.items():
self.string_to_param_dict[key] = torch.nn.Parameter(value.half())
def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
@ -161,4 +163,4 @@ class EmbeddingManager(nn.Module):
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
return loss
return loss