diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index bd40c51269..d9d4262e47 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -21,6 +21,7 @@ import re import warnings from pathlib import Path from typing import Union +from packaging import version import torch from safetensors.torch import load_file @@ -63,6 +64,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) from diffusers.utils import is_safetensors_available +import transformers from transformers import ( AutoFeatureExtractor, BertTokenizerFast, @@ -841,14 +843,16 @@ def convert_ldm_clip_checkpoint(checkpoint): key ] - try: - # transformers 4.31.0 and higher - this key no longer in state dict - position_ids = text_model_dict.pop("text_model.embeddings.position_ids") + # transformers 4.31.0 and higher - this key no longer in state dict + if version.parse(transformers.__version__) >= version.parse("4.31.0"): + position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None) text_model.load_state_dict(text_model_dict) - except RuntimeError: - # transformers 4.30.2 and lower - put the key back! - text_model_dict["text_model.embeddings.position_ids"] = position_ids - text_model.text_model.embeddings.position_ids.copy_(position_ids) + if position_ids is not None: + text_model.text_model.embeddings.position_ids.copy_(position_ids) + + # transformers 4.30.2 and lower - position_ids is part of state_dict + else: + text_model.load_state_dict(text_model_dict) return text_model @@ -954,9 +958,16 @@ def convert_open_clip_checkpoint(checkpoint): text_model_dict[new_key] = checkpoint[key] - position_ids = text_model_dict.pop("text_model.embeddings.position_ids") - text_model.load_state_dict(text_model_dict) - text_model.text_model.embeddings.position_ids.copy_(position_ids) + # transformers 4.31.0 and higher - this key no longer in state dict + if version.parse(transformers.__version__) >= version.parse("4.31.0"): + position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None) + text_model.load_state_dict(text_model_dict) + if position_ids is not None: + text_model.text_model.embeddings.position_ids.copy_(position_ids) + + # transformers 4.30.2 and lower - position_ids is part of state_dict + else: + text_model.load_state_dict(text_model_dict) return text_model