Load text_model.embeddings.position_ids outsude state_dict

This commit is contained in:
Sergey Borisov 2023-07-19 04:18:43 +03:00
parent 3c5a0c95b3
commit 0aa7193d3b

View File

@ -841,7 +841,9 @@ def convert_ldm_clip_checkpoint(checkpoint):
key
]
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
text_model.load_state_dict(text_model_dict)
text_model.text_model.embeddings.position_ids.copy_(position_ids)
return text_model
@ -947,7 +949,9 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict[new_key] = checkpoint[key]
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
text_model.load_state_dict(text_model_dict)
text_model.text_model.embeddings.position_ids.copy_(position_ids)
return text_model