[WIP] Load text_model.embeddings.position_ids outsude state_dict (#3829)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
      
## Description
In transformers 4.31.0 `text_model.embeddings.position_ids` no longer
part of state_dict.
Fix untested as can't run right now but should be correct. Also need to
check how transformers 4.30.2 works with this fix.

## Related Tickets & Documents


8e5d1619b3 (diff-7f53db5caa73a4cbeb0dca3b396e3d52f30f025b8c48d4daf51eb7abb6e2b949R191)

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer

## QA Instructions, Screenshots, Recordings

```
  File "C:\Users\artis\Documents\invokeai\.venv\lib\site-packages\invokeai\backend\model_management\convert_ckpt_to_diffusers.py", line 844, in convert_ldm_clip_checkpoint
    text_model.load_state_dict(text_model_dict)
  File "C:\Users\artis\Documents\invokeai\.venv\lib\site-packages\torch\nn\modules\module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CLIPTextModel:
        Unexpected key(s) in state_dict: "text_model.embeddings.position_ids".
```
This commit is contained in:
Lincoln Stein 2023-07-19 09:58:02 -04:00 committed by GitHub
commit aca5c6de9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 4 deletions

View File

@ -141,7 +141,7 @@ class EventServiceBase:
model_type=model_type,
submodel=submodel,
hash=model_info.hash,
location=model_info.location,
location=str(model_info.location),
precision=str(model_info.precision),
),
)

View File

@ -21,6 +21,7 @@ import re
import warnings
from pathlib import Path
from typing import Union
from packaging import version
import torch
from safetensors.torch import load_file
@ -63,6 +64,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.utils import is_safetensors_available
import transformers
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
@ -841,7 +843,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
key
]
text_model.load_state_dict(text_model_dict)
# transformers 4.31.0 and higher - this key no longer in state dict
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
if position_ids is not None:
text_model.text_model.embeddings.position_ids.copy_(position_ids)
# transformers 4.30.2 and lower - position_ids is part of state_dict
else:
text_model.load_state_dict(text_model_dict)
return text_model
@ -947,7 +958,16 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict[new_key] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
# transformers 4.31.0 and higher - this key no longer in state dict
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
if position_ids is not None:
text_model.text_model.embeddings.position_ids.copy_(position_ids)
# transformers 4.30.2 and lower - position_ids is part of state_dict
else:
text_model.load_state_dict(text_model_dict)
return text_model

View File

@ -83,7 +83,7 @@ dependencies = [
"torchvision>=0.14.1",
"torchmetrics==0.11.4",
"torchsde==0.2.5",
"transformers==4.30.2",
"transformers~=4.31.0",
"uvicorn[standard]==0.21.1",
"windows-curses; sys_platform=='win32'",
]