InvokeAI/scripts/merge_embeddings.py

125 lines
3.9 KiB
Python
Raw Normal View History

from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
2022-08-23 22:26:28 +00:00
from ldm.modules.embedding_manager import EmbeddingManager
from ldm.modules.globals import Globals
2022-08-23 22:26:28 +00:00
import argparse, os
from functools import partial
import torch
def get_placeholder_loop(placeholder_string, embedder, use_bert):
2022-08-23 22:26:28 +00:00
new_placeholder = None
while True:
if new_placeholder is None:
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
else:
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
if token is not None:
return new_placeholder, token
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(
string,
truncation=True,
max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt"
)
tokens = batch_encoding["input_ids"]
if torch.count_nonzero(tokens - 49407) == 2:
return tokens[0, 1]
return None
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
if torch.count_nonzero(token) == 3:
return token[0, 1]
return None
2022-08-23 22:26:28 +00:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_dir",
type=str,
default='.',
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'."
)
2022-08-23 22:26:28 +00:00
parser.add_argument(
"--manager_ckpts",
type=str,
nargs="+",
required=True,
help="Paths to a set of embedding managers to be merged."
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Output path for the merged manager",
)
parser.add_argument(
"-sd", "--use_bert",
action="store_true",
help="Flag to denote that we are not merging stable diffusion embeddings"
)
2022-08-23 22:26:28 +00:00
args = parser.parse_args()
Globals.root=args.root_dir
2022-08-23 22:26:28 +00:00
if args.use_bert:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
else:
embedder = FrozenCLIPEmbedder().cuda()
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
2022-08-23 22:26:28 +00:00
string_to_token_dict = {}
string_to_param_dict = torch.nn.ParameterDict()
placeholder_to_src = {}
for manager_ckpt in args.manager_ckpts:
print(f"Parsing {manager_ckpt}...")
manager = EmbeddingManager()
manager.load(manager_ckpt)
for placeholder_string in manager.string_to_token_dict:
if not placeholder_string in string_to_token_dict:
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
placeholder_to_src[placeholder_string] = manager_ckpt
else:
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
2022-08-23 22:26:28 +00:00
string_to_token_dict[new_placeholder] = new_token
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
placeholder_to_src[new_placeholder] = manager_ckpt
print("Saving combined manager...")
merged_manager = EmbeddingManager()
merged_manager.string_to_param_dict = string_to_param_dict
merged_manager.string_to_token_dict = string_to_token_dict
merged_manager.save(args.output_path)
print("Managers merged. Final list of placeholders: ")
print(placeholder_to_src)