be less verbose when assembling prompt

This commit is contained in:
Damian at mba 2022-10-16 01:53:44 +02:00
parent c6ae9f1176
commit 61357e4e6e
2 changed files with 17 additions and 15 deletions

View File

@ -35,9 +35,10 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz
pp = PromptParser()
def build_conditioning_list(prompt_string:str):
def build_conditioning_list(prompt_string:str, verbose:bool = False):
parsed_conjunction: Conjunction = pp.parse(prompt_string)
print(f"parsed '{prompt_string}' to {parsed_conjunction}")
if verbose:
print(f"parsed '{prompt_string}' to {parsed_conjunction}")
assert (type(parsed_conjunction) is Conjunction)
conditioning_list = []
@ -46,7 +47,7 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
fragments = [x[0] for x in flattened_prompt.children]
attention_weights = [x[1] for x in flattened_prompt.children]
print(fragments, attention_weights)
#print(fragments, attention_weights)
return model.get_learned_conditioning([fragments], attention_weights=[attention_weights])
for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights):
@ -65,14 +66,14 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz
return conditioning_list
positive_conditioning_list = build_conditioning_list(prompt_string_cleaned)
negative_conditioning_list = build_conditioning_list(unconditioned_words)
positive_conditioning_list = build_conditioning_list(prompt_string_cleaned, verbose=True)
negative_conditioning_list = build_conditioning_list(unconditioned_words, verbose=(len(unconditioned_words)>0) )
if len(negative_conditioning_list) == 0:
negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]])
else:
if len(negative_conditioning_list)>1:
print("cannot do conjunctions on unconditioning for now")
print("cannot do conjunctions on unconditioning for now, everything except the first prompt will be ignored")
negative_conditioning = negative_conditioning_list[0][0]
#positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight))

View File

@ -547,13 +547,13 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1)
# should have shape (B, 77, 768)
print(f"assembled all tokens into tensor of shape {batch_z.shape}")
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
return batch_z
@ -589,18 +589,19 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
)['input_ids']
all_tokens = []
per_token_weights = []
print("all fragments:", fragments, weights)
#print("all fragments:", fragments, weights)
for index, fragment in enumerate(item_encodings):
weight = weights[index]
print("processing fragment", fragment, weight)
#print("processing fragment", fragment, weight)
fragment_tokens = item_encodings[index]
print("fragment", fragment, "processed to", fragment_tokens)
#print("fragment", fragment, "processed to", fragment_tokens)
# trim bos and eos markers before appending
all_tokens.extend(fragment_tokens[1:-1])
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
if len(all_tokens) > self.max_length - 2:
print("prompt is too long and has been truncated")
if (len(all_tokens) + 2) > self.max_length:
excess_token_count = (len(all_tokens) + 2) - self.max_length
print(f"prompt is {excess_token_count} token(s) too long and has been truncated")
all_tokens = all_tokens[:self.max_length - 2]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
@ -613,7 +614,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
return all_tokens_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
@ -641,7 +642,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
weighted_z_delta_from_empty = (weighted_z-empty_z)
print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:")
#print(weighted_z[:5])