mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for yet another TI embedding format (main version) (#3050)
- This PR adds support for embedding files that contain a single key "emb_params". The only example I know of this format is the "EasyNegative" embedding on HuggingFace, but there are certainly others. - This PR also adds support for loading embedding files that have been saved in safetensors format. - It also cleans up the code so that the logic of probing for and selecting the right format parser is clear. - This is the same as #3045, which is on the 2.3 branch.
This commit is contained in:
commit
c4e6511a59
@ -1,16 +1,26 @@
|
|||||||
import os
|
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, List
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingInfo:
|
||||||
|
name: str
|
||||||
|
embedding: torch.Tensor
|
||||||
|
num_vectors_per_token: int
|
||||||
|
token_dim: int
|
||||||
|
trained_steps: int = None
|
||||||
|
trained_model_name: str = None
|
||||||
|
trained_model_checksum: str = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextualInversion:
|
class TextualInversion:
|
||||||
@ -72,66 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if str(ckpt_path).endswith(".DS_Store"):
|
if str(ckpt_path).endswith(".DS_Store"):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||||
scan_result = scan_file_path(str(ckpt_path))
|
for embedding_info in embedding_list:
|
||||||
if scan_result.infected_files == 1:
|
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||||
print(
|
print(
|
||||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||||
)
|
)
|
||||||
print("### For your safety, InvokeAI will not load this embed.")
|
continue
|
||||||
return
|
|
||||||
except Exception:
|
|
||||||
print(
|
|
||||||
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
# Resolve the situation in which an earlier embedding has claimed the same
|
||||||
|
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||||
if embedding_info is None:
|
trigger_str = embedding_info.name
|
||||||
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
|
sourcefile = (
|
||||||
return
|
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||||
elif (
|
|
||||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
|
||||||
!= embedding_info["token_dim"]
|
|
||||||
):
|
|
||||||
print(
|
|
||||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Resolve the situation in which an earlier embedding has claimed the same
|
|
||||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
|
||||||
trigger_str = embedding_info["name"]
|
|
||||||
sourcefile = (
|
|
||||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
|
||||||
if ckpt_path.name == "learned_embeds.bin"
|
|
||||||
else ckpt_path.name
|
|
||||||
)
|
|
||||||
|
|
||||||
if trigger_str in self.trigger_to_sourcefile:
|
|
||||||
replacement_trigger_str = (
|
|
||||||
f"<{ckpt_path.parent.name}>"
|
|
||||||
if ckpt_path.name == "learned_embeds.bin"
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
else f"<{ckpt_path.stem}>"
|
else ckpt_path.name
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
|
||||||
)
|
|
||||||
trigger_str = replacement_trigger_str
|
|
||||||
|
|
||||||
try:
|
if trigger_str in self.trigger_to_sourcefile:
|
||||||
self._add_textual_inversion(
|
replacement_trigger_str = (
|
||||||
trigger_str,
|
f"<{ckpt_path.parent.name}>"
|
||||||
embedding_info["embedding"],
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
defer_injecting_tokens=defer_injecting_tokens,
|
else f"<{ckpt_path.stem}>"
|
||||||
)
|
)
|
||||||
# remember which source file claims this trigger
|
print(
|
||||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||||
|
)
|
||||||
|
trigger_str = replacement_trigger_str
|
||||||
|
|
||||||
except ValueError as e:
|
try:
|
||||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
self._add_textual_inversion(
|
||||||
print(f" | The error was {str(e)}")
|
trigger_str,
|
||||||
|
embedding_info.embedding,
|
||||||
|
defer_injecting_tokens=defer_injecting_tokens,
|
||||||
|
)
|
||||||
|
# remember which source file claims this trigger
|
||||||
|
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||||
|
print(f" | The error was {str(e)}")
|
||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||||
@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
return token_id
|
return token_id
|
||||||
|
|
||||||
def _parse_embedding(self, embedding_file: str):
|
|
||||||
file_type = embedding_file.split(".")[-1]
|
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||||
if file_type == "pt":
|
suffix = Path(embedding_file).suffix
|
||||||
return self._parse_embedding_pt(embedding_file)
|
try:
|
||||||
elif file_type == "bin":
|
if suffix in [".pt",".ckpt",".bin"]:
|
||||||
return self._parse_embedding_bin(embedding_file)
|
scan_result = scan_file_path(embedding_file)
|
||||||
|
if scan_result.infected_files > 0:
|
||||||
|
print(
|
||||||
|
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||||
|
)
|
||||||
|
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||||
|
return list()
|
||||||
|
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||||
|
else:
|
||||||
|
ckpt = safetensors.torch.load_file(embedding_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||||
|
return list()
|
||||||
|
|
||||||
|
# try to figure out what kind of embedding file it is and parse accordingly
|
||||||
|
keys = list(ckpt.keys())
|
||||||
|
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||||
|
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||||
|
|
||||||
|
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||||
|
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||||
|
|
||||||
|
elif 'emb_params' in keys:
|
||||||
|
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
|
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_embedding_pt(self, embedding_file):
|
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
basename = Path(file_path).stem
|
||||||
embedding_info = {}
|
print(f' | Loading v1 embedding file: {basename}')
|
||||||
|
|
||||||
# Check if valid embedding file
|
embeddings = list()
|
||||||
if "string_to_token" and "string_to_param" in embedding_ckpt:
|
token_counter = -1
|
||||||
# Catch variants that do not have the expected keys or values.
|
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||||
try:
|
if token_counter < 0:
|
||||||
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
|
trigger = embedding_ckpt["name"]
|
||||||
os.path.splitext(embedding_file)[0]
|
elif token_counter == 0:
|
||||||
)
|
trigger = f'<basename>'
|
||||||
|
else:
|
||||||
|
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||||
|
token_counter += 1
|
||||||
|
embedding_info = EmbeddingInfo(
|
||||||
|
name = trigger,
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = embedding.size()[0],
|
||||||
|
token_dim = embedding.size()[1],
|
||||||
|
trained_steps = embedding_ckpt["step"],
|
||||||
|
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||||
|
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||||
|
)
|
||||||
|
embeddings.append(embedding_info)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
# Check num of embeddings and warn user only the first will be used
|
def _parse_embedding_v2 (
|
||||||
embedding_info["num_of_embeddings"] = len(
|
self, embedding_ckpt: dict, file_path: str
|
||||||
embedding_ckpt["string_to_token"]
|
) -> List[EmbeddingInfo]:
|
||||||
)
|
|
||||||
if embedding_info["num_of_embeddings"] > 1:
|
|
||||||
print(">> More than 1 embedding found. Will use the first one")
|
|
||||||
|
|
||||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
|
||||||
except (AttributeError, KeyError):
|
|
||||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
|
||||||
|
|
||||||
embedding_info["embedding"] = embedding
|
|
||||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
|
||||||
embedding_info["token_dim"] = embedding.size()[1]
|
|
||||||
|
|
||||||
try:
|
|
||||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
|
||||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
|
||||||
"sd_checkpoint_name"
|
|
||||||
]
|
|
||||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
|
||||||
"sd_checkpoint"
|
|
||||||
]
|
|
||||||
except AttributeError:
|
|
||||||
print(">> No Training Details Found. Passing ...")
|
|
||||||
|
|
||||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
|
||||||
# They are actually .bin files
|
|
||||||
elif len(embedding_ckpt.keys()) == 1:
|
|
||||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(">> Invalid embedding format")
|
|
||||||
embedding_info = None
|
|
||||||
|
|
||||||
return embedding_info
|
|
||||||
|
|
||||||
def _parse_embedding_bin(self, embedding_file):
|
|
||||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
|
||||||
embedding_info = {}
|
|
||||||
|
|
||||||
if list(embedding_ckpt.keys()) == 0:
|
|
||||||
print(">> Invalid concepts file")
|
|
||||||
embedding_info = None
|
|
||||||
else:
|
|
||||||
for token in list(embedding_ckpt.keys()):
|
|
||||||
embedding_info["name"] = (
|
|
||||||
token
|
|
||||||
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
|
||||||
)
|
|
||||||
embedding_info["embedding"] = embedding_ckpt[token]
|
|
||||||
embedding_info[
|
|
||||||
"num_vectors_per_token"
|
|
||||||
] = 1 # All Concepts seem to default to 1
|
|
||||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
|
||||||
|
|
||||||
return embedding_info
|
|
||||||
|
|
||||||
def _handle_broken_pt_variants(
|
|
||||||
self, embedding_ckpt: dict, embedding_file: str
|
|
||||||
) -> dict:
|
|
||||||
"""
|
"""
|
||||||
This handles the broken .pt file variants. We only know of one at present.
|
This handles embedding .pt file variant #2.
|
||||||
"""
|
"""
|
||||||
embedding_info = {}
|
basename = Path(file_path).stem
|
||||||
|
print(f' | Loading v2 embedding file: {basename}')
|
||||||
|
embeddings = list()
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||||
):
|
):
|
||||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
token_counter = 0
|
||||||
embedding_info["name"] = (
|
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||||
token
|
trigger = token if token != '*' \
|
||||||
if token != "*"
|
else f'<{basename}>' if token_counter == 0 \
|
||||||
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||||
|
embedding_info = EmbeddingInfo(
|
||||||
|
name = trigger,
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = embedding.size()[0],
|
||||||
|
token_dim = embedding.size()[1],
|
||||||
)
|
)
|
||||||
embedding_info["embedding"] = embedding_ckpt[
|
embeddings.append(embedding_info)
|
||||||
"string_to_param"
|
|
||||||
].state_dict()[token]
|
|
||||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
|
||||||
"embedding"
|
|
||||||
].shape[0]
|
|
||||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
|
||||||
else:
|
else:
|
||||||
print(">> Invalid embedding format")
|
print(f" ** {basename}: Unrecognized embedding format")
|
||||||
embedding_info = None
|
|
||||||
|
|
||||||
return embedding_info
|
return embeddings
|
||||||
|
|
||||||
|
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
|
"""
|
||||||
|
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||||
|
"""
|
||||||
|
basename = Path(file_path).stem
|
||||||
|
print(f' | Loading v3 embedding file: {basename}')
|
||||||
|
embedding = embedding_ckpt['emb_params']
|
||||||
|
embedding_info = EmbeddingInfo(
|
||||||
|
name = f'<{basename}>',
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = embedding.size()[0],
|
||||||
|
token_dim = embedding.size()[1],
|
||||||
|
)
|
||||||
|
return [embedding_info]
|
||||||
|
|
||||||
|
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||||
|
"""
|
||||||
|
Parse 'version 4' of the textual inversion embedding files. This one
|
||||||
|
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||||
|
"""
|
||||||
|
basename = Path(filepath).stem
|
||||||
|
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||||
|
|
||||||
|
print(f' | Loading v4 embedding file: {short_path}')
|
||||||
|
|
||||||
|
embeddings = list()
|
||||||
|
if list(embedding_ckpt.keys()) == 0:
|
||||||
|
print(f" ** Invalid embeddings file: {short_path}")
|
||||||
|
else:
|
||||||
|
for token,embedding in embedding_ckpt.items():
|
||||||
|
embedding_info = EmbeddingInfo(
|
||||||
|
name = token or f"<{basename}>",
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||||
|
token_dim = embedding.size()[0],
|
||||||
|
)
|
||||||
|
embeddings.append(embedding_info)
|
||||||
|
return embeddings
|
||||||
|
Loading…
x
Reference in New Issue
Block a user