From bac2a757e8f77d875cd3356f58a52dc0ea3ad518 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 2 Nov 2023 15:49:44 -0400 Subject: [PATCH] Replace deepcopy with a pickle roundtrip in apply_ti(...) to improve speed. --- invokeai/backend/model_management/lora.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 5002f278cc..2a0e465e03 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -1,6 +1,6 @@ from __future__ import annotations -import copy +import pickle from contextlib import contextmanager from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -165,7 +165,13 @@ class ModelPatcher: new_tokens_added = None try: - ti_tokenizer = copy.deepcopy(tokenizer) + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) ti_manager = TextualInversionManager(ti_tokenizer) init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings @@ -439,7 +445,13 @@ class ONNXModelPatcher: orig_embeddings = None try: - ti_tokenizer = copy.deepcopy(tokenizer) + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) ti_manager = TextualInversionManager(ti_tokenizer) def _get_trigger(ti_name, index):