make convert work with both 4.30.2 and 4.31.0

This commit is contained in:
Lincoln Stein 2023-07-18 22:18:13 -04:00
parent f29bafd6ec
commit a690cca5b5

@ -841,9 +841,14 @@ def convert_ldm_clip_checkpoint(checkpoint):
key
]
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
text_model.load_state_dict(text_model_dict)
text_model.text_model.embeddings.position_ids.copy_(position_ids)
try:
# transformers 4.31.0 and higher - this key no longer in state dict
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
text_model.load_state_dict(text_model_dict)
except RuntimeError:
# transformers 4.30.2 and lower - put the key back!
text_model_dict["text_model.embeddings.position_ids"] = position_ids
text_model.text_model.embeddings.position_ids.copy_(position_ids)
return text_model