mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip new TextualInversionManager
This commit is contained in:
parent
5d20f47993
commit
2e80872e3b
@ -1,6 +1,7 @@
|
||||
import os.path
|
||||
from cmath import log
|
||||
import torch
|
||||
from attr import dataclass
|
||||
from torch import nn
|
||||
|
||||
import sys
|
||||
@ -14,7 +15,7 @@ from picklescan.scanner import scan_file_path
|
||||
PROGRESSIVE_SCALE = 2000
|
||||
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
def get_clip_token_id_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(
|
||||
string,
|
||||
truncation=True,
|
||||
@ -25,9 +26,9 @@ def get_clip_token_for_string(tokenizer, string):
|
||||
return_tensors='pt',
|
||||
)
|
||||
tokens = batch_encoding['input_ids']
|
||||
""" assert (
|
||||
assert (
|
||||
torch.count_nonzero(tokens - 49407) == 2
|
||||
), f"String '{string}' maps to more than a single token. Please use another string" """
|
||||
), f"String '{string}' maps to more than a single token. Please use another string"
|
||||
|
||||
return tokens[0, 1]
|
||||
|
||||
@ -44,6 +45,134 @@ def get_bert_token_for_string(tokenizer, string):
|
||||
def get_embedding_for_clip_token(embedder, token):
|
||||
return embedder(token.unsqueeze(0))[0, 0]
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
token_string: str
|
||||
token_id: int
|
||||
embedding: torch.Tensor
|
||||
|
||||
@property
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
class TextualInversionManager():
|
||||
def __init__(self, clip_embedder):
|
||||
self.clip_embedder = clip_embedder
|
||||
defatul_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = defatul_textual_inversions
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, full_precision=True):
|
||||
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
return
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
# Handle .pt textual inversion files
|
||||
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
||||
filename = os.path.basename(ckpt_path)
|
||||
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
|
||||
if len(ckpt["string_to_token"]) > 1:
|
||||
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
|
||||
|
||||
string_to_param_dict = ckpt['string_to_param']
|
||||
embedding = list(string_to_param_dict.values())[0]
|
||||
self.add_textual_inversion(token_str, embedding, full_precision)
|
||||
|
||||
# Handle .bin textual inversion files from Huggingface Concepts
|
||||
# https://huggingface.co/sd-concepts-library
|
||||
else:
|
||||
for token_str in list(ckpt.keys()):
|
||||
embedding = ckpt[token_str]
|
||||
self.add_textual_inversion(token_str, embedding, full_precision)
|
||||
|
||||
def add_textual_inversion(self, token_str, embedding):
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param token_str: The trigger text in the prompt that activates this textual inversion. Should be unknown to the embedder's tokenizer.
|
||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||
:return: The token id of the added embedding.
|
||||
"""
|
||||
if token_str in [ti.token_string for ti in self.textual_inversions]:
|
||||
print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'")
|
||||
return
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
|
||||
|
||||
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
|
||||
|
||||
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
|
||||
self.textual_inversions.append(TextualInversion(
|
||||
token_string=token_str,
|
||||
token_id=token_id,
|
||||
embedding=embedding
|
||||
))
|
||||
|
||||
return token_id
|
||||
|
||||
def has_textual_inversion(self, token_str):
|
||||
return token_str in [ti.token_string for ti in self.textual_inversions]
|
||||
|
||||
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
|
||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
|
||||
"""
|
||||
assert(prompt_token_ids[0] != self.clip_embedder.bos_token_id)
|
||||
assert(prompt_token_ids[-1] != self.clip_embedder.eos_token_id)
|
||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_token_ids:
|
||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||
for pad_idx in range(1, textual_inversion.embedding_vector_length):
|
||||
prompt_token_ids.insert(i+1, self.clip_embedder.pad_token_id)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
|
||||
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
|
||||
subsequent rows in `prompt_embeddings` as well.
|
||||
|
||||
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
|
||||
>1 (call `expand_textual_inversion_token_ids()` to do this)
|
||||
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
|
||||
`prompt_token_ids` (i.e., also already expanded).
|
||||
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
|
||||
"""
|
||||
assert(prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})")
|
||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||
for i, token_id in enumerate(prompt_token_ids):
|
||||
if token_id == pad_token_id:
|
||||
continue
|
||||
if token_id in textual_inversion_token_ids:
|
||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||
for j in range(0, textual_inversion.embedding_vector_length):
|
||||
# only overwrite the textual inversion token id or the padding token id
|
||||
if prompt_token_ids[i+j] != self.clip_embedder.pad_token_id and prompt_token_ids[i+j] != token_id:
|
||||
break
|
||||
prompt_embeddings[i+j] = textual_inversion.embedding[j]
|
||||
|
||||
return prompt_embeddings
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -78,7 +207,7 @@ class EmbeddingManager(nn.Module):
|
||||
): # using Stable Diffusion's CLIP encoder
|
||||
self.is_clip = True
|
||||
get_token_for_string = partial(
|
||||
get_clip_token_for_string, embedder.tokenizer
|
||||
get_clip_token_id_for_string, embedder.tokenizer
|
||||
)
|
||||
get_embedding_for_tkn = partial(
|
||||
get_embedding_for_clip_token,
|
||||
@ -241,7 +370,7 @@ class EmbeddingManager(nn.Module):
|
||||
# both will be stored in this dictionary
|
||||
for term in self.string_to_param_dict.keys():
|
||||
term = term.strip('<').strip('>')
|
||||
self.concepts_loaded[term] = True
|
||||
self.concepts_loaded[term] = True
|
||||
print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}')
|
||||
|
||||
def _expand_directories(self, paths:list[str]):
|
||||
@ -255,55 +384,6 @@ class EmbeddingManager(nn.Module):
|
||||
expanded_paths.append(os.path.join(root,name))
|
||||
return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')]
|
||||
|
||||
def _load(self, ckpt_path, full=True):
|
||||
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
return
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
# Handle .pt textual inversion files
|
||||
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
||||
filename = os.path.basename(ckpt_path)
|
||||
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
|
||||
if len(ckpt["string_to_token"]) > 1:
|
||||
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
|
||||
|
||||
string_to_param_dict = ckpt['string_to_param']
|
||||
embedding = list(string_to_param_dict.values())[0]
|
||||
self.add_embedding(token_str, embedding, full)
|
||||
|
||||
# Handle .bin textual inversion files from Huggingface Concepts
|
||||
# https://huggingface.co/sd-concepts-library
|
||||
else:
|
||||
for token_str in list(ckpt.keys()):
|
||||
embedding = ckpt[token_str]
|
||||
self.add_embedding(token_str, embedding, full)
|
||||
|
||||
def add_embedding(self, token_str, embedding, full):
|
||||
if token_str in self.string_to_param_dict:
|
||||
print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'")
|
||||
return
|
||||
if not full:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
|
||||
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
self.embedder.transformer.resize_token_embeddings(new_token_count)
|
||||
|
||||
token = get_clip_token_for_string(self.embedder.tokenizer, token_str)
|
||||
self.string_to_token_dict[token_str] = token
|
||||
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
|
||||
|
||||
def has_embedding_for_token(self, token_str):
|
||||
return token_str in self.string_to_token_dict
|
||||
|
||||
def get_embedding_norms_squared(self):
|
||||
all_params = torch.cat(
|
||||
|
43
tests/text_textual_inversion.py
Normal file
43
tests/text_textual_inversion.py
Normal file
@ -0,0 +1,43 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.modules.embedding_manager import TextualInversionManager
|
||||
|
||||
|
||||
class DummyClipEmbedder:
|
||||
max_length = 77
|
||||
bos_token_id = 49406
|
||||
eos_token_id = 49407
|
||||
|
||||
class TextualInversionManagerTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def test_construction(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
|
||||
def test_add_embedding(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
test_embedding = torch.random([1, 768])
|
||||
test_embedding_name = "test"
|
||||
token_id = tim.add_textual_inversion(test_embedding_name, test_embedding)
|
||||
self.assertTrue(tim.has_textual_inversion(test_embedding_name))
|
||||
|
||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == token_id)
|
||||
self.assertIsNotNone(textual_inversion)
|
||||
self.assertEqual(textual_inversion.embedding, test_embedding)
|
||||
self.assertEqual(textual_inversion.token_string, test_embedding_name)
|
||||
self.assertEqual(textual_inversion.token_id, token_id)
|
||||
|
||||
def test_pad_tokens_list(self):
|
||||
tim = TextualInversionManager(DummyClipEmbedder())
|
||||
prompt_token_ids = [DummyClipEmbedder.bos_token_id, 0, 1, 2, DummyClipEmbedder.eos_token_id]
|
||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
|
||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||
|
||||
test_embedding = torch.random([1, 768])
|
||||
test_embedding_name = "test"
|
||||
tim.add_textual_inversion("<token>",
|
||||
|
||||
self.assertRaises()
|
Loading…
Reference in New Issue
Block a user