2023-05-29 22:11:00 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-05-30 23:12:27 +00:00
|
|
|
import copy
|
2023-05-29 22:11:00 +00:00
|
|
|
from contextlib import contextmanager
|
2023-07-05 02:37:16 +00:00
|
|
|
from pathlib import Path
|
2023-08-18 15:18:46 +00:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
2023-05-29 22:11:00 +00:00
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
import numpy as np
|
2023-08-17 22:45:25 +00:00
|
|
|
import torch
|
2023-05-30 23:12:27 +00:00
|
|
|
from compel.embeddings_provider import BaseTextualInversionManager
|
2023-07-05 02:37:16 +00:00
|
|
|
from diffusers.models import UNet2DConditionModel
|
2023-05-29 22:11:00 +00:00
|
|
|
from safetensors.torch import load_file
|
2023-07-05 20:40:47 +00:00
|
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
2023-05-30 23:12:27 +00:00
|
|
|
|
2023-08-17 22:45:25 +00:00
|
|
|
from .models.lora import LoRAModel
|
|
|
|
|
2023-05-29 22:11:00 +00:00
|
|
|
"""
|
|
|
|
loras = [
|
|
|
|
(lora_model1, 0.7),
|
|
|
|
(lora_model2, 0.4),
|
|
|
|
]
|
|
|
|
with LoRAHelper.apply_lora_unet(unet, loras):
|
|
|
|
# unet with applied loras
|
|
|
|
# unmodified unet
|
|
|
|
|
|
|
|
"""
|
2023-07-28 13:46:44 +00:00
|
|
|
|
|
|
|
|
2023-05-29 22:11:00 +00:00
|
|
|
# TODO: rename smth like ModelPatcher and add TI method?
|
2023-05-30 23:12:27 +00:00
|
|
|
class ModelPatcher:
|
2023-05-29 22:11:00 +00:00
|
|
|
@staticmethod
|
|
|
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
|
|
|
assert "." not in lora_key
|
|
|
|
|
|
|
|
if not lora_key.startswith(prefix):
|
|
|
|
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
|
|
|
|
|
|
|
module = model
|
|
|
|
module_key = ""
|
2023-07-28 13:46:44 +00:00
|
|
|
key_parts = lora_key[len(prefix) :].split("_")
|
2023-05-29 22:11:00 +00:00
|
|
|
|
|
|
|
submodule_name = key_parts.pop(0)
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-05-29 22:11:00 +00:00
|
|
|
while len(key_parts) > 0:
|
|
|
|
try:
|
|
|
|
module = module.get_submodule(submodule_name)
|
|
|
|
module_key += "." + submodule_name
|
|
|
|
submodule_name = key_parts.pop(0)
|
2023-08-17 22:45:25 +00:00
|
|
|
except Exception:
|
2023-05-29 22:11:00 +00:00
|
|
|
submodule_name += "_" + key_parts.pop(0)
|
|
|
|
|
|
|
|
module = module.get_submodule(submodule_name)
|
2023-06-26 00:57:33 +00:00
|
|
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
2023-05-29 22:11:00 +00:00
|
|
|
|
|
|
|
return (module_key, module)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _lora_forward_hook(
|
2023-07-24 06:58:24 +00:00
|
|
|
applied_loras: List[Tuple[LoRAModel, float]],
|
2023-05-29 22:11:00 +00:00
|
|
|
layer_name: str,
|
|
|
|
):
|
|
|
|
def lora_forward(module, input_h, output):
|
|
|
|
if len(applied_loras) == 0:
|
|
|
|
return output
|
|
|
|
|
|
|
|
for lora, weight in applied_loras:
|
|
|
|
layer = lora.layers.get(layer_name, None)
|
|
|
|
if layer is None:
|
|
|
|
continue
|
|
|
|
output += layer.forward(module, input_h, weight)
|
|
|
|
return output
|
|
|
|
|
|
|
|
return lora_forward
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_unet(
|
|
|
|
cls,
|
|
|
|
unet: UNet2DConditionModel,
|
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
|
|
|
):
|
|
|
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
|
|
|
yield
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_text_encoder(
|
|
|
|
cls,
|
|
|
|
text_encoder: CLIPTextModel,
|
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
|
|
|
):
|
|
|
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
|
|
|
yield
|
|
|
|
|
2023-07-31 20:18:02 +00:00
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_sdxl_lora_text_encoder(
|
|
|
|
cls,
|
|
|
|
text_encoder: CLIPTextModel,
|
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
|
|
|
):
|
|
|
|
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
|
|
|
yield
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_sdxl_lora_text_encoder2(
|
|
|
|
cls,
|
|
|
|
text_encoder: CLIPTextModel,
|
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
|
|
|
):
|
|
|
|
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
|
|
|
yield
|
|
|
|
|
2023-05-29 22:11:00 +00:00
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora(
|
|
|
|
cls,
|
|
|
|
model: torch.nn.Module,
|
2023-07-24 06:58:24 +00:00
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
2023-05-29 22:11:00 +00:00
|
|
|
prefix: str,
|
|
|
|
):
|
2023-06-26 00:57:33 +00:00
|
|
|
original_weights = dict()
|
2023-05-29 22:11:00 +00:00
|
|
|
try:
|
2023-07-05 04:39:15 +00:00
|
|
|
with torch.no_grad():
|
2023-06-26 00:57:33 +00:00
|
|
|
for lora, lora_weight in loras:
|
2023-07-28 13:46:44 +00:00
|
|
|
# assert lora.device.type == "cpu"
|
2023-06-26 00:57:33 +00:00
|
|
|
for layer_key, layer in lora.layers.items():
|
|
|
|
if not layer_key.startswith(prefix):
|
|
|
|
continue
|
2023-05-29 22:11:00 +00:00
|
|
|
|
2023-06-26 00:57:33 +00:00
|
|
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
|
|
if module_key not in original_weights:
|
|
|
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
|
|
|
|
|
|
|
# enable autocast to calc fp16 loras on cpu
|
2023-07-28 13:46:44 +00:00
|
|
|
# with torch.autocast(device_type="cpu"):
|
2023-07-05 02:37:16 +00:00
|
|
|
layer.to(dtype=torch.float32)
|
|
|
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
2023-08-10 23:08:08 +00:00
|
|
|
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
|
2023-06-26 00:57:33 +00:00
|
|
|
|
|
|
|
if module.weight.shape != layer_weight.shape:
|
|
|
|
# TODO: debug on lycoris
|
|
|
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
|
|
|
|
|
|
|
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
2023-05-29 22:11:00 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
yield # wait for context manager exit
|
2023-05-29 22:11:00 +00:00
|
|
|
|
|
|
|
finally:
|
2023-07-05 04:39:15 +00:00
|
|
|
with torch.no_grad():
|
2023-06-26 00:57:33 +00:00
|
|
|
for module_key, weight in original_weights.items():
|
|
|
|
model.get_submodule(module_key).weight.copy_(weight)
|
2023-05-30 23:12:27 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_ti(
|
|
|
|
cls,
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
text_encoder: CLIPTextModel,
|
2023-08-01 15:04:10 +00:00
|
|
|
ti_list: List[Tuple[str, Any]],
|
2023-05-30 23:12:27 +00:00
|
|
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
|
|
|
init_tokens_count = None
|
|
|
|
new_tokens_added = None
|
|
|
|
|
|
|
|
try:
|
|
|
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
2023-06-17 16:20:24 +00:00
|
|
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
2023-05-30 23:12:27 +00:00
|
|
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
|
|
|
|
2023-08-01 15:04:10 +00:00
|
|
|
def _get_trigger(ti_name, index):
|
|
|
|
trigger = ti_name
|
2023-05-30 23:12:27 +00:00
|
|
|
if index > 0:
|
|
|
|
trigger += f"-!pad-{i}"
|
|
|
|
return f"<{trigger}>"
|
|
|
|
|
|
|
|
# modify tokenizer
|
|
|
|
new_tokens_added = 0
|
2023-08-01 15:04:10 +00:00
|
|
|
for ti_name, ti in ti_list:
|
2023-05-30 23:12:27 +00:00
|
|
|
for i in range(ti.embedding.shape[0]):
|
2023-08-01 15:04:10 +00:00
|
|
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
2023-05-30 23:12:27 +00:00
|
|
|
|
|
|
|
# modify text_encoder
|
|
|
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
|
|
|
model_embeddings = text_encoder.get_input_embeddings()
|
|
|
|
|
2023-08-01 15:04:10 +00:00
|
|
|
for ti_name, ti in ti_list:
|
2023-05-30 23:12:27 +00:00
|
|
|
ti_tokens = []
|
|
|
|
for i in range(ti.embedding.shape[0]):
|
|
|
|
embedding = ti.embedding[i]
|
2023-08-01 15:04:10 +00:00
|
|
|
trigger = _get_trigger(ti_name, i)
|
2023-05-30 23:12:27 +00:00
|
|
|
|
|
|
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
|
|
|
if token_id == ti_tokenizer.unk_token_id:
|
|
|
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
|
|
|
|
|
|
|
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
|
|
|
raise ValueError(
|
2023-10-19 19:18:32 +00:00
|
|
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
|
|
|
f" {embedding.shape[0]}, but the current model has token dimension"
|
|
|
|
f" {model_embeddings.weight.data[token_id].shape[0]}."
|
2023-05-30 23:12:27 +00:00
|
|
|
)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
model_embeddings.weight.data[token_id] = embedding.to(
|
|
|
|
device=text_encoder.device, dtype=text_encoder.dtype
|
|
|
|
)
|
2023-05-30 23:12:27 +00:00
|
|
|
ti_tokens.append(token_id)
|
|
|
|
|
|
|
|
if len(ti_tokens) > 1:
|
|
|
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
|
|
|
|
|
|
|
yield ti_tokenizer, ti_manager
|
|
|
|
|
|
|
|
finally:
|
|
|
|
if init_tokens_count and new_tokens_added:
|
|
|
|
text_encoder.resize_token_embeddings(init_tokens_count)
|
|
|
|
|
2023-07-06 13:09:40 +00:00
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_clip_skip(
|
|
|
|
cls,
|
|
|
|
text_encoder: CLIPTextModel,
|
|
|
|
clip_skip: int,
|
|
|
|
):
|
|
|
|
skipped_layers = []
|
|
|
|
try:
|
|
|
|
for i in range(clip_skip):
|
|
|
|
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
finally:
|
|
|
|
while len(skipped_layers) > 0:
|
|
|
|
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-05-30 23:12:27 +00:00
|
|
|
class TextualInversionModel:
|
2023-07-28 13:46:44 +00:00
|
|
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
2023-05-30 23:12:27 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_checkpoint(
|
|
|
|
cls,
|
|
|
|
file_path: Union[str, Path],
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
):
|
|
|
|
if not isinstance(file_path, Path):
|
|
|
|
file_path = Path(file_path)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
result = cls() # TODO:
|
2023-05-30 23:12:27 +00:00
|
|
|
|
|
|
|
if file_path.suffix == ".safetensors":
|
|
|
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
|
|
else:
|
|
|
|
state_dict = torch.load(file_path, map_location="cpu")
|
|
|
|
|
|
|
|
# both v1 and v2 format embeddings
|
|
|
|
# difference mostly in metadata
|
|
|
|
if "string_to_param" in state_dict:
|
|
|
|
if len(state_dict["string_to_param"]) > 1:
|
2023-07-28 13:46:44 +00:00
|
|
|
print(
|
2023-10-19 19:18:32 +00:00
|
|
|
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
|
|
|
|
" token will be used."
|
2023-07-28 13:46:44 +00:00
|
|
|
)
|
2023-05-30 23:12:27 +00:00
|
|
|
|
|
|
|
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
|
|
|
|
|
|
|
# v3 (easynegative)
|
|
|
|
elif "emb_params" in state_dict:
|
|
|
|
result.embedding = state_dict["emb_params"]
|
|
|
|
|
|
|
|
# v4(diffusers bin files)
|
|
|
|
else:
|
|
|
|
result.embedding = next(iter(state_dict.values()))
|
|
|
|
|
2023-07-05 16:46:00 +00:00
|
|
|
if len(result.embedding.shape) == 1:
|
|
|
|
result.embedding = result.embedding.unsqueeze(0)
|
|
|
|
|
2023-05-30 23:12:27 +00:00
|
|
|
if not isinstance(result.embedding, torch.Tensor):
|
|
|
|
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
class TextualInversionManager(BaseTextualInversionManager):
|
|
|
|
pad_tokens: Dict[int, List[int]]
|
2023-06-17 16:20:24 +00:00
|
|
|
tokenizer: CLIPTokenizer
|
2023-05-30 23:12:27 +00:00
|
|
|
|
2023-06-17 16:20:24 +00:00
|
|
|
def __init__(self, tokenizer: CLIPTokenizer):
|
2023-05-30 23:12:27 +00:00
|
|
|
self.pad_tokens = dict()
|
2023-06-17 16:20:24 +00:00
|
|
|
self.tokenizer = tokenizer
|
2023-05-30 23:12:27 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
2023-05-30 23:12:27 +00:00
|
|
|
if len(self.pad_tokens) == 0:
|
|
|
|
return token_ids
|
|
|
|
|
2023-06-17 16:20:24 +00:00
|
|
|
if token_ids[0] == self.tokenizer.bos_token_id:
|
|
|
|
raise ValueError("token_ids must not start with bos_token_id")
|
|
|
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
|
|
|
raise ValueError("token_ids must not end with eos_token_id")
|
|
|
|
|
2023-05-30 23:12:27 +00:00
|
|
|
new_token_ids = []
|
|
|
|
for token_id in token_ids:
|
|
|
|
new_token_ids.append(token_id)
|
|
|
|
if token_id in self.pad_tokens:
|
|
|
|
new_token_ids.extend(self.pad_tokens[token_id])
|
|
|
|
|
|
|
|
return new_token_ids
|
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
class ONNXModelPatcher:
|
2023-08-17 23:33:54 +00:00
|
|
|
from diffusers import OnnxRuntimeModel
|
2023-07-28 14:00:09 +00:00
|
|
|
|
2023-08-18 15:18:46 +00:00
|
|
|
from .models.base import IAIOnnxRuntimeModel
|
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_unet(
|
|
|
|
cls,
|
|
|
|
unet: OnnxRuntimeModel,
|
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
|
|
|
):
|
|
|
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
|
|
|
yield
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora_text_encoder(
|
|
|
|
cls,
|
|
|
|
text_encoder: OnnxRuntimeModel,
|
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
|
|
|
):
|
|
|
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
|
|
|
yield
|
|
|
|
|
2023-06-21 01:24:25 +00:00
|
|
|
# based on
|
|
|
|
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
2023-06-20 23:12:21 +00:00
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_lora(
|
|
|
|
cls,
|
|
|
|
model: IAIOnnxRuntimeModel,
|
2023-08-17 22:45:25 +00:00
|
|
|
loras: List[Tuple[LoRAModel, float]],
|
2023-06-20 23:12:21 +00:00
|
|
|
prefix: str,
|
|
|
|
):
|
|
|
|
from .models.base import IAIOnnxRuntimeModel
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
|
|
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
orig_weights = dict()
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
blended_loras = dict()
|
|
|
|
|
|
|
|
for lora, lora_weight in loras:
|
|
|
|
for layer_key, layer in lora.layers.items():
|
|
|
|
if not layer_key.startswith(prefix):
|
|
|
|
continue
|
|
|
|
|
2023-07-20 18:02:23 +00:00
|
|
|
layer.to(dtype=torch.float32)
|
2023-06-20 23:12:21 +00:00
|
|
|
layer_key = layer_key.replace(prefix, "")
|
2023-08-10 23:08:08 +00:00
|
|
|
# TODO: rewrite to pass original tensor weight(required by ia3)
|
|
|
|
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
2023-06-20 23:12:21 +00:00
|
|
|
if layer_key is blended_loras:
|
|
|
|
blended_loras[layer_key] += layer_weight
|
|
|
|
else:
|
|
|
|
blended_loras[layer_key] = layer_weight
|
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
node_names = dict()
|
|
|
|
for node in model.nodes.values():
|
|
|
|
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
for layer_key, lora_weight in blended_loras.items():
|
2023-06-20 23:12:21 +00:00
|
|
|
conv_key = layer_key + "_Conv"
|
|
|
|
gemm_key = layer_key + "_Gemm"
|
|
|
|
matmul_key = layer_key + "_MatMul"
|
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
if conv_key in node_names or gemm_key in node_names:
|
|
|
|
if conv_key in node_names:
|
|
|
|
conv_node = model.nodes[node_names[conv_key]]
|
2023-06-20 23:12:21 +00:00
|
|
|
else:
|
2023-06-22 17:03:17 +00:00
|
|
|
conv_node = model.nodes[node_names[gemm_key]]
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
2023-06-22 17:03:17 +00:00
|
|
|
orig_weight = model.tensors[weight_name]
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
if orig_weight.shape[-2:] == (1, 1):
|
|
|
|
if lora_weight.shape[-2:] == (1, 1):
|
|
|
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
2023-06-20 23:12:21 +00:00
|
|
|
else:
|
2023-06-22 17:03:17 +00:00
|
|
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
new_weight = np.expand_dims(new_weight, (2, 3))
|
2023-06-20 23:12:21 +00:00
|
|
|
else:
|
2023-06-22 17:03:17 +00:00
|
|
|
if orig_weight.shape != lora_weight.shape:
|
|
|
|
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
2023-06-20 23:12:21 +00:00
|
|
|
else:
|
2023-06-22 17:03:17 +00:00
|
|
|
new_weight = orig_weight + lora_weight
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
orig_weights[weight_name] = orig_weight
|
|
|
|
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
elif matmul_key in node_names:
|
|
|
|
weight_node = model.nodes[node_names[matmul_key]]
|
2023-06-20 23:12:21 +00:00
|
|
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
orig_weight = model.tensors[matmul_name]
|
|
|
|
new_weight = orig_weight + lora_weight.transpose()
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
orig_weights[matmul_name] = orig_weight
|
|
|
|
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
else:
|
|
|
|
# warn? err?
|
|
|
|
pass
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# restore original weights
|
2023-06-22 17:03:17 +00:00
|
|
|
for name, orig_weight in orig_weights.items():
|
|
|
|
model.tensors[name] = orig_weight
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def apply_ti(
|
|
|
|
cls,
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
text_encoder: IAIOnnxRuntimeModel,
|
2023-08-01 15:04:10 +00:00
|
|
|
ti_list: List[Tuple[str, Any]],
|
2023-06-20 23:12:21 +00:00
|
|
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
|
|
|
from .models.base import IAIOnnxRuntimeModel
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
|
|
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
orig_embeddings = None
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
|
|
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
|
|
|
|
2023-08-01 15:04:10 +00:00
|
|
|
def _get_trigger(ti_name, index):
|
|
|
|
trigger = ti_name
|
2023-06-20 23:12:21 +00:00
|
|
|
if index > 0:
|
|
|
|
trigger += f"-!pad-{i}"
|
|
|
|
return f"<{trigger}>"
|
|
|
|
|
|
|
|
# modify tokenizer
|
|
|
|
new_tokens_added = 0
|
2023-08-01 15:04:10 +00:00
|
|
|
for ti_name, ti in ti_list:
|
2023-06-20 23:12:21 +00:00
|
|
|
for i in range(ti.embedding.shape[0]):
|
2023-08-01 15:04:10 +00:00
|
|
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
# modify text_encoder
|
2023-06-22 17:03:17 +00:00
|
|
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
|
|
|
|
|
|
|
embeddings = np.concatenate(
|
2023-07-28 13:46:44 +00:00
|
|
|
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
2023-06-22 17:03:17 +00:00
|
|
|
axis=0,
|
|
|
|
)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-08-01 15:04:10 +00:00
|
|
|
for ti_name, ti in ti_list:
|
2023-06-20 23:12:21 +00:00
|
|
|
ti_tokens = []
|
|
|
|
for i in range(ti.embedding.shape[0]):
|
|
|
|
embedding = ti.embedding[i].detach().numpy()
|
2023-08-01 15:04:10 +00:00
|
|
|
trigger = _get_trigger(ti_name, i)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
|
|
|
if token_id == ti_tokenizer.unk_token_id:
|
|
|
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
if embeddings[token_id].shape != embedding.shape:
|
2023-06-20 23:12:21 +00:00
|
|
|
raise ValueError(
|
2023-10-19 19:18:32 +00:00
|
|
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
|
|
|
f" {embedding.shape[0]}, but the current model has token dimension"
|
|
|
|
f" {embeddings[token_id].shape[0]}."
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
|
|
|
|
2023-06-22 17:03:17 +00:00
|
|
|
embeddings[token_id] = embedding
|
2023-06-20 23:12:21 +00:00
|
|
|
ti_tokens.append(token_id)
|
|
|
|
|
|
|
|
if len(ti_tokens) > 1:
|
|
|
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
|
|
|
orig_embeddings.dtype
|
|
|
|
)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
yield ti_tokenizer, ti_manager
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# restore
|
2023-06-22 17:03:17 +00:00
|
|
|
if orig_embeddings is not None:
|
|
|
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|