mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Support both pre and post 4.31.0 transformers
This commit is contained in:
parent
a690cca5b5
commit
2e7fc055c4
@ -21,6 +21,7 @@ import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
@ -63,6 +64,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.utils import is_safetensors_available
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
@ -841,15 +843,17 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
||||
key
|
||||
]
|
||||
|
||||
try:
|
||||
# transformers 4.31.0 and higher - this key no longer in state dict
|
||||
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
|
||||
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
|
||||
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
except RuntimeError:
|
||||
# transformers 4.30.2 and lower - put the key back!
|
||||
text_model_dict["text_model.embeddings.position_ids"] = position_ids
|
||||
if position_ids is not None:
|
||||
text_model.text_model.embeddings.position_ids.copy_(position_ids)
|
||||
|
||||
# transformers 4.30.2 and lower - position_ids is part of state_dict
|
||||
else:
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
@ -954,10 +958,17 @@ def convert_open_clip_checkpoint(checkpoint):
|
||||
|
||||
text_model_dict[new_key] = checkpoint[key]
|
||||
|
||||
position_ids = text_model_dict.pop("text_model.embeddings.position_ids")
|
||||
# transformers 4.31.0 and higher - this key no longer in state dict
|
||||
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
|
||||
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
if position_ids is not None:
|
||||
text_model.text_model.embeddings.position_ids.copy_(position_ids)
|
||||
|
||||
# transformers 4.30.2 and lower - position_ids is part of state_dict
|
||||
else:
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
def replace_checkpoint_vae(checkpoint, vae_path: str):
|
||||
|
Loading…
Reference in New Issue
Block a user