Add support for yet another TI embedding format (main version) (#3050)

- This PR adds support for embedding files that contain a single key
"emb_params". The only example I know of this format is the
"EasyNegative" embedding on HuggingFace, but there are certainly others.

- This PR also adds support for loading embedding files that have been
saved in safetensors format.

- It also cleans up the code so that the logic of probing for and
selecting the right format parser is clear.

- This is the same as #3045, which is on the 2.3 branch.
This commit is contained in:
Lincoln Stein 2023-03-31 03:57:57 -04:00 committed by GitHub
commit c4e6511a59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,16 +1,26 @@
import os
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union, List
import safetensors.torch
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from .concepts_lib import HuggingFaceConceptsLibrary from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass @dataclass
class TextualInversion: class TextualInversion:
@ -72,37 +82,17 @@ class TextualInversionManager(BaseTextualInversionManager):
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
try: embedding_list = self._parse_embedding(str(ckpt_path))
scan_result = scan_file_path(str(ckpt_path)) for embedding_info in embedding_list:
if scan_result.infected_files == 1: if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print( print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}" f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
) )
print("### For your safety, InvokeAI will not load this embed.") continue
return
except Exception:
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path))
if embedding_info is None:
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
# Resolve the situation in which an earlier embedding has claimed the same # Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to. # trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"] trigger_str = embedding_info.name
sourcefile = ( sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}" f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin" if ckpt_path.name == "learned_embeds.bin"
@ -123,7 +113,7 @@ class TextualInversionManager(BaseTextualInversionManager):
try: try:
self._add_textual_inversion( self._add_textual_inversion(
trigger_str, trigger_str,
embedding_info["embedding"], embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens, defer_injecting_tokens=defer_injecting_tokens,
) )
# remember which source file claims this trigger # remember which source file claims this trigger
@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
return token_id return token_id
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split(".")[-1]
if file_type == "pt":
return self._parse_embedding_pt(embedding_file)
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
else:
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
return None
def _parse_embedding_pt(self, embedding_file): def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
embedding_ckpt = torch.load(embedding_file, map_location="cpu") suffix = Path(embedding_file).suffix
embedding_info = {}
# Check if valid embedding file
if "string_to_token" and "string_to_param" in embedding_ckpt:
# Catch variants that do not have the expected keys or values.
try: try:
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( if suffix in [".pt",".ckpt",".bin"]:
os.path.splitext(embedding_file)[0] scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
) )
print(" ** For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# Check num of embeddings and warn user only the first will be used # try to figure out what kind of embedding file it is and parse accordingly
embedding_info["num_of_embeddings"] = len( keys = list(ckpt.keys())
embedding_ckpt["string_to_token"] if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
) return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt["string_to_param"].values())[0] elif all(x in keys for x in ['string_to_token','string_to_param']):
except (AttributeError, KeyError): return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info["embedding"] = embedding elif 'emb_params' in keys:
embedding_info["num_vectors_per_token"] = embedding.size()[0] return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)
else: else:
print(">> Invalid embedding format") return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
embedding_info = None
return embedding_info def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}')
def _parse_embedding_bin(self, embedding_file): embeddings = list()
embedding_ckpt = torch.load(embedding_file, map_location="cpu") token_counter = -1
embedding_info = {} for token,embedding in embedding_ckpt["string_to_param"].items():
if token_counter < 0:
if list(embedding_ckpt.keys()) == 0: trigger = embedding_ckpt["name"]
print(">> Invalid concepts file") elif token_counter == 0:
embedding_info = None trigger = f'<basename>'
else: else:
for token in list(embedding_ckpt.keys()): trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
embedding_info["name"] = ( token_counter += 1
token embedding_info = EmbeddingInfo(
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
) )
embedding_info["embedding"] = embedding_ckpt[token] embeddings.append(embedding_info)
embedding_info[ return embeddings
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info def _parse_embedding_v2 (
self, embedding_ckpt: dict, file_path: str
) -> List[EmbeddingInfo]:
"""
This handles embedding .pt file variant #2.
"""
basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
embeddings = list()
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
"""
This handles the broken .pt file variants. We only know of one at present.
"""
embedding_info = {}
if isinstance( if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
): ):
for token in list(embedding_ckpt["string_to_token"].keys()): token_counter = 0
embedding_info["name"] = ( for token,embedding in embedding_ckpt["string_to_param"].items():
token trigger = token if token != '*' \
if token != "*" else f'<{basename}>' if token_counter == 0 \
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
) )
embedding_info["embedding"] = embedding_ckpt[ embeddings.append(embedding_info)
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else: else:
print(">> Invalid embedding format") print(f" ** {basename}: Unrecognized embedding format")
embedding_info = None
return embedding_info return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(
name = token or f"<{basename}>",
embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
)
embeddings.append(embedding_info)
return embeddings