mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use TextualInversionManager in place of embeddings (wip, doesn't work)
This commit is contained in:
parent
023df37eff
commit
664a6e9e14
@ -22,6 +22,7 @@ import skimage
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import ldm.invoke.conditioning
|
import ldm.invoke.conditioning
|
||||||
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.invoke.generator.base import downsampling
|
from ldm.invoke.generator.base import downsampling
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -41,7 +42,6 @@ from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
|||||||
from ldm.invoke.model_cache import ModelCache
|
from ldm.invoke.model_cache import ModelCache
|
||||||
from ldm.invoke.seamless import configure_model_padding
|
from ldm.invoke.seamless import configure_model_padding
|
||||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
@ -438,7 +438,7 @@ class Generate:
|
|||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt = self.concept_lib().replace_concepts_with_triggers(prompt, lambda concepts: self.load_concepts(concepts))
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
||||||
|
|
||||||
# bit of a hack to change the cached sampler's karras threshold to
|
# bit of a hack to change the cached sampler's karras threshold to
|
||||||
# whatever the user asked for
|
# whatever the user asked for
|
||||||
@ -862,19 +862,22 @@ class Generate:
|
|||||||
|
|
||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
self.model.embedding_manager.load(
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
for name in files:
|
||||||
)
|
ti_path = os.path.join(root, name)
|
||||||
|
self.model.textual_inversion_manager.load_textual_inversion(ti_path)
|
||||||
|
print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}')
|
||||||
|
|
||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def load_concepts(self,concepts:list[str]):
|
def load_huggingface_concepts(self, concepts:list[str]):
|
||||||
self.model.embedding_manager.load_concepts(concepts, self.precision=='float32' or self.precision=='autocast')
|
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||||
|
|
||||||
def concept_lib(self)->Concepts:
|
@property
|
||||||
return self.model.embedding_manager.concepts_library
|
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
|
||||||
|
return self.model.textual_inversion_manager.hf_concepts_library
|
||||||
|
|
||||||
def correct_colors(self,
|
def correct_colors(self,
|
||||||
image_list,
|
image_list,
|
||||||
|
@ -16,7 +16,7 @@ from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_f
|
|||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||||
from ldm.invoke.image_util import make_grid
|
from ldm.invoke.image_util import make_grid
|
||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pyparsing
|
import pyparsing
|
||||||
@ -133,6 +133,10 @@ def main():
|
|||||||
main_loop(gen, opt)
|
main_loop(gen, opt)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\ngoodbye!")
|
print("\ngoodbye!")
|
||||||
|
except Exception:
|
||||||
|
print(">> An error occurred:")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||||
def main_loop(gen, opt):
|
def main_loop(gen, opt):
|
||||||
@ -310,7 +314,7 @@ def main_loop(gen, opt):
|
|||||||
if use_prefix is not None:
|
if use_prefix is not None:
|
||||||
prefix = use_prefix
|
prefix = use_prefix
|
||||||
postprocessed = upscaled if upscaled else operation=='postprocess'
|
postprocessed = upscaled if upscaled else operation=='postprocess'
|
||||||
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
|
opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
|
||||||
filename, formatted_dream_prompt = prepare_image_metadata(
|
filename, formatted_dream_prompt = prepare_image_metadata(
|
||||||
opt,
|
opt,
|
||||||
prefix,
|
prefix,
|
||||||
@ -809,7 +813,8 @@ def add_embedding_terms(gen,completer):
|
|||||||
Called after setting the model, updates the autocompleter with
|
Called after setting the model, updates the autocompleter with
|
||||||
any terms loaded by the embedding manager.
|
any terms loaded by the embedding manager.
|
||||||
'''
|
'''
|
||||||
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
|
trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
|
completer.add_embedding_terms(trigger_strings)
|
||||||
|
|
||||||
def split_variations(variations_string) -> list:
|
def split_variations(variations_string) -> list:
|
||||||
# shotgun parsing, woo
|
# shotgun parsing, woo
|
||||||
|
@ -12,7 +12,7 @@ from urllib import request, error as ul_error
|
|||||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
class Concepts(object):
|
class HuggingFaceConceptsLibrary(object):
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
'''
|
'''
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
Initialize the Concepts object. May optionally pass a root directory.
|
||||||
@ -116,11 +116,11 @@ class Concepts(object):
|
|||||||
self.download_concept(concept_name)
|
self.download_concept(concept_name)
|
||||||
path = os.path.join(self._concept_path(concept_name), file_name)
|
path = os.path.join(self._concept_path(concept_name), file_name)
|
||||||
return path if os.path.exists(path) else None
|
return path if os.path.exists(path) else None
|
||||||
|
|
||||||
def concept_is_downloaded(self, concept_name)->bool:
|
def concept_is_downloaded(self, concept_name)->bool:
|
||||||
concept_directory = self._concept_path(concept_name)
|
concept_directory = self._concept_path(concept_name)
|
||||||
return os.path.exists(concept_directory)
|
return os.path.exists(concept_directory)
|
||||||
|
|
||||||
def download_concept(self,concept_name)->bool:
|
def download_concept(self,concept_name)->bool:
|
||||||
repo_id = self._concept_id(concept_name)
|
repo_id = self._concept_id(concept_name)
|
||||||
dest = self._concept_path(concept_name)
|
dest = self._concept_path(concept_name)
|
||||||
@ -133,7 +133,7 @@ class Concepts(object):
|
|||||||
|
|
||||||
os.makedirs(dest, exist_ok=True)
|
os.makedirs(dest, exist_ok=True)
|
||||||
succeeded = True
|
succeeded = True
|
||||||
|
|
||||||
bytes = 0
|
bytes = 0
|
||||||
def tally_download_size(chunk, size, total):
|
def tally_download_size(chunk, size, total):
|
||||||
nonlocal bytes
|
nonlocal bytes
|
||||||
|
@ -231,7 +231,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
|
|||||||
|
|
||||||
def _get_tokens_length(model, fragments: list[Fragment]):
|
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||||
fragment_texts = [x.text for x in fragments]
|
fragment_texts = [x.text for x in fragments]
|
||||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
|
||||||
return sum([len(x) for x in tokens])
|
return sum([len(x) for x in tokens])
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import atexit
|
import atexit
|
||||||
from ldm.invoke.args import Args
|
from ldm.invoke.args import Args
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
# ---------------readline utilities---------------------
|
# ---------------readline utilities---------------------
|
||||||
@ -276,7 +276,7 @@ class Completer(object):
|
|||||||
|
|
||||||
def _concept_completions(self, text, state):
|
def _concept_completions(self, text, state):
|
||||||
if self.concepts is None:
|
if self.concepts is None:
|
||||||
self.concepts = set(Concepts().list_concepts())
|
self.concepts = set(HuggingFaceConceptsLibrary().list_concepts())
|
||||||
self.embedding_terms.update(self.concepts)
|
self.embedding_terms.update(self.concepts)
|
||||||
|
|
||||||
partial = text[1:] # this removes the leading '<'
|
partial = text[1:] # this removes the leading '<'
|
||||||
|
@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
|
|||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
from ldm.util import (
|
from ldm.util import (
|
||||||
log_txt_as_img,
|
log_txt_as_img,
|
||||||
exists,
|
exists,
|
||||||
@ -678,6 +679,9 @@ class LatentDiffusion(DDPM):
|
|||||||
self.embedding_manager = self.instantiate_embedding_manager(
|
self.embedding_manager = self.instantiate_embedding_manager(
|
||||||
personalization_config, self.cond_stage_model
|
personalization_config, self.cond_stage_model
|
||||||
)
|
)
|
||||||
|
self.textual_inversion_manager = TextualInversionManager(self.cond_stage_model, full_precision=True)
|
||||||
|
# this circular component dependency is gross and bad, needs to be rethought
|
||||||
|
self.cond_stage_model.set_textual_inversion_manager(self.textual_inversion_manager)
|
||||||
|
|
||||||
self.emb_ckpt_counter = 0
|
self.emb_ckpt_counter = 0
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from torch import nn
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.data.personalized import per_img_token_list
|
from ldm.data.personalized import per_img_token_list
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -31,157 +31,6 @@ def get_embedding_for_clip_token_id(embedder, token_id):
|
|||||||
token_id = torch.tensor(token_id, dtype=torch.int)
|
token_id = torch.tensor(token_id, dtype=torch.int)
|
||||||
return embedder(token_id.unsqueeze(0))[0, 0]
|
return embedder(token_id.unsqueeze(0))[0, 0]
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TextualInversion:
|
|
||||||
trigger_string: str
|
|
||||||
token_id: int
|
|
||||||
embedding: torch.Tensor
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_vector_length(self) -> int:
|
|
||||||
return self.embedding.shape[0]
|
|
||||||
|
|
||||||
class TextualInversionManager():
|
|
||||||
def __init__(self, clip_embedder):
|
|
||||||
self.clip_embedder = clip_embedder
|
|
||||||
default_textual_inversions: list[TextualInversion] = []
|
|
||||||
self.textual_inversions = default_textual_inversions
|
|
||||||
|
|
||||||
def load_textual_inversion(self, ckpt_path, full_precision=True):
|
|
||||||
|
|
||||||
scan_result = scan_file_path(ckpt_path)
|
|
||||||
if scan_result.infected_files == 1:
|
|
||||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
|
||||||
print('### For your safety, InvokeAI will not load this embed.')
|
|
||||||
return
|
|
||||||
|
|
||||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
|
||||||
|
|
||||||
# Handle .pt textual inversion files
|
|
||||||
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
|
||||||
filename = os.path.basename(ckpt_path)
|
|
||||||
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
|
|
||||||
if len(ckpt["string_to_token"]) > 1:
|
|
||||||
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
|
|
||||||
|
|
||||||
string_to_param_dict = ckpt['string_to_param']
|
|
||||||
embedding = list(string_to_param_dict.values())[0]
|
|
||||||
self.add_textual_inversion(token_str, embedding, full_precision)
|
|
||||||
|
|
||||||
# Handle .bin textual inversion files from Huggingface Concepts
|
|
||||||
# https://huggingface.co/sd-concepts-library
|
|
||||||
else:
|
|
||||||
for token_str in list(ckpt.keys()):
|
|
||||||
embedding = ckpt[token_str]
|
|
||||||
self.add_textual_inversion(token_str, embedding, full_precision)
|
|
||||||
|
|
||||||
def add_textual_inversion(self, token_str, embedding) -> int:
|
|
||||||
"""
|
|
||||||
Add a textual inversion to be recognised.
|
|
||||||
:param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
|
||||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
|
||||||
:return: The token id for the added embedding, either existing or newly-added.
|
|
||||||
"""
|
|
||||||
if token_str in [ti.trigger_string for ti in self.textual_inversions]:
|
|
||||||
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{token_str}'")
|
|
||||||
return
|
|
||||||
if len(embedding.shape) == 1:
|
|
||||||
embedding = embedding.unsqueeze(0)
|
|
||||||
elif len(embedding.shape) > 2:
|
|
||||||
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
|
|
||||||
|
|
||||||
existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
|
|
||||||
if existing_token_id == self.clip_embedder.tokenizer.unk_token_id:
|
|
||||||
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
|
|
||||||
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
|
|
||||||
current_token_count = current_embeddings.num_embeddings
|
|
||||||
new_token_count = current_token_count + num_tokens_added
|
|
||||||
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
|
|
||||||
|
|
||||||
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
|
|
||||||
self.textual_inversions.append(TextualInversion(
|
|
||||||
trigger_string=token_str,
|
|
||||||
token_id=token_id,
|
|
||||||
embedding=embedding
|
|
||||||
))
|
|
||||||
return token_id
|
|
||||||
|
|
||||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
|
||||||
try:
|
|
||||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
|
||||||
return ti is not None
|
|
||||||
except StopIteration:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
|
||||||
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
|
||||||
|
|
||||||
|
|
||||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
|
||||||
return next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
|
|
||||||
"""
|
|
||||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
|
||||||
|
|
||||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
|
||||||
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
|
|
||||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
|
||||||
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
|
|
||||||
"""
|
|
||||||
if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id:
|
|
||||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
|
||||||
if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id:
|
|
||||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
|
||||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
|
||||||
prompt_token_ids = prompt_token_ids[:]
|
|
||||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
|
||||||
if token_id in textual_inversion_token_ids:
|
|
||||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
|
||||||
for pad_idx in range(1, textual_inversion.embedding_vector_length):
|
|
||||||
prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id)
|
|
||||||
|
|
||||||
return prompt_token_ids
|
|
||||||
|
|
||||||
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
|
|
||||||
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
|
|
||||||
subsequent rows in `prompt_embeddings` as well.
|
|
||||||
|
|
||||||
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
|
|
||||||
>1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers.
|
|
||||||
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
|
|
||||||
`prompt_token_ids` (i.e., also already expanded).
|
|
||||||
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
|
|
||||||
"""
|
|
||||||
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
|
|
||||||
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
|
|
||||||
if len(prompt_token_ids) > self.clip_embedder.max_length:
|
|
||||||
raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})")
|
|
||||||
if len(prompt_token_ids) < self.clip_embedder.max_length:
|
|
||||||
raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)")
|
|
||||||
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
|
|
||||||
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")
|
|
||||||
|
|
||||||
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
|
||||||
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
|
|
||||||
overwritten_prompt_embeddings = prompt_embeddings.clone()
|
|
||||||
for i, token_id in enumerate(prompt_token_ids):
|
|
||||||
if token_id == pad_token_id:
|
|
||||||
continue
|
|
||||||
if token_id in textual_inversion_token_ids:
|
|
||||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
|
||||||
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
|
|
||||||
count_to_overwrite = end_index - i
|
|
||||||
for j in range(0, count_to_overwrite):
|
|
||||||
# only overwrite the textual inversion token id or the padding token id
|
|
||||||
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
|
|
||||||
break
|
|
||||||
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
|
|
||||||
|
|
||||||
return overwritten_prompt_embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingManager(nn.Module):
|
class EmbeddingManager(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -197,8 +46,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embedder = embedder
|
self.embedder = embedder
|
||||||
self.concepts_library=Concepts()
|
self.concepts_library=HuggingFaceConceptsLibrary()
|
||||||
self.concepts_loaded = dict()
|
|
||||||
|
|
||||||
self.string_to_token_dict = {}
|
self.string_to_token_dict = {}
|
||||||
self.string_to_param_dict = nn.ParameterDict()
|
self.string_to_param_dict = nn.ParameterDict()
|
||||||
@ -349,22 +197,6 @@ class EmbeddingManager(nn.Module):
|
|||||||
ckpt_path,
|
ckpt_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_concepts(self, concepts:list[str], full=True):
|
|
||||||
bin_files = list()
|
|
||||||
for concept_name in concepts:
|
|
||||||
if concept_name in self.concepts_loaded:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
bin_file = self.concepts_library.get_concept_model_path(concept_name)
|
|
||||||
if not bin_file:
|
|
||||||
continue
|
|
||||||
bin_files.append(bin_file)
|
|
||||||
self.concepts_loaded[concept_name]=True
|
|
||||||
self.load(bin_files, full)
|
|
||||||
|
|
||||||
def list_terms(self) -> list[str]:
|
|
||||||
return self.concepts_loaded.keys()
|
|
||||||
|
|
||||||
def load(self, ckpt_paths, full=True):
|
def load(self, ckpt_paths, full=True):
|
||||||
if len(ckpt_paths) == 0:
|
if len(ckpt_paths) == 0:
|
||||||
return
|
return
|
||||||
|
@ -9,6 +9,7 @@ from transformers import CLIPTokenizer, CLIPTextModel
|
|||||||
import kornia
|
import kornia
|
||||||
from ldm.invoke.devices import choose_torch_device
|
from ldm.invoke.devices import choose_torch_device
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
#from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
from ldm.modules.x_transformer import (
|
from ldm.modules.x_transformer import (
|
||||||
Encoder,
|
Encoder,
|
||||||
@ -465,7 +466,12 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
fragment_weights_key = "fragment_weights"
|
fragment_weights_key = "fragment_weights"
|
||||||
return_tokens_key = "return_tokens"
|
return_tokens_key = "return_tokens"
|
||||||
|
|
||||||
|
def set_textual_inversion_manager(self, manager): #TextualInversionManager):
|
||||||
|
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||||
|
self.textual_inversion_manager = manager
|
||||||
|
|
||||||
def forward(self, text: list, **kwargs):
|
def forward(self, text: list, **kwargs):
|
||||||
|
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||||
'''
|
'''
|
||||||
|
|
||||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||||
@ -560,19 +566,42 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
else:
|
else:
|
||||||
return batch_z
|
return batch_z
|
||||||
|
|
||||||
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||||
tokens = self.tokenizer(
|
"""
|
||||||
|
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
|
||||||
|
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
|
||||||
|
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
|
||||||
|
(typically 75 tokens + eos/bos markers).
|
||||||
|
|
||||||
|
:param fragments: The strings to convert.
|
||||||
|
:param include_start_and_end_markers:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
|
||||||
|
token_ids_list = self.tokenizer(
|
||||||
fragments,
|
fragments,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
return_overflowing_tokens=False,
|
return_overflowing_tokens=False,
|
||||||
padding='do_not_pad',
|
padding='do_not_pad',
|
||||||
return_tensors=None, # just give me a list of ints
|
return_tensors=None, # just give me lists of ints
|
||||||
)['input_ids']
|
)['input_ids']
|
||||||
if include_start_and_end_markers:
|
|
||||||
return tokens
|
result = []
|
||||||
else:
|
for token_ids in token_ids_list:
|
||||||
return [x[1:-1] for x in tokens]
|
# trim eos/bos
|
||||||
|
token_ids = token_ids[1:-1]
|
||||||
|
# pad for textual inversions with vector length >1
|
||||||
|
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids(token_ids)
|
||||||
|
# restrict length to max_length-2 (leaving room for bos/eos)
|
||||||
|
token_ids = token_ids[0:self.max_length - 2]
|
||||||
|
# add back eos/bos if requested
|
||||||
|
if include_start_and_end_markers:
|
||||||
|
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
result.append(token_ids)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -597,56 +626,60 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
if len(fragments) == 0 and len(weights) == 0:
|
if len(fragments) == 0 and len(weights) == 0:
|
||||||
fragments = ['']
|
fragments = ['']
|
||||||
weights = [1]
|
weights = [1]
|
||||||
item_encodings = self.tokenizer(
|
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
|
||||||
fragments,
|
all_token_ids = []
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
return_overflowing_tokens=True,
|
|
||||||
padding='do_not_pad',
|
|
||||||
return_tensors=None, # just give me a list of ints
|
|
||||||
)['input_ids']
|
|
||||||
all_tokens = []
|
|
||||||
per_token_weights = []
|
per_token_weights = []
|
||||||
#print("all fragments:", fragments, weights)
|
#print("all fragments:", fragments, weights)
|
||||||
for index, fragment in enumerate(item_encodings):
|
for index, fragment in enumerate(per_fragment_token_ids):
|
||||||
weight = weights[index]
|
weight = float(weights[index])
|
||||||
#print("processing fragment", fragment, weight)
|
#print("processing fragment", fragment, weight)
|
||||||
fragment_tokens = item_encodings[index]
|
this_fragment_token_ids = per_fragment_token_ids[index]
|
||||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
#print("fragment", fragment, "processed to", this_fragment_token_ids)
|
||||||
# trim bos and eos markers before appending
|
# append
|
||||||
all_tokens.extend(fragment_tokens[1:-1])
|
all_token_ids += this_fragment_token_ids
|
||||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
# fill out weights tensor with one float per token
|
||||||
|
per_token_weights += [weight] * len(this_fragment_token_ids)
|
||||||
|
|
||||||
if (len(all_tokens) + 2) > self.max_length:
|
# leave room for bos/eos
|
||||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
if len(all_token_ids) > self.max_length - 2:
|
||||||
|
excess_token_count = len(all_token_ids) - self.max_length - 2
|
||||||
|
# TODO build nice description string of how the truncation was applied
|
||||||
|
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||||
|
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||||
all_tokens = all_tokens[:self.max_length - 2]
|
all_token_ids = all_token_ids[0:self.max_length]
|
||||||
per_token_weights = per_token_weights[:self.max_length - 2]
|
per_token_weights = per_token_weights[0:self.max_length]
|
||||||
|
|
||||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||||
# (77 = self.max_length)
|
# (77 = self.max_length)
|
||||||
pad_length = self.max_length - 1 - len(all_tokens)
|
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
pad_length = self.max_length - len(all_token_ids)
|
||||||
per_token_weights.insert(0, 1)
|
all_token_ids += [self.tokenizer.eos_token_id] * pad_length
|
||||||
per_token_weights.extend([1] * pad_length)
|
per_token_weights += [1.0] * pad_length
|
||||||
|
|
||||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device)
|
||||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||||
return all_tokens_tensor, per_token_weights_tensor
|
return all_token_ids_tensor, per_token_weights_tensor
|
||||||
|
|
||||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
:param token_ids: A tensor of shape (77) containing token ids (integers)
|
||||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||||
:param kwargs: passed on to self.transformer()
|
:param kwargs: passed on to self.transformer()
|
||||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||||
'''
|
'''
|
||||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
if token_ids.shape[0] != self.max_length:
|
||||||
|
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||||
|
z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs)
|
||||||
|
assert(z.shape[0] == 1)
|
||||||
|
new_z0 = self.textual_inversion_manager.overwrite_textual_inversion_embeddings(token_ids, z[0])
|
||||||
|
z[0] = new_z0
|
||||||
|
|
||||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||||
|
|
||||||
if weight_delta_from_empty:
|
if weight_delta_from_empty:
|
||||||
@ -660,7 +693,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
z_delta_from_empty = z - empty_z
|
z_delta_from_empty = z - empty_z
|
||||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||||
|
|
||||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
#weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||||
|
|
||||||
#print("using empty-delta method, first 5 rows:")
|
#print("using empty-delta method, first 5 rows:")
|
||||||
|
185
ldm/modules/textual_inversion_manager.py
Normal file
185
ldm/modules/textual_inversion_manager.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from attr import dataclass
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
from ldm.modules.embedding_manager import get_clip_token_id_for_string
|
||||||
|
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextualInversion:
|
||||||
|
trigger_string: str
|
||||||
|
token_id: int
|
||||||
|
embedding: torch.Tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_vector_length(self) -> int:
|
||||||
|
return self.embedding.shape[0]
|
||||||
|
|
||||||
|
class TextualInversionManager():
|
||||||
|
def __init__(self, clip_embedder: FrozenCLIPEmbedder, full_precision: bool):
|
||||||
|
self.clip_embedder = clip_embedder
|
||||||
|
self.full_precision = full_precision
|
||||||
|
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||||
|
default_textual_inversions: list[TextualInversion] = []
|
||||||
|
self.textual_inversions = default_textual_inversions
|
||||||
|
|
||||||
|
def load_huggingface_concepts(self, concepts: list[str]):
|
||||||
|
for concept_name in concepts:
|
||||||
|
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||||
|
continue
|
||||||
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
|
if not bin_file:
|
||||||
|
continue
|
||||||
|
self.load_textual_inversion(bin_file)
|
||||||
|
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||||
|
|
||||||
|
def get_all_trigger_strings(self) -> list[str]:
|
||||||
|
return [ti.trigger_string for ti in self.textual_inversions]
|
||||||
|
|
||||||
|
def load_textual_inversion(self, ckpt_path):
|
||||||
|
scan_result = scan_file_path(ckpt_path)
|
||||||
|
if scan_result.infected_files == 1:
|
||||||
|
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||||
|
print('### For your safety, InvokeAI will not load this embed.')
|
||||||
|
return
|
||||||
|
|
||||||
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
|
# Handle .pt textual inversion files
|
||||||
|
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
||||||
|
filename = os.path.basename(ckpt_path)
|
||||||
|
trigger_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
|
||||||
|
if len(ckpt["string_to_token"]) > 1:
|
||||||
|
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
|
||||||
|
|
||||||
|
string_to_param_dict = ckpt['string_to_param']
|
||||||
|
embedding = list(string_to_param_dict.values())[0]
|
||||||
|
self.add_textual_inversion(trigger_str, embedding)
|
||||||
|
|
||||||
|
# Handle .bin textual inversion files from Huggingface Concepts
|
||||||
|
# https://huggingface.co/sd-concepts-library
|
||||||
|
else:
|
||||||
|
for trigger_str in list(ckpt.keys()):
|
||||||
|
embedding = ckpt[trigger_str]
|
||||||
|
self.add_textual_inversion(trigger_str, embedding)
|
||||||
|
|
||||||
|
def add_textual_inversion(self, trigger_str, embedding) -> int:
|
||||||
|
"""
|
||||||
|
Add a textual inversion to be recognised.
|
||||||
|
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||||
|
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||||
|
:return: The token id for the added embedding, either existing or newly-added.
|
||||||
|
"""
|
||||||
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
|
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
|
||||||
|
return
|
||||||
|
if not self.full_precision:
|
||||||
|
embedding = embedding.half()
|
||||||
|
if len(embedding.shape) == 1:
|
||||||
|
embedding = embedding.unsqueeze(0)
|
||||||
|
elif len(embedding.shape) > 2:
|
||||||
|
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
|
||||||
|
|
||||||
|
existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, trigger_str)
|
||||||
|
if existing_token_id == self.clip_embedder.tokenizer.unk_token_id:
|
||||||
|
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(trigger_str)
|
||||||
|
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
|
||||||
|
current_token_count = current_embeddings.num_embeddings
|
||||||
|
new_token_count = current_token_count + num_tokens_added
|
||||||
|
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
|
||||||
|
|
||||||
|
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, trigger_str)
|
||||||
|
self.textual_inversions.append(TextualInversion(
|
||||||
|
trigger_string=trigger_str,
|
||||||
|
token_id=token_id,
|
||||||
|
embedding=embedding
|
||||||
|
))
|
||||||
|
return token_id
|
||||||
|
|
||||||
|
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||||
|
try:
|
||||||
|
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||||
|
return ti is not None
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
||||||
|
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
||||||
|
|
||||||
|
|
||||||
|
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||||
|
return next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||||
|
|
||||||
|
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||||
|
|
||||||
|
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||||
|
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
|
||||||
|
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||||
|
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
|
||||||
|
"""
|
||||||
|
if len(prompt_token_ids) == 0:
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id:
|
||||||
|
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||||
|
if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id:
|
||||||
|
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||||
|
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||||
|
prompt_token_ids = prompt_token_ids[:]
|
||||||
|
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||||
|
if token_id in textual_inversion_token_ids:
|
||||||
|
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||||
|
for pad_idx in range(1, textual_inversion.embedding_vector_length):
|
||||||
|
prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: Union[torch.Tensor,list[int]], prompt_embeddings: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
|
||||||
|
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
|
||||||
|
subsequent rows in `prompt_embeddings` as well.
|
||||||
|
|
||||||
|
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector length
|
||||||
|
>1 (call `expand_textual_inversion_token_ids()` to do this), padded to max length, and including bos and eos markers.
|
||||||
|
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
|
||||||
|
`prompt_token_ids` (i.e., also already expanded).
|
||||||
|
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
|
||||||
|
"""
|
||||||
|
if type(prompt_token_ids) is torch.Tensor:
|
||||||
|
if prompt_token_ids.shape != torch.Size([self.clip_embedder.max_length]):
|
||||||
|
raise ValueError(f"prompt_token_ids must be a list of length {self.clip_embedder.max_length} or a tensor of shape [{self.clip_embedder.max_length}]")
|
||||||
|
prompt_token_ids = list(prompt_token_ids.cpu().numpy())
|
||||||
|
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
|
||||||
|
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
|
||||||
|
if len(prompt_token_ids) > self.clip_embedder.max_length:
|
||||||
|
raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})")
|
||||||
|
if len(prompt_token_ids) < self.clip_embedder.max_length:
|
||||||
|
raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)")
|
||||||
|
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
|
||||||
|
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")
|
||||||
|
|
||||||
|
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
|
||||||
|
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
|
||||||
|
overwritten_prompt_embeddings = prompt_embeddings.clone()
|
||||||
|
for i, token_id in enumerate(prompt_token_ids):
|
||||||
|
if token_id == pad_token_id:
|
||||||
|
continue
|
||||||
|
if token_id in textual_inversion_token_ids:
|
||||||
|
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
|
||||||
|
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
|
||||||
|
count_to_overwrite = end_index - i
|
||||||
|
for j in range(0, count_to_overwrite):
|
||||||
|
# only overwrite the textual inversion token id or the padding token id
|
||||||
|
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
|
||||||
|
break
|
||||||
|
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
|
||||||
|
|
||||||
|
return overwritten_prompt_embeddings
|
Loading…
Reference in New Issue
Block a user