mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
remove unused embeddings code
This commit is contained in:
parent
850d1ee984
commit
1a7371ea17
@ -26,7 +26,7 @@ from invokeai.backend.modules.get_canvas_generation_mode import (
|
||||
from invokeai.backend.modules.parameters import parameters_to_command
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer
|
||||
from invokeai.backend.ldm.conditioning import get_tokens_for_prompt_object, get_prompt_structure, get_tokenizer
|
||||
from .generator import infill_methods, PipelineIntermediateState
|
||||
from ldm.invoke.globals import ( Globals, global_converted_ckpts_dir,
|
||||
global_models_dir
|
||||
|
@ -1,377 +0,0 @@
|
||||
import os.path
|
||||
from cmath import log
|
||||
import torch
|
||||
from attr import dataclass
|
||||
from torch import nn
|
||||
|
||||
import sys
|
||||
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.data.personalized import per_img_token_list
|
||||
from transformers import CLIPTokenizer
|
||||
from functools import partial
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
PROGRESSIVE_SCALE = 2000
|
||||
|
||||
|
||||
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
|
||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||
return token_id
|
||||
|
||||
def get_embedding_for_clip_token_id(embedder, token_id):
|
||||
if type(token_id) is not torch.Tensor:
|
||||
token_id = torch.tensor(token_id, dtype=torch.int)
|
||||
return embedder(token_id.unsqueeze(0))[0, 0]
|
||||
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
placeholder_strings=None,
|
||||
initializer_words=None,
|
||||
per_image_tokens=False,
|
||||
num_vectors_per_token=1,
|
||||
progressive_words=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedder = embedder
|
||||
self.concepts_library=HuggingFaceConceptsLibrary()
|
||||
|
||||
self.string_to_token_dict = {}
|
||||
self.string_to_param_dict = nn.ParameterDict()
|
||||
|
||||
self.initial_embeddings = (
|
||||
nn.ParameterDict()
|
||||
) # These should not be optimized
|
||||
|
||||
self.progressive_words = progressive_words
|
||||
self.progressive_counter = 0
|
||||
|
||||
self.max_vectors_per_token = num_vectors_per_token
|
||||
|
||||
if hasattr(
|
||||
embedder, 'tokenizer'
|
||||
): # using Stable Diffusion's CLIP encoder
|
||||
self.is_clip = True
|
||||
get_token_id_for_string = partial(
|
||||
get_clip_token_id_for_string, embedder.tokenizer
|
||||
)
|
||||
get_embedding_for_tkn_id = partial(
|
||||
get_embedding_for_clip_token_id,
|
||||
embedder.transformer.text_model.embeddings,
|
||||
)
|
||||
# per bug report #572
|
||||
#token_dim = 1280
|
||||
token_dim = 768
|
||||
else: # using LDM's BERT encoder
|
||||
self.is_clip = False
|
||||
get_token_id_for_string = partial(
|
||||
get_bert_token_id_for_string, embedder.tknz_fn
|
||||
)
|
||||
get_embedding_for_tkn_id = embedder.transformer.token_emb
|
||||
token_dim = 1280
|
||||
|
||||
if per_image_tokens:
|
||||
placeholder_strings.extend(per_img_token_list)
|
||||
|
||||
for idx, placeholder_string in enumerate(placeholder_strings):
|
||||
|
||||
token_id = get_token_id_for_string(placeholder_string)
|
||||
|
||||
if initializer_words and idx < len(initializer_words):
|
||||
init_word_token_id = get_token_id_for_string(initializer_words[idx])
|
||||
|
||||
with torch.no_grad():
|
||||
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
|
||||
|
||||
token_params = torch.nn.Parameter(
|
||||
init_word_embedding.unsqueeze(0).repeat(
|
||||
num_vectors_per_token, 1
|
||||
),
|
||||
requires_grad=True,
|
||||
)
|
||||
self.initial_embeddings[
|
||||
placeholder_string
|
||||
] = torch.nn.Parameter(
|
||||
init_word_embedding.unsqueeze(0).repeat(
|
||||
num_vectors_per_token, 1
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
token_params = torch.nn.Parameter(
|
||||
torch.rand(
|
||||
size=(num_vectors_per_token, token_dim),
|
||||
requires_grad=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.string_to_token_dict[placeholder_string] = token_id
|
||||
self.string_to_param_dict[placeholder_string] = token_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokenized_text,
|
||||
embedded_text,
|
||||
):
|
||||
# torch.save(embedded_text, '/tmp/embedding-manager-uglysonic-pre-rewrite.pt')
|
||||
|
||||
b, n, device = *tokenized_text.shape, tokenized_text.device
|
||||
|
||||
for (
|
||||
placeholder_string,
|
||||
placeholder_token,
|
||||
) in self.string_to_token_dict.items():
|
||||
|
||||
placeholder_embedding = self.string_to_param_dict[
|
||||
placeholder_string
|
||||
].to(device)
|
||||
|
||||
if self.progressive_words:
|
||||
self.progressive_counter += 1
|
||||
max_step_tokens = (
|
||||
1 + self.progressive_counter // PROGRESSIVE_SCALE
|
||||
)
|
||||
else:
|
||||
max_step_tokens = self.max_vectors_per_token
|
||||
|
||||
num_vectors_for_token = min(
|
||||
placeholder_embedding.shape[0], max_step_tokens
|
||||
)
|
||||
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token
|
||||
)
|
||||
|
||||
if placeholder_rows.nelement() == 0:
|
||||
continue
|
||||
|
||||
sorted_cols, sort_idx = torch.sort(
|
||||
placeholder_cols, descending=True
|
||||
)
|
||||
sorted_rows = placeholder_rows[sort_idx]
|
||||
|
||||
for idx in range(sorted_rows.shape[0]):
|
||||
row = sorted_rows[idx]
|
||||
col = sorted_cols[idx]
|
||||
|
||||
new_token_row = torch.cat(
|
||||
[
|
||||
tokenized_text[row][:col],
|
||||
torch.tensor([placeholder_token] * num_vectors_for_token, device=device),
|
||||
tokenized_text[row][col + 1 :],
|
||||
],
|
||||
axis=0,
|
||||
)[:n]
|
||||
new_embed_row = torch.cat(
|
||||
[
|
||||
embedded_text[row][:col],
|
||||
placeholder_embedding[:num_vectors_for_token],
|
||||
embedded_text[row][col + 1 :],
|
||||
],
|
||||
axis=0,
|
||||
)[:n]
|
||||
|
||||
embedded_text[row] = new_embed_row
|
||||
tokenized_text[row] = new_token_row
|
||||
|
||||
return embedded_text
|
||||
|
||||
def save(self, ckpt_path):
|
||||
torch.save(
|
||||
{
|
||||
'string_to_token': self.string_to_token_dict,
|
||||
'string_to_param': self.string_to_param_dict,
|
||||
},
|
||||
ckpt_path,
|
||||
)
|
||||
|
||||
def load(self, ckpt_paths, full=True):
|
||||
if len(ckpt_paths) == 0:
|
||||
return
|
||||
if type(ckpt_paths) != list:
|
||||
ckpt_paths = [ckpt_paths]
|
||||
ckpt_paths = self._expand_directories(ckpt_paths)
|
||||
for c in ckpt_paths:
|
||||
self._load(c,full)
|
||||
# remember that we know this term and don't try to download it again from the concepts library
|
||||
# note that if the concept name is also provided and different from the trigger term, they
|
||||
# both will be stored in this dictionary
|
||||
for term in self.string_to_param_dict.keys():
|
||||
term = term.strip('<').strip('>')
|
||||
self.concepts_loaded[term] = True
|
||||
print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}')
|
||||
|
||||
def _expand_directories(self, paths:list[str]):
|
||||
expanded_paths = list()
|
||||
for path in paths:
|
||||
if os.path.isfile(path):
|
||||
expanded_paths.append(path)
|
||||
elif os.path.isdir(path):
|
||||
for root, _, files in os.walk(path):
|
||||
for name in files:
|
||||
expanded_paths.append(os.path.join(root,name))
|
||||
return [x for x in expanded_paths if os.path.splitext(x)[1] in ('.pt','.bin')]
|
||||
|
||||
def _load(self, ckpt_path, full=True):
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
return
|
||||
except Exception:
|
||||
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
|
||||
return
|
||||
|
||||
embedding_info = self.parse_embedding(ckpt_path)
|
||||
if embedding_info:
|
||||
self.max_vectors_per_token = embedding_info['num_vectors_per_token']
|
||||
self.add_embedding(embedding_info['name'], embedding_info['embedding'], full)
|
||||
else:
|
||||
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
|
||||
|
||||
def add_embedding(self, token_str, embedding, full):
|
||||
if token_str in self.string_to_param_dict:
|
||||
print(f">> Embedding manager refusing to overwrite already-loaded term '{token_str}'")
|
||||
return
|
||||
if not full:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
|
||||
existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||
if existing_token_id == self.embedder.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
self.embedder.transformer.resize_token_embeddings(new_token_count)
|
||||
|
||||
token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||
self.string_to_token_dict[token_str] = token_id
|
||||
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
|
||||
|
||||
def parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split('.')[-1]
|
||||
if file_type == 'pt':
|
||||
return self.parse_embedding_pt(embedding_file)
|
||||
elif file_type == 'bin':
|
||||
return self.parse_embedding_bin(embedding_file)
|
||||
else:
|
||||
print(f'>> Not a recognized embedding file: {embedding_file}')
|
||||
|
||||
def parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_info = {}
|
||||
|
||||
# Check if valid embedding file
|
||||
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
|
||||
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
|
||||
if embedding_info['num_of_embeddings'] > 1:
|
||||
print('>> More than 1 embedding found. Will use the first one')
|
||||
|
||||
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||
except (AttributeError,KeyError):
|
||||
return self.handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info['embedding'] = embedding
|
||||
embedding_info['num_vectors_per_token'] = embedding.size()[0]
|
||||
embedding_info['token_dim'] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info['trained_steps'] = embedding_ckpt['step']
|
||||
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
|
||||
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys())==1:
|
||||
print('>> Detected .bin file masquerading as .pt file')
|
||||
embedding_info = self.parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(">> Invalid concepts file")
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt[token]
|
||||
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
|
||||
'''
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
'''
|
||||
embedding_info = {}
|
||||
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||
print(f'>> Variant Embedding Detected. Parsing: {embedding_file}') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
token = list(embedding_ckpt['string_to_token'].keys())[0]
|
||||
embedding_info['name'] = os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def has_embedding_for_token(self, token_str):
|
||||
return token_str in self.string_to_token_dict
|
||||
|
||||
def get_embedding_norms_squared(self):
|
||||
all_params = torch.cat(
|
||||
list(self.string_to_param_dict.values()), axis=0
|
||||
) # num_placeholders x embedding_dim
|
||||
param_norm_squared = (all_params * all_params).sum(
|
||||
axis=-1
|
||||
) # num_placeholders
|
||||
|
||||
return param_norm_squared
|
||||
|
||||
def embedding_parameters(self):
|
||||
return self.string_to_param_dict.parameters()
|
||||
|
||||
def embedding_to_coarse_loss(self):
|
||||
|
||||
loss = 0.0
|
||||
num_embeddings = len(self.initial_embeddings)
|
||||
|
||||
for key in self.initial_embeddings:
|
||||
optimized = self.string_to_param_dict[key]
|
||||
coarse = self.initial_embeddings[key].clone().to(optimized.device)
|
||||
|
||||
loss = (
|
||||
loss
|
||||
+ (optimized - coarse)
|
||||
@ (optimized - coarse).T
|
||||
/ num_embeddings
|
||||
)
|
||||
|
||||
return loss
|
Loading…
Reference in New Issue
Block a user