mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Skip weight initialization when resizing text encoder token embeddings to accomodate new TI embeddings. This saves time.
This commit is contained in:
parent
8e17e29a5c
commit
f7f697849c
@ -13,6 +13,7 @@ from safetensors.torch import load_file
|
|||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
||||||
|
|
||||||
from .models.lora import LoRAModel
|
from .models.lora import LoRAModel
|
||||||
|
|
||||||
@ -211,7 +212,11 @@ class ModelPatcher:
|
|||||||
for i in range(ti_embedding.shape[0]):
|
for i in range(ti_embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# Modify text_encoder.
|
||||||
|
# resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of
|
||||||
|
# this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some
|
||||||
|
# time.
|
||||||
|
with skip_torch_weight_init():
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user