mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for yet another textual inversion embedding format
- This PR adds support for embedding files that contain a single key "emb_params". The only example I know of this format is the "EasyNegative" embedding on HuggingFace, but there are certainly others. - This PR also adds support for loading embedding files that have been saved in safetensors format. - It also cleans up the code so that the logic of probing for and selecting the right format parser is clear.
This commit is contained in:
parent
09dfde0ba1
commit
abe4dc8ac1
@ -1,17 +1,17 @@
|
||||
import os
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
trigger_string: str
|
||||
@ -72,20 +72,6 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
try:
|
||||
scan_result = scan_file_path(str(ckpt_path))
|
||||
if scan_result.infected_files == 1:
|
||||
print(
|
||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
except Exception:
|
||||
print(
|
||||
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
||||
|
||||
if embedding_info is None:
|
||||
@ -96,7 +82,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
!= embedding_info["token_dim"]
|
||||
):
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||
)
|
||||
return
|
||||
|
||||
@ -309,92 +295,72 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
return token_id
|
||||
|
||||
def _parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split(".")[-1]
|
||||
if file_type == "pt":
|
||||
return self._parse_embedding_pt(embedding_file)
|
||||
elif file_type == "bin":
|
||||
return self._parse_embedding_bin(embedding_file)
|
||||
else:
|
||||
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
|
||||
def _parse_embedding(self, embedding_file: str)->dict:
|
||||
suffix = Path(embedding_file).suffix
|
||||
try:
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files == 1:
|
||||
print(
|
||||
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return None
|
||||
|
||||
def _parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
# Check if valid embedding file
|
||||
if "string_to_token" and "string_to_param" in embedding_ckpt:
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info["num_of_embeddings"] = len(
|
||||
embedding_ckpt["string_to_token"]
|
||||
)
|
||||
if embedding_info["num_of_embeddings"] > 1:
|
||||
print(">> More than 1 embedding found. Will use the first one")
|
||||
|
||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
||||
except (AttributeError, KeyError):
|
||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
||||
"sd_checkpoint_name"
|
||||
]
|
||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
||||
"sd_checkpoint"
|
||||
]
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys()) == 1:
|
||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
keys = list(ckpt.keys())
|
||||
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||
|
||||
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||
|
||||
elif 'emb_params' in keys:
|
||||
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||
|
||||
else:
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str):
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v1 embedding file: {basename}')
|
||||
|
||||
embedding_info = {}
|
||||
embedding_info["name"] = embedding_ckpt["name"]
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info["num_of_embeddings"] = len(
|
||||
embedding_ckpt["string_to_token"]
|
||||
)
|
||||
if embedding_info["num_of_embeddings"] > 1:
|
||||
print(" | More than 1 embedding found. Will use the first one")
|
||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
||||
"sd_checkpoint_name"
|
||||
]
|
||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
||||
"sd_checkpoint"
|
||||
]
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(">> Invalid concepts file")
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[token]
|
||||
embedding_info[
|
||||
"num_vectors_per_token"
|
||||
] = 1 # All Concepts seem to default to 1
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _handle_broken_pt_variants(
|
||||
self, embedding_ckpt: dict, embedding_file: str
|
||||
def _parse_embedding_v2 (
|
||||
self, embedding_ckpt: dict, file_path: str
|
||||
) -> dict:
|
||||
"""
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v2 embedding file: {basename}')
|
||||
embedding_info = {}
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
@ -403,7 +369,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
if token != "*"
|
||||
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
||||
else f"<{basename}>"
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[
|
||||
"string_to_param"
|
||||
@ -413,7 +379,46 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
].shape[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
||||
else:
|
||||
print(">> Invalid embedding format")
|
||||
print(f" ** {basename}: Unrecognized embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str):
|
||||
"""
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v3 embedding file: {basename}')
|
||||
embedding_info = {}
|
||||
embedding_info["name"] = f'<{basename}>'
|
||||
embedding_info["num_of_embeddings"] = 1
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str):
|
||||
"""
|
||||
Parse 'version 4' of the textual inversion embedding files. This one
|
||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||
"""
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
print(f' | Loading v4 embedding file: {short_path}')
|
||||
embedding_info = {}
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(f" ** Invalid embeddings file: {short_path}")
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
or f"<{basename}>"
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[token]
|
||||
embedding_info["num_vectors_per_token"] = 1 # All Concepts seem to default to 1
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
return embedding_info
|
||||
|
Loading…
Reference in New Issue
Block a user