This commit is contained in:
Damian Stewart 2022-12-16 02:36:54 +01:00
parent 05fac594ea
commit 023df37eff

View File

@ -15,21 +15,21 @@ from picklescan.scanner import scan_file_path
PROGRESSIVE_SCALE = 2000
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str):
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
token_id = tokenizer.convert_tokens_to_ids(token_str)
return token_id
def get_bert_token_for_string(tokenizer, string):
def get_bert_token_id_for_string(tokenizer, string) -> int:
token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
token = token[0, 1]
return token
return token.item()
def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0]
def get_embedding_for_clip_token_id(embedder, token_id):
if type(token_id) is not torch.Tensor:
token_id = torch.tensor(token_id, dtype=torch.int)
return embedder(token_id.unsqueeze(0))[0, 0]
@dataclass
class TextualInversion:
@ -183,9 +183,6 @@ class TextualInversionManager():
return overwritten_prompt_embeddings
class EmbeddingManager(nn.Module):
def __init__(
self,
@ -222,8 +219,8 @@ class EmbeddingManager(nn.Module):
get_token_id_for_string = partial(
get_clip_token_id_for_string, embedder.tokenizer
)
get_embedding_for_tkn = partial(
get_embedding_for_clip_token,
get_embedding_for_tkn_id = partial(
get_embedding_for_clip_token_id,
embedder.transformer.text_model.embeddings,
)
# per bug report #572
@ -232,9 +229,9 @@ class EmbeddingManager(nn.Module):
else: # using LDM's BERT encoder
self.is_clip = False
get_token_id_for_string = partial(
get_bert_token_for_string, embedder.tknz_fn
get_bert_token_id_for_string, embedder.tknz_fn
)
get_embedding_for_tkn = embedder.transformer.token_emb
get_embedding_for_tkn_id = embedder.transformer.token_emb
token_dim = 1280
if per_image_tokens:
@ -248,9 +245,7 @@ class EmbeddingManager(nn.Module):
init_word_token_id = get_token_id_for_string(initializer_words[idx])
with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(
init_word_token_id.cpu()
)
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
token_params = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(