2024-02-05 04:18:00 +00:00
|
|
|
# Copyright (c) 2024 Ryan Dick, Lincoln D. Stein, and the InvokeAI Development Team
|
|
|
|
"""These classes implement model patching with LoRAs and Textual Inversions."""
|
2024-02-29 23:04:59 +00:00
|
|
|
|
2024-02-05 04:18:00 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import pickle
|
|
|
|
from contextlib import contextmanager
|
2024-07-29 21:34:37 +00:00
|
|
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
2024-02-05 04:18:00 +00:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2024-02-06 03:56:32 +00:00
|
|
|
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
2024-02-29 06:02:28 +00:00
|
|
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
2024-02-05 04:18:00 +00:00
|
|
|
|
|
|
|
from invokeai.app.shared.models import FreeUConfig
|
2024-07-03 16:20:35 +00:00
|
|
|
from invokeai.backend.lora import LoRAModelRaw
|
2024-02-06 03:56:32 +00:00
|
|
|
from invokeai.backend.model_manager import AnyModel
|
2024-02-05 04:18:00 +00:00
|
|
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
|
|
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
2024-07-26 23:39:53 +00:00
|
|
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
2024-07-03 16:20:35 +00:00
|
|
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
2024-07-30 00:39:01 +00:00
|
|
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
2024-02-05 04:18:00 +00:00
|
|
|
|
|
|
|
"""
|
|
|
|
loras = [
|
|
|
|
(lora_model1, 0.7),
|
|
|
|
(lora_model2, 0.4),
|
|
|
|
]
|
|
|
|
with LoRAHelper.apply_lora_unet(unet, loras):
|
|
|
|
# unet with applied loras
|
|
|
|
# unmodified unet
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
class ModelPatcher:
|
2024-07-16 17:03:29 +00:00
|
|
|
@staticmethod
|
|
|
|
@contextmanager
|
|
|
|
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
|
|
|
|
"""A context manager that patches `unet` with the provided attention processor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
unet (UNet2DConditionModel): The UNet model to patch.
|
|
|
|
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
|
|
|
"""
|
|
|
|
unet_orig_processors = unet.attn_processors
|
|
|
|
|
2024-07-17 00:48:37 +00:00
|
|
|
# create separate instance for each attention, to be able modify each attention separately
|
|
|
|
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
|
|
|
try:
|
|
|
|
unet.set_attn_processor(unet_new_processors)
|
2024-07-16 17:03:29 +00:00
|
|
|
yield None
|
|
|
|
|
|
|
|
finally:
|
|
|
|
unet.set_attn_processor(unet_orig_processors)
|
|
|
|
|
2024-02-05 04:18:00 +00:00
|
|
|
@staticmethod
|
|
|
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
|
|
|
assert "." not in lora_key
|
|
|
|
|
|
|
|
if not lora_key.startswith(prefix):
|
|
|
|
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
|
|
|
|
|
|
|
module = model
|
|
|
|
module_key = ""
|
|
|
|
key_parts = lora_key[len(prefix) :].split("_")
|
|
|
|
|
|
|
|
submodule_name = key_parts.pop(0)
|
|
|
|
|
|
|
|
while len(key_parts) > 0:
|
|
|
|
try:
|
|
|
|
module = module.get_submodule(submodule_name)
|
|
|
|
module_key += "." + submodule_name
|
|
|
|
submodule_name = key_parts.pop(0)
|
|
|
|
except Exception:
|
|
|
|
submodule_name += "_" + key_parts.pop(0)
|
|
|
|
|
|
|
|
module = module.get_submodule(submodule_name)
|
|
|
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
|
|
|
|
|
|
|
return (module_key, module)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_unet(
|
|
|
|
cls,
|
|
|
|
unet: UNet2DConditionModel,
|
2024-02-10 23:09:45 +00:00
|
|
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
2024-07-26 23:39:53 +00:00
|
|
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
2024-06-13 17:10:03 +00:00
|
|
|
) -> Generator[None, None, None]:
|
2024-06-06 13:53:35 +00:00
|
|
|
with cls.apply_lora(
|
|
|
|
unet,
|
|
|
|
loras=loras,
|
|
|
|
prefix="lora_unet_",
|
2024-07-26 23:39:53 +00:00
|
|
|
cached_weights=cached_weights,
|
2024-06-06 13:53:35 +00:00
|
|
|
):
|
2024-02-05 04:18:00 +00:00
|
|
|
yield
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_text_encoder(
|
|
|
|
cls,
|
|
|
|
text_encoder: CLIPTextModel,
|
2024-02-06 03:56:32 +00:00
|
|
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
2024-07-26 23:39:53 +00:00
|
|
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
2024-06-13 17:10:03 +00:00
|
|
|
) -> Generator[None, None, None]:
|
2024-07-26 23:39:53 +00:00
|
|
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
2024-02-05 04:18:00 +00:00
|
|
|
yield
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora(
|
|
|
|
cls,
|
2024-02-06 03:56:32 +00:00
|
|
|
model: AnyModel,
|
|
|
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
2024-02-05 04:18:00 +00:00
|
|
|
prefix: str,
|
2024-07-26 23:39:53 +00:00
|
|
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
2024-06-13 17:10:03 +00:00
|
|
|
) -> Generator[None, None, None]:
|
2024-06-06 13:53:35 +00:00
|
|
|
"""
|
|
|
|
Apply one or more LoRAs to a model.
|
|
|
|
|
|
|
|
:param model: The model to patch.
|
|
|
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
|
|
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
2024-07-26 23:39:53 +00:00
|
|
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
2024-06-06 13:53:35 +00:00
|
|
|
"""
|
2024-07-30 00:39:01 +00:00
|
|
|
original_weights = OriginalWeightsStorage(cached_weights)
|
2024-02-05 04:18:00 +00:00
|
|
|
try:
|
2024-07-26 23:39:53 +00:00
|
|
|
for lora_model, lora_weight in loras:
|
2024-07-29 21:34:37 +00:00
|
|
|
LoRAExt.patch_model(
|
2024-07-26 23:39:53 +00:00
|
|
|
model=model,
|
|
|
|
prefix=prefix,
|
|
|
|
lora=lora_model,
|
|
|
|
lora_weight=lora_weight,
|
2024-07-29 21:34:37 +00:00
|
|
|
original_weights=original_weights,
|
2024-07-26 23:39:53 +00:00
|
|
|
)
|
|
|
|
del lora_model
|
|
|
|
|
|
|
|
yield
|
2024-02-05 04:18:00 +00:00
|
|
|
|
|
|
|
finally:
|
|
|
|
with torch.no_grad():
|
2024-07-30 00:39:01 +00:00
|
|
|
for param_key, weight in original_weights.get_changed_weights():
|
2024-07-26 23:39:53 +00:00
|
|
|
model.get_parameter(param_key).copy_(weight)
|
2024-02-05 04:18:00 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_ti(
|
|
|
|
cls,
|
|
|
|
tokenizer: CLIPTokenizer,
|
2024-02-29 06:02:28 +00:00
|
|
|
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
2024-02-06 03:56:32 +00:00
|
|
|
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
|
|
|
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
2024-02-05 04:18:00 +00:00
|
|
|
init_tokens_count = None
|
|
|
|
new_tokens_added = None
|
|
|
|
|
|
|
|
# TODO: This is required since Transformers 4.32 see
|
|
|
|
# https://github.com/huggingface/transformers/pull/25088
|
|
|
|
# More information by NVIDIA:
|
|
|
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
|
|
|
# This value might need to be changed in the future and take the GPUs model into account as there seem
|
|
|
|
# to be ideal values for different GPUS. This value is temporary!
|
|
|
|
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
|
|
|
|
pad_to_multiple_of = 8
|
|
|
|
|
|
|
|
try:
|
|
|
|
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
|
|
|
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
|
|
|
# exiting this `apply_ti(...)` context manager.
|
|
|
|
#
|
|
|
|
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
|
|
|
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
|
|
|
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
|
|
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
|
|
|
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
|
|
|
|
|
|
|
|
def _get_trigger(ti_name: str, index: int) -> str:
|
|
|
|
trigger = ti_name
|
|
|
|
if index > 0:
|
|
|
|
trigger += f"-!pad-{i}"
|
|
|
|
return f"<{trigger}>"
|
|
|
|
|
2024-02-06 03:56:32 +00:00
|
|
|
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor:
|
2024-02-05 04:18:00 +00:00
|
|
|
# for SDXL models, select the embedding that matches the text encoder's dimensions
|
|
|
|
if ti.embedding_2 is not None:
|
|
|
|
return (
|
|
|
|
ti.embedding_2
|
|
|
|
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
|
|
|
|
else ti.embedding
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return ti.embedding
|
|
|
|
|
|
|
|
# modify tokenizer
|
|
|
|
new_tokens_added = 0
|
|
|
|
for ti_name, ti in ti_list:
|
|
|
|
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
|
|
|
|
|
|
|
|
for i in range(ti_embedding.shape[0]):
|
|
|
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
|
|
|
|
|
|
|
# Modify text_encoder.
|
|
|
|
# resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of
|
|
|
|
# this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some
|
|
|
|
# time.
|
|
|
|
with skip_torch_weight_init():
|
|
|
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
|
|
|
model_embeddings = text_encoder.get_input_embeddings()
|
|
|
|
|
|
|
|
for ti_name, ti in ti_list:
|
2024-02-06 03:56:32 +00:00
|
|
|
assert isinstance(ti, TextualInversionModelRaw)
|
2024-02-05 04:18:00 +00:00
|
|
|
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
|
|
|
|
|
|
|
|
ti_tokens = []
|
|
|
|
for i in range(ti_embedding.shape[0]):
|
|
|
|
embedding = ti_embedding[i]
|
|
|
|
trigger = _get_trigger(ti_name, i)
|
|
|
|
|
|
|
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
|
|
|
if token_id == ti_tokenizer.unk_token_id:
|
|
|
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
|
|
|
|
|
|
|
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
|
|
|
raise ValueError(
|
|
|
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
|
|
|
f" {embedding.shape[0]}, but the current model has token dimension"
|
|
|
|
f" {model_embeddings.weight.data[token_id].shape[0]}."
|
|
|
|
)
|
|
|
|
|
|
|
|
model_embeddings.weight.data[token_id] = embedding.to(
|
|
|
|
device=text_encoder.device, dtype=text_encoder.dtype
|
|
|
|
)
|
|
|
|
ti_tokens.append(token_id)
|
|
|
|
|
|
|
|
if len(ti_tokens) > 1:
|
|
|
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
|
|
|
|
|
|
|
yield ti_tokenizer, ti_manager
|
|
|
|
|
|
|
|
finally:
|
|
|
|
if init_tokens_count and new_tokens_added:
|
|
|
|
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_clip_skip(
|
|
|
|
cls,
|
2024-02-29 06:02:28 +00:00
|
|
|
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
2024-02-05 04:18:00 +00:00
|
|
|
clip_skip: int,
|
2024-02-06 03:56:32 +00:00
|
|
|
) -> None:
|
2024-02-05 04:18:00 +00:00
|
|
|
skipped_layers = []
|
|
|
|
try:
|
|
|
|
for _i in range(clip_skip):
|
|
|
|
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
finally:
|
|
|
|
while len(skipped_layers) > 0:
|
|
|
|
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_freeu(
|
|
|
|
cls,
|
|
|
|
unet: UNet2DConditionModel,
|
|
|
|
freeu_config: Optional[FreeUConfig] = None,
|
2024-02-06 03:56:32 +00:00
|
|
|
) -> None:
|
2024-02-05 04:18:00 +00:00
|
|
|
did_apply_freeu = False
|
|
|
|
try:
|
|
|
|
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
|
|
|
|
if freeu_config is not None:
|
|
|
|
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
|
|
|
|
did_apply_freeu = True
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
finally:
|
|
|
|
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
|
|
|
|
if did_apply_freeu:
|
|
|
|
unet.disable_freeu()
|
|
|
|
|
|
|
|
|
|
|
|
class ONNXModelPatcher:
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_unet(
|
|
|
|
cls,
|
|
|
|
unet: OnnxRuntimeModel,
|
2024-02-10 23:09:45 +00:00
|
|
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
2024-02-06 03:56:32 +00:00
|
|
|
) -> None:
|
2024-02-05 04:18:00 +00:00
|
|
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
|
|
|
yield
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_text_encoder(
|
|
|
|
cls,
|
|
|
|
text_encoder: OnnxRuntimeModel,
|
|
|
|
loras: List[Tuple[LoRAModelRaw, float]],
|
2024-02-06 03:56:32 +00:00
|
|
|
) -> None:
|
2024-02-05 04:18:00 +00:00
|
|
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
|
|
|
yield
|
|
|
|
|
|
|
|
# based on
|
|
|
|
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora(
|
|
|
|
cls,
|
|
|
|
model: IAIOnnxRuntimeModel,
|
|
|
|
loras: List[Tuple[LoRAModelRaw, float]],
|
|
|
|
prefix: str,
|
2024-02-06 03:56:32 +00:00
|
|
|
) -> None:
|
2024-07-03 16:20:35 +00:00
|
|
|
from invokeai.backend.models.base import IAIOnnxRuntimeModel
|
2024-02-05 04:18:00 +00:00
|
|
|
|
|
|
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
|
|
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
|
|
|
|
|
|
orig_weights = {}
|
|
|
|
|
|
|
|
try:
|
|
|
|
blended_loras: Dict[str, torch.Tensor] = {}
|
|
|
|
|
|
|
|
for lora, lora_weight in loras:
|
|
|
|
for layer_key, layer in lora.layers.items():
|
|
|
|
if not layer_key.startswith(prefix):
|
|
|
|
continue
|
|
|
|
|
|
|
|
layer.to(dtype=torch.float32)
|
|
|
|
layer_key = layer_key.replace(prefix, "")
|
|
|
|
# TODO: rewrite to pass original tensor weight(required by ia3)
|
|
|
|
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
|
|
|
if layer_key in blended_loras:
|
|
|
|
blended_loras[layer_key] += layer_weight
|
|
|
|
else:
|
|
|
|
blended_loras[layer_key] = layer_weight
|
|
|
|
|
|
|
|
node_names = {}
|
|
|
|
for node in model.nodes.values():
|
|
|
|
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
|
|
|
|
|
|
|
for layer_key, lora_weight in blended_loras.items():
|
|
|
|
conv_key = layer_key + "_Conv"
|
|
|
|
gemm_key = layer_key + "_Gemm"
|
|
|
|
matmul_key = layer_key + "_MatMul"
|
|
|
|
|
|
|
|
if conv_key in node_names or gemm_key in node_names:
|
|
|
|
if conv_key in node_names:
|
|
|
|
conv_node = model.nodes[node_names[conv_key]]
|
|
|
|
else:
|
|
|
|
conv_node = model.nodes[node_names[gemm_key]]
|
|
|
|
|
|
|
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
|
|
|
orig_weight = model.tensors[weight_name]
|
|
|
|
|
|
|
|
if orig_weight.shape[-2:] == (1, 1):
|
|
|
|
if lora_weight.shape[-2:] == (1, 1):
|
|
|
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
|
|
|
else:
|
|
|
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
|
|
|
|
|
|
|
new_weight = np.expand_dims(new_weight, (2, 3))
|
|
|
|
else:
|
|
|
|
if orig_weight.shape != lora_weight.shape:
|
|
|
|
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
|
|
|
else:
|
|
|
|
new_weight = orig_weight + lora_weight
|
|
|
|
|
|
|
|
orig_weights[weight_name] = orig_weight
|
|
|
|
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
|
|
|
|
|
|
|
elif matmul_key in node_names:
|
|
|
|
weight_node = model.nodes[node_names[matmul_key]]
|
|
|
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
|
|
|
|
|
|
|
orig_weight = model.tensors[matmul_name]
|
|
|
|
new_weight = orig_weight + lora_weight.transpose()
|
|
|
|
|
|
|
|
orig_weights[matmul_name] = orig_weight
|
|
|
|
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# warn? err?
|
|
|
|
pass
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# restore original weights
|
|
|
|
for name, orig_weight in orig_weights.items():
|
|
|
|
model.tensors[name] = orig_weight
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_ti(
|
|
|
|
cls,
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
text_encoder: IAIOnnxRuntimeModel,
|
|
|
|
ti_list: List[Tuple[str, Any]],
|
2024-02-06 03:56:32 +00:00
|
|
|
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
2024-07-03 16:20:35 +00:00
|
|
|
from invokeai.backend.models.base import IAIOnnxRuntimeModel
|
2024-02-05 04:18:00 +00:00
|
|
|
|
|
|
|
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
|
|
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
|
|
|
|
|
|
orig_embeddings = None
|
|
|
|
|
|
|
|
try:
|
|
|
|
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
|
|
|
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
|
|
|
# exiting this `apply_ti(...)` context manager.
|
|
|
|
#
|
|
|
|
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
|
|
|
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
|
|
|
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
|
|
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
|
|
|
|
|
|
|
def _get_trigger(ti_name: str, index: int) -> str:
|
|
|
|
trigger = ti_name
|
|
|
|
if index > 0:
|
|
|
|
trigger += f"-!pad-{i}"
|
|
|
|
return f"<{trigger}>"
|
|
|
|
|
|
|
|
# modify text_encoder
|
|
|
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
|
|
|
|
|
|
|
# modify tokenizer
|
|
|
|
new_tokens_added = 0
|
|
|
|
for ti_name, ti in ti_list:
|
|
|
|
if ti.embedding_2 is not None:
|
|
|
|
ti_embedding = (
|
|
|
|
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
ti_embedding = ti.embedding
|
|
|
|
|
|
|
|
for i in range(ti_embedding.shape[0]):
|
|
|
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
|
|
|
|
|
|
|
embeddings = np.concatenate(
|
|
|
|
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
|
|
|
axis=0,
|
|
|
|
)
|
|
|
|
|
|
|
|
for ti_name, _ in ti_list:
|
|
|
|
ti_tokens = []
|
|
|
|
for i in range(ti_embedding.shape[0]):
|
|
|
|
embedding = ti_embedding[i].detach().numpy()
|
|
|
|
trigger = _get_trigger(ti_name, i)
|
|
|
|
|
|
|
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
|
|
|
if token_id == ti_tokenizer.unk_token_id:
|
|
|
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
|
|
|
|
|
|
|
if embeddings[token_id].shape != embedding.shape:
|
|
|
|
raise ValueError(
|
|
|
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
|
|
|
f" {embedding.shape[0]}, but the current model has token dimension"
|
|
|
|
f" {embeddings[token_id].shape[0]}."
|
|
|
|
)
|
|
|
|
|
|
|
|
embeddings[token_id] = embedding
|
|
|
|
ti_tokens.append(token_id)
|
|
|
|
|
|
|
|
if len(ti_tokens) > 1:
|
|
|
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
|
|
|
|
|
|
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
|
|
|
orig_embeddings.dtype
|
|
|
|
)
|
|
|
|
|
|
|
|
yield ti_tokenizer, ti_manager
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# restore
|
|
|
|
if orig_embeddings is not None:
|
|
|
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|