Add support for yet another TI embedding format (main version) (#3050)

- This PR adds support for embedding files that contain a single key
"emb_params". The only example I know of this format is the
"EasyNegative" embedding on HuggingFace, but there are certainly others.

- This PR also adds support for loading embedding files that have been
saved in safetensors format.

- It also cleans up the code so that the logic of probing for and
selecting the right format parser is clear.

- This is the same as #3045, which is on the 2.3 branch.
This commit is contained in:
Lincoln Stein 2023-03-31 03:57:57 -04:00 committed by GitHub
commit c4e6511a59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,16 +1,26 @@
import os
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union, List
import safetensors.torch
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from .concepts_lib import HuggingFaceConceptsLibrary from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass @dataclass
class TextualInversion: class TextualInversion:
@ -72,66 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager):
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
try: embedding_list = self._parse_embedding(str(ckpt_path))
scan_result = scan_file_path(str(ckpt_path)) for embedding_info in embedding_list:
if scan_result.infected_files == 1: if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print( print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}" f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
) )
print("### For your safety, InvokeAI will not load this embed.") continue
return
except Exception:
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path)) # Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
if embedding_info is None: trigger_str = embedding_info.name
# We've already put out an error message about the bad embedding in _parse_embedding, so just return. sourcefile = (
return f"{ckpt_path.parent.name}/{ckpt_path.name}"
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"]
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
)
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin" if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>" else ckpt_path.name
) )
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
try: if trigger_str in self.trigger_to_sourcefile:
self._add_textual_inversion( replacement_trigger_str = (
trigger_str, f"<{ckpt_path.parent.name}>"
embedding_info["embedding"], if ckpt_path.name == "learned_embeds.bin"
defer_injecting_tokens=defer_injecting_tokens, else f"<{ckpt_path.stem}>"
) )
# remember which source file claims this trigger print(
self.trigger_to_sourcefile[trigger_str] = sourcefile f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
except ValueError as e: try:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}') self._add_textual_inversion(
print(f" | The error was {str(e)}") trigger_str,
embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
def _add_textual_inversion( def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False self, trigger_str, embedding, defer_injecting_tokens=False
@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
return token_id return token_id
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split(".")[-1] def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
if file_type == "pt": suffix = Path(embedding_file).suffix
return self._parse_embedding_pt(embedding_file) try:
elif file_type == "bin": if suffix in [".pt",".ckpt",".bin"]:
return self._parse_embedding_bin(embedding_file) scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
)
print(" ** For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
keys = list(ckpt.keys())
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
elif all(x in keys for x in ['string_to_token','string_to_param']):
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
elif 'emb_params' in keys:
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
else: else:
print(f"** Notice: unrecognized embedding file format: {embedding_file}") return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
return None
def _parse_embedding_pt(self, embedding_file): def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
embedding_ckpt = torch.load(embedding_file, map_location="cpu") basename = Path(file_path).stem
embedding_info = {} print(f' | Loading v1 embedding file: {basename}')
# Check if valid embedding file embeddings = list()
if "string_to_token" and "string_to_param" in embedding_ckpt: token_counter = -1
# Catch variants that do not have the expected keys or values. for token,embedding in embedding_ckpt["string_to_param"].items():
try: if token_counter < 0:
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( trigger = embedding_ckpt["name"]
os.path.splitext(embedding_file)[0] elif token_counter == 0:
) trigger = f'<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
)
embeddings.append(embedding_info)
return embeddings
# Check num of embeddings and warn user only the first will be used def _parse_embedding_v2 (
embedding_info["num_of_embeddings"] = len( self, embedding_ckpt: dict, file_path: str
embedding_ckpt["string_to_token"] ) -> List[EmbeddingInfo]:
)
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt["string_to_param"].values())[0]
except (AttributeError, KeyError):
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)
else:
print(">> Invalid embedding format")
embedding_info = None
return embedding_info
def _parse_embedding_bin(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
if list(embedding_ckpt.keys()) == 0:
print(">> Invalid concepts file")
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info["name"] = (
token
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
""" """
This handles the broken .pt file variants. We only know of one at present. This handles embedding .pt file variant #2.
""" """
embedding_info = {} basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
embeddings = list()
if isinstance( if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
): ):
for token in list(embedding_ckpt["string_to_token"].keys()): token_counter = 0
embedding_info["name"] = ( for token,embedding in embedding_ckpt["string_to_param"].items():
token trigger = token if token != '*' \
if token != "*" else f'<{basename}>' if token_counter == 0 \
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
) )
embedding_info["embedding"] = embedding_ckpt[ embeddings.append(embedding_info)
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else: else:
print(">> Invalid embedding format") print(f" ** {basename}: Unrecognized embedding format")
embedding_info = None
return embedding_info return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(
name = token or f"<{basename}>",
embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
)
embeddings.append(embedding_info)
return embeddings