Support both pre and post 4.31.0 transformers

This commit is contained in:
Sergey Borisov 2023-07-19 06:15:17 +03:00
parent a690cca5b5
commit 2e7fc055c4

View File

@ -21,6 +21,7 @@ import re
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from packaging import version
import torch import torch
from safetensors.torch import load_file from safetensors.torch import load_file
@ -63,6 +64,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
import transformers
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
BertTokenizerFast, BertTokenizerFast,
@ -841,15 +843,17 @@ def convert_ldm_clip_checkpoint(checkpoint):
key key
] ]
try:
# transformers 4.31.0 and higher - this key no longer in state dict # transformers 4.31.0 and higher - this key no longer in state dict
position_ids = text_model_dict.pop("text_model.embeddings.position_ids") if version.parse(transformers.__version__) >= version.parse("4.31.0"):
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict)
except RuntimeError: if position_ids is not None:
# transformers 4.30.2 and lower - put the key back!
text_model_dict["text_model.embeddings.position_ids"] = position_ids
text_model.text_model.embeddings.position_ids.copy_(position_ids) text_model.text_model.embeddings.position_ids.copy_(position_ids)
# transformers 4.30.2 and lower - position_ids is part of state_dict
else:
text_model.load_state_dict(text_model_dict)
return text_model return text_model
@ -954,10 +958,17 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict[new_key] = checkpoint[key] text_model_dict[new_key] = checkpoint[key]
position_ids = text_model_dict.pop("text_model.embeddings.position_ids") # transformers 4.31.0 and higher - this key no longer in state dict
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict)
if position_ids is not None:
text_model.text_model.embeddings.position_ids.copy_(position_ids) text_model.text_model.embeddings.position_ids.copy_(position_ids)
# transformers 4.30.2 and lower - position_ids is part of state_dict
else:
text_model.load_state_dict(text_model_dict)
return text_model return text_model
def replace_checkpoint_vae(checkpoint, vae_path: str): def replace_checkpoint_vae(checkpoint, vae_path: str):