mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add Embedding Parsing (#1973)
* Add Embedding Parsing * Add Embedding Parsing * Return token_dim in embedding_info * fixes to handle other variants 1. Handle the case of a .bin file being mislabeled .pt (seen in the wild at https://cyberes.github.io/stable-diffusion-textual-inversion-models/) 2. Handle the "broken" .pt files reported by https://github.com/invoke-ai/InvokeAI/issues/1829 3. When token name is not available, use the basename of the pt or bin file rather than the whole path. fixes #1829 * remove whitespace Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
parent
1050f2726a
commit
69cc0993f8
@ -256,32 +256,22 @@ class EmbeddingManager(nn.Module):
|
||||
return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')]
|
||||
|
||||
def _load(self, ckpt_path, full=True):
|
||||
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
return
|
||||
except Exception:
|
||||
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
|
||||
return
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
# Handle .pt textual inversion files
|
||||
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
||||
filename = os.path.basename(ckpt_path)
|
||||
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
|
||||
if len(ckpt["string_to_token"]) > 1:
|
||||
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
|
||||
|
||||
string_to_param_dict = ckpt['string_to_param']
|
||||
embedding = list(string_to_param_dict.values())[0]
|
||||
self.add_embedding(token_str, embedding, full)
|
||||
|
||||
# Handle .bin textual inversion files from Huggingface Concepts
|
||||
# https://huggingface.co/sd-concepts-library
|
||||
embedding_info = self.parse_embedding(ckpt_path)
|
||||
if embedding_info:
|
||||
self.max_vectors_per_token = embedding_info['num_vectors_per_token']
|
||||
self.add_embedding(embedding_info['name'], embedding_info['embedding'], full)
|
||||
else:
|
||||
for token_str in list(ckpt.keys()):
|
||||
embedding = ckpt[token_str]
|
||||
self.add_embedding(token_str, embedding, full)
|
||||
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
|
||||
|
||||
def add_embedding(self, token_str, embedding, full):
|
||||
if token_str in self.string_to_param_dict:
|
||||
@ -302,6 +292,92 @@ class EmbeddingManager(nn.Module):
|
||||
self.string_to_token_dict[token_str] = token
|
||||
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
|
||||
|
||||
def parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split('.')[-1]
|
||||
if file_type == 'pt':
|
||||
return self.parse_embedding_pt(embedding_file)
|
||||
elif file_type == 'bin':
|
||||
return self.parse_embedding_bin(embedding_file)
|
||||
else:
|
||||
print(f'>> Not a recognized embedding file: {embedding_file}')
|
||||
|
||||
def parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_info = {}
|
||||
|
||||
# Check if valid embedding file
|
||||
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
|
||||
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
|
||||
if embedding_info['num_of_embeddings'] > 1:
|
||||
print('>> More than 1 embedding found. Will use the first one')
|
||||
|
||||
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||
except (AttributeError,KeyError):
|
||||
return self.handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info['embedding'] = embedding
|
||||
embedding_info['num_vectors_per_token'] = embedding.size()[0]
|
||||
embedding_info['token_dim'] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info['trained_steps'] = embedding_ckpt['step']
|
||||
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
|
||||
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys())==1:
|
||||
print('>> Detected .bin file masquerading as .pt file')
|
||||
embedding_info = self.parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(">> Invalid concepts file")
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt[token]
|
||||
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
|
||||
'''
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
'''
|
||||
embedding_info = {}
|
||||
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt['string_to_token'].keys()):
|
||||
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def has_embedding_for_token(self, token_str):
|
||||
return token_str in self.string_to_token_dict
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user