mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix merging embeddings (#226)
Fixed merging embeddings based on the changes made in textual inversion. Tested and working. Inverted their logic to prioritize Stable Diffusion implementation over alternatives, but left the option for alternatives to still be used.
This commit is contained in:
parent
b622819051
commit
58e3562652
@ -52,7 +52,7 @@ model:
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 512
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
@ -74,7 +74,7 @@ data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 1
|
||||
num_workers: 16
|
||||
num_workers: 2
|
||||
wrap: false
|
||||
train:
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
@ -105,4 +105,5 @@ lightning:
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
max_steps: 6100
|
||||
max_steps: 4000
|
||||
|
@ -1,4 +1,4 @@
|
||||
from ldm.modules.encoders.modules import BERTTokenizer
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
||||
from ldm.modules.embedding_manager import EmbeddingManager
|
||||
|
||||
import argparse, os
|
||||
@ -6,7 +6,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
def get_placeholder_loop(placeholder_string, tokenizer):
|
||||
def get_placeholder_loop(placeholder_string, embedder, use_bert):
|
||||
|
||||
new_placeholder = None
|
||||
|
||||
@ -16,10 +16,36 @@ def get_placeholder_loop(placeholder_string, tokenizer):
|
||||
else:
|
||||
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
|
||||
|
||||
token = tokenizer(new_placeholder)
|
||||
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
|
||||
|
||||
if token is not None:
|
||||
return new_placeholder, token
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(
|
||||
string,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
tokens = batch_encoding["input_ids"]
|
||||
|
||||
if torch.count_nonzero(tokens - 49407) == 2:
|
||||
return tokens[0, 1]
|
||||
|
||||
return None
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
token = tokenizer(string)
|
||||
if torch.count_nonzero(token) == 3:
|
||||
return token[0, 1]
|
||||
|
||||
return None
|
||||
|
||||
if torch.count_nonzero(token) == 3:
|
||||
return new_placeholder, token[0, 1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -40,10 +66,20 @@ if __name__ == "__main__":
|
||||
help="Output path for the merged manager",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-sd", "--use_bert",
|
||||
action="store_true",
|
||||
help="Flag to denote that we are not merging stable diffusion embeddings"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = BERTTokenizer(vq_interface=False, max_length=77)
|
||||
EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"])
|
||||
if args.use_bert:
|
||||
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
||||
else:
|
||||
embedder = FrozenCLIPEmbedder().cuda()
|
||||
|
||||
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
|
||||
|
||||
string_to_token_dict = {}
|
||||
string_to_param_dict = torch.nn.ParameterDict()
|
||||
@ -63,7 +99,7 @@ if __name__ == "__main__":
|
||||
|
||||
placeholder_to_src[placeholder_string] = manager_ckpt
|
||||
else:
|
||||
new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer)
|
||||
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
|
||||
string_to_token_dict[new_placeholder] = new_token
|
||||
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
|
||||
|
||||
@ -77,7 +113,3 @@ if __name__ == "__main__":
|
||||
|
||||
print("Managers merged. Final list of placeholders: ")
|
||||
print(placeholder_to_src)
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user