This commit is contained in:
Damian Stewart 2022-12-16 02:36:54 +01:00
parent 05fac594ea
commit 023df37eff

View File

@ -15,21 +15,21 @@ from picklescan.scanner import scan_file_path
PROGRESSIVE_SCALE = 2000 PROGRESSIVE_SCALE = 2000
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str): def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
token_id = tokenizer.convert_tokens_to_ids(token_str) token_id = tokenizer.convert_tokens_to_ids(token_str)
return token_id return token_id
def get_bert_token_for_string(tokenizer, string): def get_bert_token_id_for_string(tokenizer, string) -> int:
token = tokenizer(string) token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
token = token[0, 1] token = token[0, 1]
return token.item()
return token
def get_embedding_for_clip_token(embedder, token): def get_embedding_for_clip_token_id(embedder, token_id):
return embedder(token.unsqueeze(0))[0, 0] if type(token_id) is not torch.Tensor:
token_id = torch.tensor(token_id, dtype=torch.int)
return embedder(token_id.unsqueeze(0))[0, 0]
@dataclass @dataclass
class TextualInversion: class TextualInversion:
@ -183,9 +183,6 @@ class TextualInversionManager():
return overwritten_prompt_embeddings return overwritten_prompt_embeddings
class EmbeddingManager(nn.Module): class EmbeddingManager(nn.Module):
def __init__( def __init__(
self, self,
@ -222,8 +219,8 @@ class EmbeddingManager(nn.Module):
get_token_id_for_string = partial( get_token_id_for_string = partial(
get_clip_token_id_for_string, embedder.tokenizer get_clip_token_id_for_string, embedder.tokenizer
) )
get_embedding_for_tkn = partial( get_embedding_for_tkn_id = partial(
get_embedding_for_clip_token, get_embedding_for_clip_token_id,
embedder.transformer.text_model.embeddings, embedder.transformer.text_model.embeddings,
) )
# per bug report #572 # per bug report #572
@ -232,9 +229,9 @@ class EmbeddingManager(nn.Module):
else: # using LDM's BERT encoder else: # using LDM's BERT encoder
self.is_clip = False self.is_clip = False
get_token_id_for_string = partial( get_token_id_for_string = partial(
get_bert_token_for_string, embedder.tknz_fn get_bert_token_id_for_string, embedder.tknz_fn
) )
get_embedding_for_tkn = embedder.transformer.token_emb get_embedding_for_tkn_id = embedder.transformer.token_emb
token_dim = 1280 token_dim = 1280
if per_image_tokens: if per_image_tokens:
@ -248,9 +245,7 @@ class EmbeddingManager(nn.Module):
init_word_token_id = get_token_id_for_string(initializer_words[idx]) init_word_token_id = get_token_id_for_string(initializer_words[idx])
with torch.no_grad(): with torch.no_grad():
init_word_embedding = get_embedding_for_tkn( init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
init_word_token_id.cpu()
)
token_params = torch.nn.Parameter( token_params = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat( init_word_embedding.unsqueeze(0).repeat(