use TextualInversionManager in place of embeddings (wip, doesn't work)

This commit is contained in:
Damian Stewart 2022-12-16 12:48:38 +01:00
parent 023df37eff
commit 664a6e9e14
9 changed files with 290 additions and 228 deletions

View File

@ -22,6 +22,7 @@ import skimage
from omegaconf import OmegaConf from omegaconf import OmegaConf
import ldm.invoke.conditioning import ldm.invoke.conditioning
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
from PIL import Image, ImageOps from PIL import Image, ImageOps
from torch import nn from torch import nn
@ -41,7 +42,6 @@ from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.model_cache import ModelCache from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
from ldm.invoke.concepts_lib import Concepts
def fix_func(orig): def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
@ -438,7 +438,7 @@ class Generate:
self._set_sampler() self._set_sampler()
# apply the concepts library to the prompt # apply the concepts library to the prompt
prompt = self.concept_lib().replace_concepts_with_triggers(prompt, lambda concepts: self.load_concepts(concepts)) prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
# bit of a hack to change the cached sampler's karras threshold to # bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for # whatever the user asked for
@ -862,19 +862,22 @@ class Generate:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None: if self.embedding_path is not None:
self.model.embedding_manager.load( for root, _, files in os.walk(self.embedding_path):
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast' for name in files:
) ti_path = os.path.join(root, name)
self.model.textual_inversion_manager.load_textual_inversion(ti_path)
print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}')
self._set_sampler() self._set_sampler()
self.model_name = model_name self.model_name = model_name
return self.model return self.model
def load_concepts(self,concepts:list[str]): def load_huggingface_concepts(self, concepts:list[str]):
self.model.embedding_manager.load_concepts(concepts, self.precision=='float32' or self.precision=='autocast') self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
def concept_lib(self)->Concepts: @property
return self.model.embedding_manager.concepts_library def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
return self.model.textual_inversion_manager.hf_concepts_library
def correct_colors(self, def correct_colors(self,
image_list, image_list,

View File

@ -16,7 +16,7 @@ from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_f
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log from ldm.invoke.log import write_log
from ldm.invoke.concepts_lib import Concepts from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
import pyparsing import pyparsing
@ -133,6 +133,10 @@ def main():
main_loop(gen, opt) main_loop(gen, opt)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\ngoodbye!") print("\ngoodbye!")
except Exception:
print(">> An error occurred:")
traceback.print_exc()
# TODO: main_loop() has gotten busy. Needs to be refactored. # TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt): def main_loop(gen, opt):
@ -310,7 +314,7 @@ def main_loop(gen, opt):
if use_prefix is not None: if use_prefix is not None:
prefix = use_prefix prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess' postprocessed = upscaled if upscaled else operation=='postprocess'
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
filename, formatted_dream_prompt = prepare_image_metadata( filename, formatted_dream_prompt = prepare_image_metadata(
opt, opt,
prefix, prefix,
@ -809,7 +813,8 @@ def add_embedding_terms(gen,completer):
Called after setting the model, updates the autocompleter with Called after setting the model, updates the autocompleter with
any terms loaded by the embedding manager. any terms loaded by the embedding manager.
''' '''
completer.add_embedding_terms(gen.model.embedding_manager.list_terms()) trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings()
completer.add_embedding_terms(trigger_strings)
def split_variations(variations_string) -> list: def split_variations(variations_string) -> list:
# shotgun parsing, woo # shotgun parsing, woo

View File

@ -12,7 +12,7 @@ from urllib import request, error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
class Concepts(object): class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None): def __init__(self, root=None):
''' '''
Initialize the Concepts object. May optionally pass a root directory. Initialize the Concepts object. May optionally pass a root directory.

View File

@ -231,7 +231,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
def _get_tokens_length(model, fragments: list[Fragment]): def _get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments] fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens]) return sum([len(x) for x in tokens])

View File

