mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
handle multiple tokens and embeddings in single file
This commit is contained in:
parent
cdb3616dca
commit
e11c1d66ab
@ -1,7 +1,7 @@
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -12,6 +12,16 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
class EmbeddingInfo:
|
||||
name: str
|
||||
embedding: torch.Tensor
|
||||
num_vectors_per_token: int
|
||||
token_dim: int
|
||||
trained_steps: int = None
|
||||
trained_model_name: str = None
|
||||
trained_model_checksum: str = None
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
trigger_string: str
|
||||
@ -72,52 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
print(
|
||||
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
continue
|
||||
|
||||
if embedding_info is None:
|
||||
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
|
||||
return
|
||||
elif (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
!= embedding_info["token_dim"]
|
||||
):
|
||||
print(
|
||||
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||
)
|
||||
return
|
||||
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info["name"]
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else ckpt_path.name
|
||||
)
|
||||
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info.name
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
else ckpt_path.name
|
||||
)
|
||||
print(
|
||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info["embedding"],
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
print(
|
||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f" | The error was {str(e)}")
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info.embedding,
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f" | The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
@ -295,23 +299,24 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
return token_id
|
||||
|
||||
def _parse_embedding(self, embedding_file: str)->dict:
|
||||
|
||||
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||
suffix = Path(embedding_file).suffix
|
||||
try:
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files == 1:
|
||||
if scan_result.infected_files > 0:
|
||||
print(
|
||||
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return None
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
keys = list(ckpt.keys())
|
||||
@ -327,79 +332,78 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
else:
|
||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str):
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v1 embedding file: {basename}')
|
||||
|
||||
embedding_info = {}
|
||||
embedding_info["name"] = embedding_ckpt["name"]
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info["num_of_embeddings"] = len(
|
||||
embedding_ckpt["string_to_token"]
|
||||
)
|
||||
if embedding_info["num_of_embeddings"] > 1:
|
||||
print(" | More than 1 embedding found. Will use the first one")
|
||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
||||
"sd_checkpoint_name"
|
||||
]
|
||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
||||
"sd_checkpoint"
|
||||
]
|
||||
return embedding_info
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = f'<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
trained_steps = embedding_ckpt["step"],
|
||||
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v2 (
|
||||
self, embedding_ckpt: dict, file_path: str
|
||||
) -> dict:
|
||||
) -> List[EmbeddingInfo]:
|
||||
"""
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v2 embedding file: {basename}')
|
||||
embedding_info = {}
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
if token != "*"
|
||||
else f"<{basename}>"
|
||||
token_counter = 0
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
trigger = token if token != '*' \
|
||||
else f'<{basename}>' if token_counter == 0 \
|
||||
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[
|
||||
"string_to_param"
|
||||
].state_dict()[token]
|
||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
||||
"embedding"
|
||||
].shape[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
print(f" ** {basename}: Unrecognized embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str):
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v3 embedding file: {basename}')
|
||||
embedding_info = {}
|
||||
embedding_info["name"] = f'<{basename}>'
|
||||
embedding_info["num_of_embeddings"] = 1
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
return embedding_info
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
return [embedding_info]
|
||||
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str):
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 4' of the textual inversion embedding files. This one
|
||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||
@ -408,17 +412,17 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
print(f' | Loading v4 embedding file: {short_path}')
|
||||
embedding_info = {}
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(f" ** Invalid embeddings file: {short_path}")
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
or f"<{basename}>"
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = token or f"<{basename}>",
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||
token_dim = embedding.size()[0],
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[token]
|
||||
embedding_info["num_vectors_per_token"] = 1 # All Concepts seem to default to 1
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
return embedding_info
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
Loading…
Reference in New Issue
Block a user