From 023df37efffa67434f77def7fc3c9dfb29f699fd Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 16 Dec 2022 02:36:54 +0100 Subject: [PATCH] cleanup --- ldm/modules/embedding_manager.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 97b98dc5e7..613bf7a430 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -15,21 +15,21 @@ from picklescan.scanner import scan_file_path PROGRESSIVE_SCALE = 2000 -def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str): +def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int: token_id = tokenizer.convert_tokens_to_ids(token_str) return token_id -def get_bert_token_for_string(tokenizer, string): +def get_bert_token_id_for_string(tokenizer, string) -> int: token = tokenizer(string) # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" - token = token[0, 1] - - return token + return token.item() -def get_embedding_for_clip_token(embedder, token): - return embedder(token.unsqueeze(0))[0, 0] +def get_embedding_for_clip_token_id(embedder, token_id): + if type(token_id) is not torch.Tensor: + token_id = torch.tensor(token_id, dtype=torch.int) + return embedder(token_id.unsqueeze(0))[0, 0] @dataclass class TextualInversion: @@ -183,9 +183,6 @@ class TextualInversionManager(): return overwritten_prompt_embeddings - - - class EmbeddingManager(nn.Module): def __init__( self, @@ -222,8 +219,8 @@ class EmbeddingManager(nn.Module): get_token_id_for_string = partial( get_clip_token_id_for_string, embedder.tokenizer ) - get_embedding_for_tkn = partial( - get_embedding_for_clip_token, + get_embedding_for_tkn_id = partial( + get_embedding_for_clip_token_id, embedder.transformer.text_model.embeddings, ) # per bug report #572 @@ -232,9 +229,9 @@ class EmbeddingManager(nn.Module): else: # using LDM's BERT encoder self.is_clip = False get_token_id_for_string = partial( - get_bert_token_for_string, embedder.tknz_fn + get_bert_token_id_for_string, embedder.tknz_fn ) - get_embedding_for_tkn = embedder.transformer.token_emb + get_embedding_for_tkn_id = embedder.transformer.token_emb token_dim = 1280 if per_image_tokens: @@ -248,9 +245,7 @@ class EmbeddingManager(nn.Module): init_word_token_id = get_token_id_for_string(initializer_words[idx]) with torch.no_grad(): - init_word_embedding = get_embedding_for_tkn( - init_word_token_id.cpu() - ) + init_word_embedding = get_embedding_for_tkn_id(init_word_token_id) token_params = torch.nn.Parameter( init_word_embedding.unsqueeze(0).repeat(