@ -12,7 +12,7 @@ import os
import re import re
import atexit import atexit
from ldm.invoke.args import Args from ldm.invoke.args import Args
from ldm.invoke.concepts_lib import Concepts from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
# ---------------readline utilities--------------------- # ---------------readline utilities---------------------
@ -276,7 +276,7 @@ class Completer(object):
def _concept_completions(self, text, state): def _concept_completions(self, text, state):
if self.concepts is None: if self.concepts is None:
self.concepts = set(Concepts().list_concepts()) self.concepts = set(HuggingFaceConceptsLibrary().list_concepts())
self.embedding_terms.update(self.concepts) self.embedding_terms.update(self.concepts)
partial = text[1:] # this removes the leading '<' partial = text[1:] # this removes the leading '<'

View File

@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig from omegaconf import ListConfig
import urllib import urllib
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ldm.util import ( from ldm.util import (
log_txt_as_img, log_txt_as_img,
exists, exists,
@ -678,6 +679,9 @@ class LatentDiffusion(DDPM):
self.embedding_manager = self.instantiate_embedding_manager( self.embedding_manager = self.instantiate_embedding_manager(
personalization_config, self.cond_stage_model personalization_config, self.cond_stage_model
) )
self.textual_inversion_manager = TextualInversionManager(self.cond_stage_model, full_precision=True)
# this circular component dependency is gross and bad, needs to be rethought
self.cond_stage_model.set_textual_inversion_manager(self.textual_inversion_manager)
self.emb_ckpt_counter = 0 self.emb_ckpt_counter = 0

View File

@ -6,7 +6,7 @@ from torch import nn
import sys import sys
from ldm.invoke.concepts_lib import Concepts from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.data.personalized import per_img_token_list from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from functools import partial from functools import partial
@ -31,157 +31,6 @@ def get_embedding_for_clip_token_id(embedder, token_id):
token_id = torch.tensor(token_id, dtype=torch.int) token_id = torch.tensor(token_id, dtype=torch.int)
return embedder(token_id.unsqueeze(0))[0, 0] return embedder(token_id.unsqueeze(0))[0, 0]
@dataclass
class TextualInversion:
trigger_string: str
token_id: int
embedding: torch.Tensor
@property
def embedding_vector_length(self) -> int:
return self.embedding.shape[0]
class TextualInversionManager():
def __init__(self, clip_embedder):
self.clip_embedder = clip_embedder
default_textual_inversions: list[TextualInversion] = []
self.textual_inversions = default_textual_inversions
def load_textual_inversion(self, ckpt_path, full_precision=True):
scan_result = scan_file_path(ckpt_path)
if scan_result.infected_files == 1:
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
print('### For your safety, InvokeAI will not load this embed.')
return
ckpt = torch.load(ckpt_path, map_location='cpu')
# Handle .pt textual inversion files
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
filename = os.path.basename(ckpt_path)
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
if len(ckpt["string_to_token"]) > 1:
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
string_to_param_dict = ckpt['string_to_param']
embedding = list(string_to_param_dict.values())[0]
self.add_textual_inversion(token_str, embedding, full_precision)
# Handle .bin textual inversion files from Huggingface Concepts
# https://huggingface.co/sd-concepts-library
else:
for token_str in list(ckpt.keys()):
embedding = ckpt[token_str]
self.add_textual_inversion(token_str, embedding, full_precision)
def add_textual_inversion(self, token_str, embedding) -> int:
"""
Add a textual inversion to be recognised.
:param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
:return: The token id for the added embedding, either existing or newly-added.
"""
if token_str in [ti.trigger_string for ti in self.textual_inversions]:
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{token_str}'")
return
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2:
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
if existing_token_id == self.clip_embedder.tokenizer.unk_token_id:
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
current_token_count = current_embeddings.num_embeddings
new_token_count = current_token_count + num_tokens_added
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
self.textual_inversions.append(TextualInversion(
trigger_string=token_str,
token_id=token_id,
embedding=embedding
))
return token_id
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
try:
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
return ti is not None
except StopIteration:
return False
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.token_id == token_id)
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
"""
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
"""
if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id:
raise ValueError("prompt_token_ids must not start with bos_token_id")
if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must not end with eos_token_id")
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
prompt_token_ids = prompt_token_ids[:]
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
if token_id in textual_inversion_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
for pad_idx in range(1, textual_inversion.embedding_vector_length):
prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id)
return prompt_token_ids
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor:
"""
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
subsequent rows in `prompt_embeddings` as well.
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
>1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers.
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
`prompt_token_ids` (i.e., also already expanded).
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
"""
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
if len(prompt_token_ids) > self.clip_embedder.max_length:
raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})")
if len(prompt_token_ids) < self.clip_embedder.max_length:
raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)")
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
overwritten_prompt_embeddings = prompt_embeddings.clone()
for i, token_id in enumerate(prompt_token_ids):
if token_id == pad_token_id:
continue
if token_id in textual_inversion_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
count_to_overwrite = end_index - i
for j in range(0, count_to_overwrite):
# only overwrite the textual inversion token id or the padding token id
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
break
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
return overwritten_prompt_embeddings
class EmbeddingManager(nn.Module): class EmbeddingManager(nn.Module):
def __init__( def __init__(
@ -197,8 +46,7 @@ class EmbeddingManager(nn.Module):
super().__init__() super().__init__()
self.embedder = embedder self.embedder = embedder
self.concepts_library=Concepts() self.concepts_library=HuggingFaceConceptsLibrary()
self.concepts_loaded = dict()
self.string_to_token_dict = {} self.string_to_token_dict = {}
self.string_to_param_dict = nn.ParameterDict() self.string_to_param_dict = nn.ParameterDict()
@ -349,22 +197,6 @@ class EmbeddingManager(nn.Module):
ckpt_path, ckpt_path,
) )
def load_concepts(self, concepts:list[str], full=True):
bin_files = list()
for concept_name in concepts:
if concept_name in self.concepts_loaded:
continue
else:
bin_file = self.concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
bin_files.append(bin_file)
self.concepts_loaded[concept_name]=True
self.load(bin_files, full)
def list_terms(self) -> list[str]:
return self.concepts_loaded.keys()
def load(self, ckpt_paths, full=True): def load(self, ckpt_paths, full=True):
if len(ckpt_paths) == 0: if len(ckpt_paths) == 0:
return return

View File

@ -9,6 +9,7 @@ from transformers import CLIPTokenizer, CLIPTextModel
import kornia import kornia
from ldm.invoke.devices import choose_torch_device from ldm.invoke.devices import choose_torch_device
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
#from ldm.modules.textual_inversion_manager import TextualInversionManager
from ldm.modules.x_transformer import ( from ldm.modules.x_transformer import (
Encoder, Encoder,
@ -465,7 +466,12 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights" fragment_weights_key = "fragment_weights"
return_tokens_key = "return_tokens" return_tokens_key = "return_tokens"
def set_textual_inversion_manager(self, manager): #TextualInversionManager):
# TODO all of the weighting and expanding stuff needs be moved out of this class
self.textual_inversion_manager = manager
def forward(self, text: list, **kwargs): def forward(self, text: list, **kwargs):
# TODO all of the weighting and expanding stuff needs be moved out of this class
''' '''
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
@ -560,19 +566,42 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
else: else:
return batch_z return batch_z
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
tokens = self.tokenizer( """
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
(typically 75 tokens + eos/bos markers).
:param fragments: The strings to convert.
:param include_start_and_end_markers:
:return:
"""
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
token_ids_list = self.tokenizer(
fragments, fragments,
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
return_overflowing_tokens=False, return_overflowing_tokens=False,
padding='do_not_pad', padding='do_not_pad',
return_tensors=None, # just give me a list of ints return_tensors=None, # just give me lists of ints
)['input_ids'] )['input_ids']
result = []
for token_ids in token_ids_list:
# trim eos/bos
token_ids = token_ids[1:-1]
# pad for textual inversions with vector length >1
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids(token_ids)
# restrict length to max_length-2 (leaving room for bos/eos)
token_ids = token_ids[0:self.max_length - 2]
# add back eos/bos if requested
if include_start_and_end_markers: if include_start_and_end_markers:
return tokens token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
else:
return [x[1:-1] for x in tokens] result.append(token_ids)
return result
@classmethod @classmethod
@ -597,56 +626,60 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
if len(fragments) == 0 and len(weights) == 0: if len(fragments) == 0 and len(weights) == 0:
fragments = [''] fragments = ['']
weights = [1] weights = [1]
item_encodings = self.tokenizer( per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
fragments, all_token_ids = []
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=True,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
all_tokens = []
per_token_weights = [] per_token_weights = []
#print("all fragments:", fragments, weights) #print("all fragments:", fragments, weights)
for index, fragment in enumerate(item_encodings): for index, fragment in enumerate(per_fragment_token_ids):
weight = weights[index] weight = float(weights[index])
#print("processing fragment", fragment, weight) #print("processing fragment", fragment, weight)
fragment_tokens = item_encodings[index] this_fragment_token_ids = per_fragment_token_ids[index]
#print("fragment", fragment, "processed to", fragment_tokens) #print("fragment", fragment, "processed to", this_fragment_token_ids)
# trim bos and eos markers before appending # append
all_tokens.extend(fragment_tokens[1:-1]) all_token_ids += this_fragment_token_ids
per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) # fill out weights tensor with one float per token
per_token_weights += [weight] * len(this_fragment_token_ids)
if (len(all_tokens) + 2) > self.max_length: # leave room for bos/eos
excess_token_count = (len(all_tokens) + 2) - self.max_length if len(all_token_ids) > self.max_length - 2:
excess_token_count = len(all_token_ids) - self.max_length - 2
# TODO build nice description string of how the truncation was applied
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
all_tokens = all_tokens[:self.max_length - 2] all_token_ids = all_token_ids[0:self.max_length]
per_token_weights = per_token_weights[:self.max_length - 2] per_token_weights = per_token_weights[0:self.max_length]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token] # pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
# (77 = self.max_length) # (77 = self.max_length)
pad_length = self.max_length - 1 - len(all_tokens) all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
all_tokens.insert(0, self.tokenizer.bos_token_id) per_token_weights = [1.0] + per_token_weights + [1.0]
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) pad_length = self.max_length - len(all_token_ids)
per_token_weights.insert(0, 1) all_token_ids += [self.tokenizer.eos_token_id] * pad_length
per_token_weights.extend([1] * pad_length) per_token_weights += [1.0] * pad_length
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
return all_tokens_tensor, per_token_weights_tensor return all_token_ids_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
''' '''
Build a tensor representing the passed-in tokens, each of which has a weight. Build a tensor representing the passed-in tokens, each of which has a weight.
:param tokens: A tensor of shape (77) containing token ids (integers) :param token_ids: A tensor of shape (77) containing token ids (integers)
:param per_token_weights: A tensor of shape (77) containing weights (floats) :param per_token_weights: A tensor of shape (77) containing weights (floats)
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
:param kwargs: passed on to self.transformer() :param kwargs: passed on to self.transformer()
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
''' '''
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) if token_ids.shape[0] != self.max_length:
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs)
assert(z.shape[0] == 1)
new_z0 = self.textual_inversion_manager.overwrite_textual_inversion_embeddings(token_ids, z[0])
z[0] = new_z0
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
if weight_delta_from_empty: if weight_delta_from_empty:
@ -660,7 +693,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
z_delta_from_empty = z - empty_z z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
weighted_z_delta_from_empty = (weighted_z-empty_z) #weighted_z_delta_from_empty = (weighted_z-empty_z)
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) #print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:") #print("using empty-delta method, first 5 rows:")

