wip new TextualInversionManager

This commit is contained in:
Damian Stewart 2022-12-15 10:57:57 +01:00
parent 5d20f47993
commit 2e80872e3b
2 changed files with 177 additions and 54 deletions

View File

@ -1,6 +1,7 @@
import os.path import os.path
from cmath import log from cmath import log
import torch import torch
from attr import dataclass
from torch import nn from torch import nn
import sys import sys
@ -14,7 +15,7 @@ from picklescan.scanner import scan_file_path
PROGRESSIVE_SCALE = 2000 PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string): def get_clip_token_id_for_string(tokenizer, string):
batch_encoding = tokenizer( batch_encoding = tokenizer(
string, string,
truncation=True, truncation=True,
@ -25,9 +26,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors='pt', return_tensors='pt',
) )
tokens = batch_encoding['input_ids'] tokens = batch_encoding['input_ids']
""" assert ( assert (
torch.count_nonzero(tokens - 49407) == 2 torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string" """ ), f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1] return tokens[0, 1]
@ -44,6 +45,134 @@ def get_bert_token_for_string(tokenizer, string):
def get_embedding_for_clip_token(embedder, token): def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0] return embedder(token.unsqueeze(0))[0, 0]
@dataclass
class TextualInversion:
token_string: str
token_id: int
embedding: torch.Tensor
@property
def embedding_vector_length(self) -> int:
return self.embedding.shape[0]
class TextualInversionManager():
def __init__(self, clip_embedder):
self.clip_embedder = clip_embedder
defatul_textual_inversions: list[TextualInversion] = []
self.textual_inversions = defatul_textual_inversions
def load_textual_inversion(self, ckpt_path, full_precision=True):
scan_result = scan_file_path(ckpt_path)
if scan_result.infected_files == 1:
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
print('### For your safety, InvokeAI will not load this embed.')
return
ckpt = torch.load(ckpt_path, map_location='cpu')
# Handle .pt textual inversion files
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
filename = os.path.basename(ckpt_path)
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
if len(ckpt["string_to_token"]) > 1:
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
string_to_param_dict = ckpt['string_to_param']
embedding = list(string_to_param_dict.values())[0]
self.add_textual_inversion(token_str, embedding, full_precision)
# Handle .bin textual inversion files from Huggingface Concepts
# https://huggingface.co/sd-concepts-library
else:
for token_str in list(ckpt.keys()):
embedding = ckpt[token_str]
self.add_textual_inversion(token_str, embedding, full_precision)
def add_textual_inversion(self, token_str, embedding):
"""
Add a textual inversion to be recognised.
:param token_str: The trigger text in the prompt that activates this textual inversion. Should be unknown to the embedder's tokenizer.
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
:return: The token id of the added embedding.
"""
if token_str in [ti.token_string for ti in self.textual_inversions]:
print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'")
return
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2:
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
current_token_count = current_embeddings.num_embeddings
new_token_count = current_token_count + num_tokens_added
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
self.textual_inversions.append(TextualInversion(
token_string=token_str,
token_id=token_id,
embedding=embedding
))
return token_id
def has_textual_inversion(self, token_str):
return token_str in [ti.token_string for ti in self.textual_inversions]
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
"""
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
"""
assert(prompt_token_ids[0] != self.clip_embedder.bos_token_id)
assert(prompt_token_ids[-1] != self.clip_embedder.eos_token_id)
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
if token_id in textual_inversion_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
for pad_idx in range(1, textual_inversion.embedding_vector_length):
prompt_token_ids.insert(i+1, self.clip_embedder.pad_token_id)
return prompt_token_ids
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor:
"""
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
subsequent rows in `prompt_embeddings` as well.
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
>1 (call `expand_textual_inversion_token_ids()` to do this)
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
`prompt_token_ids` (i.e., also already expanded).
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
"""
assert(prompt_embeddings.shape[0] == self.clip_embedder.max_length, f"prompt_embeddings must have 77 entries (has: {prompt_embeddings.shape[0]})")
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
for i, token_id in enumerate(prompt_token_ids):
if token_id == pad_token_id:
continue
if token_id in textual_inversion_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
for j in range(0, textual_inversion.embedding_vector_length):
# only overwrite the textual inversion token id or the padding token id
if prompt_token_ids[i+j] != self.clip_embedder.pad_token_id and prompt_token_ids[i+j] != token_id:
break
prompt_embeddings[i+j] = textual_inversion.embedding[j]
return prompt_embeddings
class EmbeddingManager(nn.Module): class EmbeddingManager(nn.Module):
def __init__( def __init__(
self, self,
@ -78,7 +207,7 @@ class EmbeddingManager(nn.Module):
): # using Stable Diffusion's CLIP encoder ): # using Stable Diffusion's CLIP encoder
self.is_clip = True self.is_clip = True
get_token_for_string = partial( get_token_for_string = partial(
get_clip_token_for_string, embedder.tokenizer get_clip_token_id_for_string, embedder.tokenizer
) )
get_embedding_for_tkn = partial( get_embedding_for_tkn = partial(
get_embedding_for_clip_token, get_embedding_for_clip_token,
@ -241,7 +370,7 @@ class EmbeddingManager(nn.Module):
# both will be stored in this dictionary # both will be stored in this dictionary
for term in self.string_to_param_dict.keys(): for term in self.string_to_param_dict.keys():
term = term.strip('<').strip('>') term = term.strip('<').strip('>')
self.concepts_loaded[term] = True self.concepts_loaded[term] = True
print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}') print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}')
def _expand_directories(self, paths:list[str]): def _expand_directories(self, paths:list[str]):
@ -255,55 +384,6 @@ class EmbeddingManager(nn.Module):
expanded_paths.append(os.path.join(root,name)) expanded_paths.append(os.path.join(root,name))
return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')] return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')]
def _load(self, ckpt_path, full=True):
scan_result = scan_file_path(ckpt_path)
if scan_result.infected_files == 1:
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
print('### For your safety, InvokeAI will not load this embed.')
return
ckpt = torch.load(ckpt_path, map_location='cpu')
# Handle .pt textual inversion files
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
filename = os.path.basename(ckpt_path)
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
if len(ckpt["string_to_token"]) > 1:
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
string_to_param_dict = ckpt['string_to_param']
embedding = list(string_to_param_dict.values())[0]
self.add_embedding(token_str, embedding, full)
# Handle .bin textual inversion files from Huggingface Concepts
# https://huggingface.co/sd-concepts-library
else:
for token_str in list(ckpt.keys()):
embedding = ckpt[token_str]
self.add_embedding(token_str, embedding, full)
def add_embedding(self, token_str, embedding, full):
if token_str in self.string_to_param_dict:
print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'")
return
if not full:
embedding = embedding.half()
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
current_token_count = current_embeddings.num_embeddings
new_token_count = current_token_count + num_tokens_added
self.embedder.transformer.resize_token_embeddings(new_token_count)
token = get_clip_token_for_string(self.embedder.tokenizer, token_str)
self.string_to_token_dict[token_str] = token
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
def has_embedding_for_token(self, token_str):
return token_str in self.string_to_token_dict
def get_embedding_norms_squared(self): def get_embedding_norms_squared(self):
all_params = torch.cat( all_params = torch.cat(

View File

@ -0,0 +1,43 @@
import unittest
import torch
from ldm.modules.embedding_manager import TextualInversionManager
class DummyClipEmbedder:
max_length = 77
bos_token_id = 49406
eos_token_id = 49407
class TextualInversionManagerTestCase(unittest.TestCase):
def test_construction(self):
tim = TextualInversionManager(DummyClipEmbedder())
def test_add_embedding(self):
tim = TextualInversionManager(DummyClipEmbedder())
test_embedding = torch.random([1, 768])
test_embedding_name = "test"
token_id = tim.add_textual_inversion(test_embedding_name, test_embedding)
self.assertTrue(tim.has_textual_inversion(test_embedding_name))
textual_inversion = next(ti for ti in tim.textual_inversions if ti.token_id == token_id)
self.assertIsNotNone(textual_inversion)
self.assertEqual(textual_inversion.embedding, test_embedding)
self.assertEqual(textual_inversion.token_string, test_embedding_name)
self.assertEqual(textual_inversion.token_id, token_id)
def test_pad_tokens_list(self):
tim = TextualInversionManager(DummyClipEmbedder())
prompt_token_ids = [DummyClipEmbedder.bos_token_id, 0, 1, 2, DummyClipEmbedder.eos_token_id]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids(prompt_token_ids=prompt_token_ids)
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
test_embedding = torch.random([1, 768])
test_embedding_name = "test"
tim.add_textual_inversion("<token>",
self.assertRaises()