mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[WIP] Load text_model.embeddings.position_ids outsude state_dict (#3829)
## What type of PR is this? (check all applicable)
- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
## Description
In transformers 4.31.0 `text_model.embeddings.position_ids` no longer
part of state_dict.
Fix untested as can't run right now but should be correct. Also need to
check how transformers 4.30.2 works with this fix.
## Related Tickets & Documents
8e5d1619b3 (diff-7f53db5caa73a4cbeb0dca3b396e3d52f30f025b8c48d4daf51eb7abb6e2b949R191)
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer
## QA Instructions, Screenshots, Recordings
```
File "C:\Users\artis\Documents\invokeai\.venv\lib\site-packages\invokeai\backend\model_management\convert_ckpt_to_diffusers.py", line 844, in convert_ldm_clip_checkpoint
text_model.load_state_dict(text_model_dict)
File "C:\Users\artis\Documents\invokeai\.venv\lib\site-packages\torch\nn\modules\module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CLIPTextModel:
Unexpected key(s) in state_dict: "text_model.embeddings.position_ids".
```
This commit is contained in:
commit
aca5c6de9a
@ -141,7 +141,7 @@ class EventServiceBase:
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
hash=model_info.hash,
|
||||
location=model_info.location,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
),
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
@ -63,6 +64,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.utils import is_safetensors_available
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
@ -841,7 +843,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
||||
key
|
||||
]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
# transformers 4.31.0 and higher - this key no longer in state dict
|
||||
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
|
||||
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
if position_ids is not None:
|
||||
text_model.text_model.embeddings.position_ids.copy_(position_ids)
|
||||
|
||||
# transformers 4.30.2 and lower - position_ids is part of state_dict
|
||||
else:
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
@ -947,7 +958,16 @@ def convert_open_clip_checkpoint(checkpoint):
|
||||
|
||||
text_model_dict[new_key] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
# transformers 4.31.0 and higher - this key no longer in state dict
|
||||
if version.parse(transformers.__version__) >= version.parse("4.31.0"):
|
||||
position_ids = text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
if position_ids is not None:
|
||||
text_model.text_model.embeddings.position_ids.copy_(position_ids)
|
||||
|
||||
# transformers 4.30.2 and lower - position_ids is part of state_dict
|
||||
else:
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
|
@ -83,7 +83,7 @@ dependencies = [
|
||||
"torchvision>=0.14.1",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.5",
|
||||
"transformers==4.30.2",
|
||||
"transformers~=4.31.0",
|
||||
"uvicorn[standard]==0.21.1",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user