diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index f5d66ca51a..d72f55794d 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -13,6 +13,7 @@ from safetensors.torch import load_file from transformers import CLIPTextModel, CLIPTokenizer from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init from .models.lora import LoRAModel @@ -211,8 +212,12 @@ class ModelPatcher: for i in range(ti_embedding.shape[0]): new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) - # modify text_encoder - text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) + # Modify text_encoder. + # resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of + # this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some + # time. + with skip_torch_weight_init(): + text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) model_embeddings = text_encoder.get_input_embeddings() for ti_name, ti in ti_list: