Replace deepcopy with a pickle roundtrip in apply_ti(...) to improve speed.

This commit is contained in:
Ryan Dick 2023-11-02 15:49:44 -04:00 committed by Kent Keirsey
parent a4a7b601a1
commit bac2a757e8

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import copy import pickle
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -165,7 +165,13 @@ class ModelPatcher:
new_tokens_added = None new_tokens_added = None
try: try:
ti_tokenizer = copy.deepcopy(tokenizer) # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
@ -439,7 +445,13 @@ class ONNXModelPatcher:
orig_embeddings = None orig_embeddings = None
try: try:
ti_tokenizer = copy.deepcopy(tokenizer) # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index): def _get_trigger(ti_name, index):