mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update Transformers to 4.34 and fix pad_to_multiple_of
This commit is contained in:
@ -166,6 +166,13 @@ class ModelPatcher:
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
# This is required since Transformers 4.32, see transformers/pull/25088
|
||||
# More information: https://tinyurl.com/ycxxzdhh
|
||||
if "A100" in torch.cuda.get_device_name():
|
||||
pad_to_multiple_of = 64
|
||||
else:
|
||||
pad_to_multiple_of = 8
|
||||
|
||||
try:
|
||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||
@ -175,7 +182,7 @@ class ModelPatcher:
|
||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
|
||||
|
||||
def _get_trigger(ti_name, index):
|
||||
trigger = ti_name
|
||||
@ -190,7 +197,7 @@ class ModelPatcher:
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
# modify text_encoder
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti_name, ti in ti_list:
|
||||
@ -222,7 +229,7 @@ class ModelPatcher:
|
||||
|
||||
finally:
|
||||
if init_tokens_count and new_tokens_added:
|
||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
Reference in New Issue
Block a user