mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup
This commit is contained in:
parent
05fac594ea
commit
023df37eff
@ -15,21 +15,21 @@ from picklescan.scanner import scan_file_path
|
|||||||
PROGRESSIVE_SCALE = 2000
|
PROGRESSIVE_SCALE = 2000
|
||||||
|
|
||||||
|
|
||||||
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str):
|
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
|
||||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||||
return token_id
|
return token_id
|
||||||
|
|
||||||
def get_bert_token_for_string(tokenizer, string):
|
def get_bert_token_id_for_string(tokenizer, string) -> int:
|
||||||
token = tokenizer(string)
|
token = tokenizer(string)
|
||||||
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||||
|
|
||||||
token = token[0, 1]
|
token = token[0, 1]
|
||||||
|
return token.item()
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_for_clip_token(embedder, token):
|
def get_embedding_for_clip_token_id(embedder, token_id):
|
||||||
return embedder(token.unsqueeze(0))[0, 0]
|
if type(token_id) is not torch.Tensor:
|
||||||
|
token_id = torch.tensor(token_id, dtype=torch.int)
|
||||||
|
return embedder(token_id.unsqueeze(0))[0, 0]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextualInversion:
|
class TextualInversion:
|
||||||
@ -183,9 +183,6 @@ class TextualInversionManager():
|
|||||||
return overwritten_prompt_embeddings
|
return overwritten_prompt_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingManager(nn.Module):
|
class EmbeddingManager(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -222,8 +219,8 @@ class EmbeddingManager(nn.Module):
|
|||||||
get_token_id_for_string = partial(
|
get_token_id_for_string = partial(
|
||||||
get_clip_token_id_for_string, embedder.tokenizer
|
get_clip_token_id_for_string, embedder.tokenizer
|
||||||
)
|
)
|
||||||
get_embedding_for_tkn = partial(
|
get_embedding_for_tkn_id = partial(
|
||||||
get_embedding_for_clip_token,
|
get_embedding_for_clip_token_id,
|
||||||
embedder.transformer.text_model.embeddings,
|
embedder.transformer.text_model.embeddings,
|
||||||
)
|
)
|
||||||
# per bug report #572
|
# per bug report #572
|
||||||
@ -232,9 +229,9 @@ class EmbeddingManager(nn.Module):
|
|||||||
else: # using LDM's BERT encoder
|
else: # using LDM's BERT encoder
|
||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
get_token_id_for_string = partial(
|
get_token_id_for_string = partial(
|
||||||
get_bert_token_for_string, embedder.tknz_fn
|
get_bert_token_id_for_string, embedder.tknz_fn
|
||||||
)
|
)
|
||||||
get_embedding_for_tkn = embedder.transformer.token_emb
|
get_embedding_for_tkn_id = embedder.transformer.token_emb
|
||||||
token_dim = 1280
|
token_dim = 1280
|
||||||
|
|
||||||
if per_image_tokens:
|
if per_image_tokens:
|
||||||
@ -248,9 +245,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
init_word_token_id = get_token_id_for_string(initializer_words[idx])
|
init_word_token_id = get_token_id_for_string(initializer_words[idx])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
init_word_embedding = get_embedding_for_tkn(
|
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
|
||||||
init_word_token_id.cpu()
|
|
||||||
)
|
|
||||||
|
|
||||||
token_params = torch.nn.Parameter(
|
token_params = torch.nn.Parameter(
|
||||||
init_word_embedding.unsqueeze(0).repeat(
|
init_word_embedding.unsqueeze(0).repeat(
|
||||||
|
Loading…
Reference in New Issue
Block a user