Fix merging embeddings (#226)

Fixed merging embeddings based on the changes made in textual inversion. Tested and working. Inverted their logic to prioritize Stable Diffusion implementation over alternatives, but left the option for alternatives to still be used.
This commit is contained in:
David Ford 2022-08-31 10:24:11 -05:00 committed by GitHub
parent b622819051
commit 58e3562652
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 15 deletions

View File

@ -52,7 +52,7 @@ model:
ddconfig:
double_z: true
z_channels: 4
resolution: 512
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
@ -74,7 +74,7 @@ data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 16
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
@ -105,4 +105,5 @@ lightning:
trainer:
benchmark: True
max_steps: 6100
max_steps: 4000

View File

@ -1,4 +1,4 @@
from ldm.modules.encoders.modules import BERTTokenizer
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
from ldm.modules.embedding_manager import EmbeddingManager
import argparse, os
@ -6,7 +6,7 @@ from functools import partial
import torch
def get_placeholder_loop(placeholder_string, tokenizer):
def get_placeholder_loop(placeholder_string, embedder, use_bert):
new_placeholder = None
@ -16,10 +16,36 @@ def get_placeholder_loop(placeholder_string, tokenizer):
else:
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
token = tokenizer(new_placeholder)
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
if token is not None:
return new_placeholder, token
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(
string,
truncation=True,
max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt"
)
tokens = batch_encoding["input_ids"]
if torch.count_nonzero(tokens - 49407) == 2:
return tokens[0, 1]
return None
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
if torch.count_nonzero(token) == 3:
return new_placeholder, token[0, 1]
return token[0, 1]
return None
if __name__ == "__main__":
@ -40,10 +66,20 @@ if __name__ == "__main__":
help="Output path for the merged manager",
)
parser.add_argument(
"-sd", "--use_bert",
action="store_true",
help="Flag to denote that we are not merging stable diffusion embeddings"
)
args = parser.parse_args()
tokenizer = BERTTokenizer(vq_interface=False, max_length=77)
EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"])
if args.use_bert:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
else:
embedder = FrozenCLIPEmbedder().cuda()
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
string_to_token_dict = {}
string_to_param_dict = torch.nn.ParameterDict()
@ -63,7 +99,7 @@ if __name__ == "__main__":
placeholder_to_src[placeholder_string] = manager_ckpt
else:
new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer)
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
string_to_token_dict[new_placeholder] = new_token
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
@ -77,7 +113,3 @@ if __name__ == "__main__":
print("Managers merged. Final list of placeholders: ")
print(placeholder_to_src)