handle multiple tokens and embeddings in single file

This commit is contained in:
Lincoln Stein 2023-03-29 22:05:06 -04:00
parent cdb3616dca
commit e11c1d66ab

View File

@ -1,7 +1,7 @@
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union, List
import safetensors.torch import safetensors.torch
import torch import torch
@ -12,6 +12,16 @@ from transformers import CLIPTextModel, CLIPTokenizer
from .concepts_lib import HuggingFaceConceptsLibrary from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass @dataclass
class TextualInversion: class TextualInversion:
trigger_string: str trigger_string: str
@ -72,23 +82,17 @@ class TextualInversionManager(BaseTextualInversionManager):
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
embedding_info = self._parse_embedding(str(ckpt_path)) embedding_list = self._parse_embedding(str(ckpt_path))
for embedding_info in embedding_list:
if embedding_info is None: if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["token_dim"]
):
print( print(
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
) )
return continue
# Resolve the situation in which an earlier embedding has claimed the same # Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to. # trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"] trigger_str = embedding_info.name
sourcefile = ( sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}" f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin" if ckpt_path.name == "learned_embeds.bin"
@ -109,7 +113,7 @@ class TextualInversionManager(BaseTextualInversionManager):
try: try:
self._add_textual_inversion( self._add_textual_inversion(
trigger_str, trigger_str,
embedding_info["embedding"], embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens, defer_injecting_tokens=defer_injecting_tokens,
) )
# remember which source file claims this trigger # remember which source file claims this trigger
@ -295,23 +299,24 @@ class TextualInversionManager(BaseTextualInversionManager):
return token_id return token_id
def _parse_embedding(self, embedding_file: str)->dict:
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
suffix = Path(embedding_file).suffix suffix = Path(embedding_file).suffix
try: try:
if suffix in [".pt",".ckpt",".bin"]: if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file) scan_result = scan_file_path(embedding_file)
if scan_result.infected_files == 1: if scan_result.infected_files > 0:
print( print(
f" ** Security Issues Found in Model: {scan_result.issues_count}" f" ** Security Issues Found in Model: {scan_result.issues_count}"
) )
print(" ** For your safety, InvokeAI will not load this embed.") print(" ** For your safety, InvokeAI will not load this embed.")
return return list()
ckpt = torch.load(embedding_file,map_location="cpu") ckpt = torch.load(embedding_file,map_location="cpu")
else: else:
ckpt = safetensors.torch.load_file(embedding_file) ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e: except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}") print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
return None return list()
# try to figure out what kind of embedding file it is and parse accordingly # try to figure out what kind of embedding file it is and parse accordingly
keys = list(ckpt.keys()) keys = list(ckpt.keys())
@ -327,79 +332,78 @@ class TextualInversionManager(BaseTextualInversionManager):
else: else:
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str): def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}') print(f' | Loading v1 embedding file: {basename}')
embedding_info = {} embeddings = list()
embedding_info["name"] = embedding_ckpt["name"] token_counter = -1
for token,embedding in embedding_ckpt["string_to_param"].items():
# Check num of embeddings and warn user only the first will be used if token_counter < 0:
embedding_info["num_of_embeddings"] = len( trigger = embedding_ckpt["name"]
embedding_ckpt["string_to_token"] elif token_counter == 0:
trigger = f'<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
) )
if embedding_info["num_of_embeddings"] > 1: embeddings.append(embedding_info)
print(" | More than 1 embedding found. Will use the first one") return embeddings
embedding = list(embedding_ckpt["string_to_param"].values())[0]
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
return embedding_info
def _parse_embedding_v2 ( def _parse_embedding_v2 (
self, embedding_ckpt: dict, file_path: str self, embedding_ckpt: dict, file_path: str
) -> dict: ) -> List[EmbeddingInfo]:
""" """
This handles embedding .pt file variant #2. This handles embedding .pt file variant #2.
""" """
basename = Path(file_path).stem basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}') print(f' | Loading v2 embedding file: {basename}')
embedding_info = {} embeddings = list()
if isinstance( if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
): ):
for token in list(embedding_ckpt["string_to_token"].keys()): token_counter = 0
embedding_info["name"] = ( for token,embedding in embedding_ckpt["string_to_param"].items():
token trigger = token if token != '*' \
if token != "*" else f'<{basename}>' if token_counter == 0 \
else f"<{basename}>" else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
) )
embedding_info["embedding"] = embedding_ckpt[ embeddings.append(embedding_info)
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else: else:
print(f" ** {basename}: Unrecognized embedding format") print(f" ** {basename}: Unrecognized embedding format")
embedding_info = None
return embedding_info return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str): def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
""" """
Parse 'version 3' of the .pt textual inversion embedding files. Parse 'version 3' of the .pt textual inversion embedding files.
""" """
basename = Path(file_path).stem basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}') print(f' | Loading v3 embedding file: {basename}')
embedding_info = {}
embedding_info["name"] = f'<{basename}>'
embedding_info["num_of_embeddings"] = 1
embedding = embedding_ckpt['emb_params'] embedding = embedding_ckpt['emb_params']
embedding_info["embedding"] = embedding embedding_info = EmbeddingInfo(
embedding_info["num_vectors_per_token"] = embedding.size()[0] name = f'<{basename}>',
embedding_info["token_dim"] = embedding.size()[1] embedding = embedding,
return embedding_info num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str): def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
""" """
Parse 'version 4' of the textual inversion embedding files. This one Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers. is usually associated with .bin files trained by HuggingFace diffusers.
@ -408,17 +412,17 @@ class TextualInversionManager(BaseTextualInversionManager):
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}') print(f' | Loading v4 embedding file: {short_path}')
embedding_info = {}
embeddings = list()
if list(embedding_ckpt.keys()) == 0: if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}") print(f" ** Invalid embeddings file: {short_path}")
embedding_info = None
else: else:
for token in list(embedding_ckpt.keys()): for token,embedding in embedding_ckpt.items():
embedding_info["name"] = ( embedding_info = EmbeddingInfo(
token name = token or f"<{basename}>",
or f"<{basename}>" embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
) )
embedding_info["embedding"] = embedding_ckpt[token] embeddings.append(embedding_info)
embedding_info["num_vectors_per_token"] = 1 # All Concepts seem to default to 1 return embeddings
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info