From a0be83e370bdfeefb4e56b1af299e44ac7aa465f Mon Sep 17 00:00:00 2001 From: Wubbbi Date: Fri, 6 Oct 2023 18:55:59 +0200 Subject: [PATCH] Update Transformers to 4.34 and fix pad_to_multiple_of --- invokeai/backend/model_management/lora.py | 13 ++++++++++--- pyproject.toml | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 48fa8dfbe2..7b6fa8c9b6 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -166,6 +166,13 @@ class ModelPatcher: init_tokens_count = None new_tokens_added = None + # This is required since Transformers 4.32, see transformers/pull/25088 + # More information: https://tinyurl.com/ycxxzdhh + if "A100" in torch.cuda.get_device_name(): + pad_to_multiple_of = 64 + else: + pad_to_multiple_of = 8 + try: # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after @@ -175,7 +182,7 @@ class ModelPatcher: # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) ti_manager = TextualInversionManager(ti_tokenizer) - init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings + init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings def _get_trigger(ti_name, index): trigger = ti_name @@ -190,7 +197,7 @@ class ModelPatcher: new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) # modify text_encoder - text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) + text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) model_embeddings = text_encoder.get_input_embeddings() for ti_name, ti in ti_list: @@ -222,7 +229,7 @@ class ModelPatcher: finally: if init_tokens_count and new_tokens_added: - text_encoder.resize_token_embeddings(init_tokens_count) + text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of) @classmethod @contextmanager diff --git a/pyproject.toml b/pyproject.toml index 18b5605133..6e4c32774f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ dependencies = [ "torchvision~=0.16", "torchmetrics~=0.11.0", "torchsde~=0.2.5", - "transformers~=4.31.0", + "transformers~=4.34.0", "uvicorn[standard]~=0.21.1", "windows-curses; sys_platform=='win32'", ]