mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup
This commit is contained in:
parent
05fac594ea
commit
023df37eff
@ -15,21 +15,21 @@ from picklescan.scanner import scan_file_path
|
||||
PROGRESSIVE_SCALE = 2000
|
||||
|
||||
|
||||
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str):
|
||||
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
|
||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||
return token_id
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
def get_bert_token_id_for_string(tokenizer, string) -> int:
|
||||
token = tokenizer(string)
|
||||
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
|
||||
token = token[0, 1]
|
||||
|
||||
return token
|
||||
return token.item()
|
||||
|
||||
|
||||
def get_embedding_for_clip_token(embedder, token):
|
||||
return embedder(token.unsqueeze(0))[0, 0]
|
||||
def get_embedding_for_clip_token_id(embedder, token_id):
|
||||
if type(token_id) is not torch.Tensor:
|
||||
token_id = torch.tensor(token_id, dtype=torch.int)
|
||||
return embedder(token_id.unsqueeze(0))[0, 0]
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
@ -183,9 +183,6 @@ class TextualInversionManager():
|
||||
return overwritten_prompt_embeddings
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -222,8 +219,8 @@ class EmbeddingManager(nn.Module):
|
||||
get_token_id_for_string = partial(
|
||||
get_clip_token_id_for_string, embedder.tokenizer
|
||||
)
|
||||
get_embedding_for_tkn = partial(
|
||||
get_embedding_for_clip_token,
|
||||
get_embedding_for_tkn_id = partial(
|
||||
get_embedding_for_clip_token_id,
|
||||
embedder.transformer.text_model.embeddings,
|
||||
)
|
||||
# per bug report #572
|
||||
@ -232,9 +229,9 @@ class EmbeddingManager(nn.Module):
|
||||
else: # using LDM's BERT encoder
|
||||
self.is_clip = False
|
||||
get_token_id_for_string = partial(
|
||||
get_bert_token_for_string, embedder.tknz_fn
|
||||
get_bert_token_id_for_string, embedder.tknz_fn
|
||||
)
|
||||
get_embedding_for_tkn = embedder.transformer.token_emb
|
||||
get_embedding_for_tkn_id = embedder.transformer.token_emb
|
||||
token_dim = 1280
|
||||
|
||||
if per_image_tokens:
|
||||
@ -248,9 +245,7 @@ class EmbeddingManager(nn.Module):
|
||||
init_word_token_id = get_token_id_for_string(initializer_words[idx])
|
||||
|
||||
with torch.no_grad():
|
||||
init_word_embedding = get_embedding_for_tkn(
|
||||
init_word_token_id.cpu()
|
||||
)
|
||||
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
|
||||
|
||||
token_params = torch.nn.Parameter(
|
||||
init_word_embedding.unsqueeze(0).repeat(
|
||||
|
Loading…
Reference in New Issue
Block a user