Replace deepcopy with a pickle roundtrip in apply_ti(...) to improve speed.

This commit is contained in:
Ryan Dick 2023-11-02 15:49:44 -04:00 committed by Kent Keirsey
parent a4a7b601a1
commit bac2a757e8

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import copy
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
@ -165,7 +165,13 @@ class ModelPatcher:
new_tokens_added = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
@ -439,7 +445,13 @@ class ONNXModelPatcher:
orig_embeddings = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index):