mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
abe4dc8ac1
- This PR adds support for embedding files that contain a single key "emb_params". The only example I know of this format is the "EasyNegative" embedding on HuggingFace, but there are certainly others. - This PR also adds support for loading embedding files that have been saved in safetensors format. - It also cleans up the code so that the logic of probing for and selecting the right format parser is clear.
425 lines
18 KiB
Python
425 lines
18 KiB
Python
import traceback
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
|
|
from compel.embeddings_provider import BaseTextualInversionManager
|
|
from picklescan.scanner import scan_file_path
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
from .concepts_lib import HuggingFaceConceptsLibrary
|
|
|
|
@dataclass
|
|
class TextualInversion:
|
|
trigger_string: str
|
|
embedding: torch.Tensor
|
|
trigger_token_id: Optional[int] = None
|
|
pad_token_ids: Optional[list[int]] = None
|
|
|
|
@property
|
|
def embedding_vector_length(self) -> int:
|
|
return self.embedding.shape[0]
|
|
|
|
|
|
class TextualInversionManager(BaseTextualInversionManager):
|
|
def __init__(
|
|
self,
|
|
tokenizer: CLIPTokenizer,
|
|
text_encoder: CLIPTextModel,
|
|
full_precision: bool = True,
|
|
):
|
|
self.tokenizer = tokenizer
|
|
self.text_encoder = text_encoder
|
|
self.full_precision = full_precision
|
|
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
|
self.trigger_to_sourcefile = dict()
|
|
default_textual_inversions: list[TextualInversion] = []
|
|
self.textual_inversions = default_textual_inversions
|
|
|
|
def load_huggingface_concepts(self, concepts: list[str]):
|
|
for concept_name in concepts:
|
|
if concept_name in self.hf_concepts_library.concepts_loaded:
|
|
continue
|
|
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
|
if (
|
|
self.has_textual_inversion_for_trigger_string(trigger)
|
|
or self.has_textual_inversion_for_trigger_string(concept_name)
|
|
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
|
): # in case a token with literal angle brackets encountered
|
|
print(f">> Loaded local embedding for trigger {concept_name}")
|
|
continue
|
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
|
if not bin_file:
|
|
continue
|
|
print(f">> Loaded remote embedding for trigger {concept_name}")
|
|
self.load_textual_inversion(bin_file)
|
|
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
|
|
|
def get_all_trigger_strings(self) -> list[str]:
|
|
return [ti.trigger_string for ti in self.textual_inversions]
|
|
|
|
def load_textual_inversion(
|
|
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
|
|
):
|
|
ckpt_path = Path(ckpt_path)
|
|
|
|
if not ckpt_path.is_file():
|
|
return
|
|
|
|
if str(ckpt_path).endswith(".DS_Store"):
|
|
return
|
|
|
|
embedding_info = self._parse_embedding(str(ckpt_path))
|
|
|
|
if embedding_info is None:
|
|
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
|
|
return
|
|
elif (
|
|
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
|
!= embedding_info["token_dim"]
|
|
):
|
|
print(
|
|
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
|
)
|
|
return
|
|
|
|
# Resolve the situation in which an earlier embedding has claimed the same
|
|
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
|
trigger_str = embedding_info["name"]
|
|
sourcefile = (
|
|
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
|
if ckpt_path.name == "learned_embeds.bin"
|
|
else ckpt_path.name
|
|
)
|
|
|
|
if trigger_str in self.trigger_to_sourcefile:
|
|
replacement_trigger_str = (
|
|
f"<{ckpt_path.parent.name}>"
|
|
if ckpt_path.name == "learned_embeds.bin"
|
|
else f"<{ckpt_path.stem}>"
|
|
)
|
|
print(
|
|
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
|
)
|
|
trigger_str = replacement_trigger_str
|
|
|
|
try:
|
|
self._add_textual_inversion(
|
|
trigger_str,
|
|
embedding_info["embedding"],
|
|
defer_injecting_tokens=defer_injecting_tokens,
|
|
)
|
|
# remember which source file claims this trigger
|
|
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
|
|
|
except ValueError as e:
|
|
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
|
print(f" | The error was {str(e)}")
|
|
|
|
def _add_textual_inversion(
|
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
|
) -> Optional[TextualInversion]:
|
|
"""
|
|
Add a textual inversion to be recognised.
|
|
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
|
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
|
:return: The token id for the added embedding, either existing or newly-added.
|
|
"""
|
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
|
print(
|
|
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
|
)
|
|
return
|
|
if not self.full_precision:
|
|
embedding = embedding.half()
|
|
if len(embedding.shape) == 1:
|
|
embedding = embedding.unsqueeze(0)
|
|
elif len(embedding.shape) > 2:
|
|
raise ValueError(
|
|
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
|
)
|
|
|
|
try:
|
|
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
|
if not defer_injecting_tokens:
|
|
self._inject_tokens_and_assign_embeddings(ti)
|
|
self.textual_inversions.append(ti)
|
|
return ti
|
|
|
|
except ValueError as e:
|
|
if str(e).startswith("Warning"):
|
|
print(f">> {str(e)}")
|
|
else:
|
|
traceback.print_exc()
|
|
print(
|
|
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
|
)
|
|
raise
|
|
|
|
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
|
if ti.trigger_token_id is not None:
|
|
raise ValueError(
|
|
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
|
)
|
|
|
|
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
|
ti.trigger_string, ti.embedding[0]
|
|
)
|
|
|
|
if ti.embedding_vector_length > 1:
|
|
# for embeddings with vector length > 1
|
|
pad_token_strings = [
|
|
ti.trigger_string + "-!pad-" + str(pad_index)
|
|
for pad_index in range(1, ti.embedding_vector_length)
|
|
]
|
|
# todo: batched UI for faster loading when vector length >2
|
|
pad_token_ids = [
|
|
self._get_or_create_token_id_and_assign_embedding(
|
|
pad_token_str, ti.embedding[1 + i]
|
|
)
|
|
for (i, pad_token_str) in enumerate(pad_token_strings)
|
|
]
|
|
else:
|
|
pad_token_ids = []
|
|
|
|
ti.trigger_token_id = trigger_token_id
|
|
ti.pad_token_ids = pad_token_ids
|
|
return ti.trigger_token_id
|
|
|
|
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
|
try:
|
|
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
|
return ti is not None
|
|
except StopIteration:
|
|
return False
|
|
|
|
def get_textual_inversion_for_trigger_string(
|
|
self, trigger_string: str
|
|
) -> TextualInversion:
|
|
return next(
|
|
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
|
)
|
|
|
|
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
|
return next(
|
|
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
|
)
|
|
|
|
def create_deferred_token_ids_for_any_trigger_terms(
|
|
self, prompt_string: str
|
|
) -> list[int]:
|
|
injected_token_ids = []
|
|
for ti in self.textual_inversions:
|
|
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
|
if ti.embedding_vector_length > 1:
|
|
print(
|
|
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
|
)
|
|
try:
|
|
self._inject_tokens_and_assign_embeddings(ti)
|
|
except ValueError as e:
|
|
print(
|
|
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
|
)
|
|
print(f" | The error was {str(e)}")
|
|
continue
|
|
injected_token_ids.append(ti.trigger_token_id)
|
|
injected_token_ids.extend(ti.pad_token_ids)
|
|
return injected_token_ids
|
|
|
|
def expand_textual_inversion_token_ids_if_necessary(
|
|
self, prompt_token_ids: list[int]
|
|
) -> list[int]:
|
|
"""
|
|
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
|
|
|
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
|
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
|
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
|
"""
|
|
if len(prompt_token_ids) == 0:
|
|
return prompt_token_ids
|
|
|
|
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
|
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
|
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
|
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
|
textual_inversion_trigger_token_ids = [
|
|
ti.trigger_token_id for ti in self.textual_inversions
|
|
]
|
|
prompt_token_ids = prompt_token_ids.copy()
|
|
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
|
if token_id in textual_inversion_trigger_token_ids:
|
|
textual_inversion = next(
|
|
ti
|
|
for ti in self.textual_inversions
|
|
if ti.trigger_token_id == token_id
|
|
)
|
|
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
|
prompt_token_ids.insert(
|
|
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
|
)
|
|
|
|
return prompt_token_ids
|
|
|
|
def _get_or_create_token_id_and_assign_embedding(
|
|
self, token_str: str, embedding: torch.Tensor
|
|
) -> int:
|
|
if len(embedding.shape) != 1:
|
|
raise ValueError(
|
|
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
|
)
|
|
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
|
if existing_token_id == self.tokenizer.unk_token_id:
|
|
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
|
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
|
current_token_count = current_embeddings.num_embeddings
|
|
new_token_count = current_token_count + num_tokens_added
|
|
# the following call is slow - todo make batched for better performance with vector length >1
|
|
self.text_encoder.resize_token_embeddings(new_token_count)
|
|
|
|
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
|
if token_id == self.tokenizer.unk_token_id:
|
|
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
|
if (
|
|
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
|
!= embedding.shape
|
|
):
|
|
raise ValueError(
|
|
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
|
)
|
|
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
|
|
|
return token_id
|
|
|
|
def _parse_embedding(self, embedding_file: str)->dict:
|
|
suffix = Path(embedding_file).suffix
|
|
try:
|
|
if suffix in [".pt",".ckpt",".bin"]:
|
|
scan_result = scan_file_path(embedding_file)
|
|
if scan_result.infected_files == 1:
|
|
print(
|
|
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
|
)
|
|
print(" ** For your safety, InvokeAI will not load this embed.")
|
|
return
|
|
ckpt = torch.load(embedding_file,map_location="cpu")
|
|
else:
|
|
ckpt = safetensors.torch.load_file(embedding_file)
|
|
except Exception as e:
|
|
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
|
return None
|
|
|
|
# try to figure out what kind of embedding file it is and parse accordingly
|
|
keys = list(ckpt.keys())
|
|
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
|
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
|
|
|
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
|
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
|
|
|
elif 'emb_params' in keys:
|
|
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
|
|
|
else:
|
|
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
|
|
|
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str):
|
|
basename = Path(file_path).stem
|
|
print(f' | Loading v1 embedding file: {basename}')
|
|
|
|
embedding_info = {}
|
|
embedding_info["name"] = embedding_ckpt["name"]
|
|
|
|
# Check num of embeddings and warn user only the first will be used
|
|
embedding_info["num_of_embeddings"] = len(
|
|
embedding_ckpt["string_to_token"]
|
|
)
|
|
if embedding_info["num_of_embeddings"] > 1:
|
|
print(" | More than 1 embedding found. Will use the first one")
|
|
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
|
embedding_info["embedding"] = embedding
|
|
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
|
embedding_info["token_dim"] = embedding.size()[1]
|
|
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
|
embedding_info["trained_model_name"] = embedding_ckpt[
|
|
"sd_checkpoint_name"
|
|
]
|
|
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
|
"sd_checkpoint"
|
|
]
|
|
return embedding_info
|
|
|
|
def _parse_embedding_v2 (
|
|
self, embedding_ckpt: dict, file_path: str
|
|
) -> dict:
|
|
"""
|
|
This handles embedding .pt file variant #2.
|
|
"""
|
|
basename = Path(file_path).stem
|
|
print(f' | Loading v2 embedding file: {basename}')
|
|
embedding_info = {}
|
|
if isinstance(
|
|
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
|
):
|
|
for token in list(embedding_ckpt["string_to_token"].keys()):
|
|
embedding_info["name"] = (
|
|
token
|
|
if token != "*"
|
|
else f"<{basename}>"
|
|
)
|
|
embedding_info["embedding"] = embedding_ckpt[
|
|
"string_to_param"
|
|
].state_dict()[token]
|
|
embedding_info["num_vectors_per_token"] = embedding_info[
|
|
"embedding"
|
|
].shape[0]
|
|
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
|
else:
|
|
print(f" ** {basename}: Unrecognized embedding format")
|
|
embedding_info = None
|
|
|
|
return embedding_info
|
|
|
|
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str):
|
|
"""
|
|
Parse 'version 3' of the .pt textual inversion embedding files.
|
|
"""
|
|
basename = Path(file_path).stem
|
|
print(f' | Loading v3 embedding file: {basename}')
|
|
embedding_info = {}
|
|
embedding_info["name"] = f'<{basename}>'
|
|
embedding_info["num_of_embeddings"] = 1
|
|
embedding = embedding_ckpt['emb_params']
|
|
embedding_info["embedding"] = embedding
|
|
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
|
embedding_info["token_dim"] = embedding.size()[1]
|
|
return embedding_info
|
|
|
|
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str):
|
|
"""
|
|
Parse 'version 4' of the textual inversion embedding files. This one
|
|
is usually associated with .bin files trained by HuggingFace diffusers.
|
|
"""
|
|
basename = Path(filepath).stem
|
|
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
|
|
|
print(f' | Loading v4 embedding file: {short_path}')
|
|
embedding_info = {}
|
|
if list(embedding_ckpt.keys()) == 0:
|
|
print(f" ** Invalid embeddings file: {short_path}")
|
|
embedding_info = None
|
|
else:
|
|
for token in list(embedding_ckpt.keys()):
|
|
embedding_info["name"] = (
|
|
token
|
|
or f"<{basename}>"
|
|
)
|
|
embedding_info["embedding"] = embedding_ckpt[token]
|
|
embedding_info["num_vectors_per_token"] = 1 # All Concepts seem to default to 1
|
|
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
|
return embedding_info
|