mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into Millu-patch-1
This commit is contained in:
commit
d5aa74623d
@ -166,6 +166,15 @@ class ModelPatcher:
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
# TODO: This is required since Transformers 4.32 see
|
||||
# https://github.com/huggingface/transformers/pull/25088
|
||||
# More information by NVIDIA:
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
||||
# This value might need to be changed in the future and take the GPUs model into account as there seem
|
||||
# to be ideal values for different GPUS. This value is temporary!
|
||||
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
|
||||
pad_to_multiple_of = 8
|
||||
|
||||
try:
|
||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||
@ -175,7 +184,7 @@ class ModelPatcher:
|
||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
|
||||
|
||||
def _get_trigger(ti_name, index):
|
||||
trigger = ti_name
|
||||
@ -190,7 +199,7 @@ class ModelPatcher:
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
# modify text_encoder
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti_name, ti in ti_list:
|
||||
@ -222,7 +231,7 @@ class ModelPatcher:
|
||||
|
||||
finally:
|
||||
if init_tokens_count and new_tokens_added:
|
||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
@ -82,7 +82,7 @@ dependencies = [
|
||||
"torchvision~=0.16",
|
||||
"torchmetrics~=0.11.0",
|
||||
"torchsde~=0.2.5",
|
||||
"transformers~=4.31.0",
|
||||
"transformers~=4.35.0",
|
||||
"uvicorn[standard]~=0.21.1",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user