mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Replace deepcopy with a pickle roundtrip in apply_ti(...) to improve speed.
This commit is contained in:
parent
a4a7b601a1
commit
bac2a757e8
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
@ -165,7 +165,13 @@ class ModelPatcher:
|
||||
new_tokens_added = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||
# exiting this `apply_ti(...)` context manager.
|
||||
#
|
||||
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
|
||||
@ -439,7 +445,13 @@ class ONNXModelPatcher:
|
||||
orig_embeddings = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||
# exiting this `apply_ti(...)` context manager.
|
||||
#
|
||||
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti_name, index):
|
||||
|
Loading…
Reference in New Issue
Block a user