Support both pre and post 4.31.0 transformers

This commit is contained in:
Sergey Borisov 2023-07-19 06:15:17 +03:00
parent a690cca5b5
commit 2e7fc055c4

View File

@ -21,6 +21,7 @@ import re
import warnings
from pathlib import Path
from typing import Union
from packaging import version
import torch
from safetensors.torch import load_file
@ -63,6 +64,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.utils import is_safetensors_available
import transformers
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
@ -841,15 +843,17 @@ def convert_ldm_clip_checkpoint(checkpoint):
key
]
try:
# transformers 4.31.0 and higher - this key no longer in state dict
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
except RuntimeError:
# transformers 4.30.2 and lower - put the key back!
text_model_dict["text_model.embeddings.position_ids"] = position_ids
if position_ids is not None:
text_model.text_model.embeddings.position_ids.copy_(position_ids)
# transformers 4.30.2 and lower - position_ids is part of state_dict
else:
text_model.load_state_dict(text_model_dict)
return text_model
@ -954,10 +958,17 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict[new_key] = checkpoint[key]
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
# transformers 4.31.0 and higher - this key no longer in state dict
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
if position_ids is not None:
text_model.text_model.embeddings.position_ids.copy_(position_ids)
# transformers 4.30.2 and lower - position_ids is part of state_dict
else:
text_model.load_state_dict(text_model_dict)
return text_model
def replace_checkpoint_vae(checkpoint, vae_path: str):