View File

@ -0,0 +1,185 @@
import os
from typing import Union
import torch
from attr import dataclass
from picklescan.scanner import scan_file_path
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.modules.embedding_manager import get_clip_token_id_for_string
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
@dataclass
class TextualInversion:
trigger_string: str
token_id: int
embedding: torch.Tensor
@property
def embedding_vector_length(self) -> int:
return self.embedding.shape[0]
class TextualInversionManager():
def __init__(self, clip_embedder: FrozenCLIPEmbedder, full_precision: bool):
self.clip_embedder = clip_embedder
self.full_precision = full_precision
self.hf_concepts_library = HuggingFaceConceptsLibrary()
default_textual_inversions: list[TextualInversion] = []
self.textual_inversions = default_textual_inversions
def load_huggingface_concepts(self, concepts: list[str]):
for concept_name in concepts:
if concept_name in self.hf_concepts_library.concepts_loaded:
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name]=True
def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path):
scan_result = scan_file_path(ckpt_path)
if scan_result.infected_files == 1:
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
print('### For your safety, InvokeAI will not load this embed.')
return
ckpt = torch.load(ckpt_path, map_location='cpu')
# Handle .pt textual inversion files
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
filename = os.path.basename(ckpt_path)
trigger_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
if len(ckpt["string_to_token"]) > 1:
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
string_to_param_dict = ckpt['string_to_param']
embedding = list(string_to_param_dict.values())[0]
self.add_textual_inversion(trigger_str, embedding)
# Handle .bin textual inversion files from Huggingface Concepts
# https://huggingface.co/sd-concepts-library
else:
for trigger_str in list(ckpt.keys()):
embedding = ckpt[trigger_str]
self.add_textual_inversion(trigger_str, embedding)
def add_textual_inversion(self, trigger_str, embedding) -> int:
"""
Add a textual inversion to be recognised.
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
return
if not self.full_precision:
embedding = embedding.half()
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2:
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, trigger_str)
if existing_token_id == self.clip_embedder.tokenizer.unk_token_id:
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(trigger_str)
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
current_token_count = current_embeddings.num_embeddings
new_token_count = current_token_count + num_tokens_added
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, trigger_str)
self.textual_inversions.append(TextualInversion(
trigger_string=trigger_str,
token_id=token_id,
embedding=embedding
))
return token_id
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
try:
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
return ti is not None
except StopIteration:
return False
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.token_id == token_id)
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
"""
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
"""
if len(prompt_token_ids) == 0:
return prompt_token_ids
if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id:
raise ValueError("prompt_token_ids must not start with bos_token_id")
if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must not end with eos_token_id")
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
prompt_token_ids = prompt_token_ids[:]
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
if token_id in textual_inversion_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
for pad_idx in range(1, textual_inversion.embedding_vector_length):
prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id)
return prompt_token_ids
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: Union[torch.Tensor,list[int]], prompt_embeddings: torch.Tensor) -> torch.Tensor:
"""
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
subsequent rows in `prompt_embeddings` as well.
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector length
>1 (call `expand_textual_inversion_token_ids()` to do this), padded to max length, and including bos and eos markers.
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
`prompt_token_ids` (i.e., also already expanded).
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
"""
if type(prompt_token_ids) is torch.Tensor:
if prompt_token_ids.shape != torch.Size([self.clip_embedder.max_length]):
raise ValueError(f"prompt_token_ids must be a list of length {self.clip_embedder.max_length} or a tensor of shape [{self.clip_embedder.max_length}]")
prompt_token_ids = list(prompt_token_ids.cpu().numpy())
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
if len(prompt_token_ids) > self.clip_embedder.max_length:
raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})")
if len(prompt_token_ids) < self.clip_embedder.max_length:
raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)")
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
overwritten_prompt_embeddings = prompt_embeddings.clone()
for i, token_id in enumerate(prompt_token_ids):
if token_id == pad_token_id:
continue
if token_id in textual_inversion_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
count_to_overwrite = end_index - i
for j in range(0, count_to_overwrite):
# only overwrite the textual inversion token id or the padding token id
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
break
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
return overwritten_prompt_embeddings