mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
be less verbose when assembling prompt
This commit is contained in:
parent
c6ae9f1176
commit
61357e4e6e
@ -35,9 +35,10 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz
|
||||
|
||||
pp = PromptParser()
|
||||
|
||||
def build_conditioning_list(prompt_string:str):
|
||||
def build_conditioning_list(prompt_string:str, verbose:bool = False):
|
||||
parsed_conjunction: Conjunction = pp.parse(prompt_string)
|
||||
print(f"parsed '{prompt_string}' to {parsed_conjunction}")
|
||||
if verbose:
|
||||
print(f"parsed '{prompt_string}' to {parsed_conjunction}")
|
||||
assert (type(parsed_conjunction) is Conjunction)
|
||||
|
||||
conditioning_list = []
|
||||
@ -46,7 +47,7 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz
|
||||
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
||||
fragments = [x[0] for x in flattened_prompt.children]
|
||||
attention_weights = [x[1] for x in flattened_prompt.children]
|
||||
print(fragments, attention_weights)
|
||||
#print(fragments, attention_weights)
|
||||
return model.get_learned_conditioning([fragments], attention_weights=[attention_weights])
|
||||
|
||||
for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights):
|
||||
@ -65,14 +66,14 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz
|
||||
|
||||
return conditioning_list
|
||||
|
||||
positive_conditioning_list = build_conditioning_list(prompt_string_cleaned)
|
||||
negative_conditioning_list = build_conditioning_list(unconditioned_words)
|
||||
positive_conditioning_list = build_conditioning_list(prompt_string_cleaned, verbose=True)
|
||||
negative_conditioning_list = build_conditioning_list(unconditioned_words, verbose=(len(unconditioned_words)>0) )
|
||||
|
||||
if len(negative_conditioning_list) == 0:
|
||||
negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]])
|
||||
else:
|
||||
if len(negative_conditioning_list)>1:
|
||||
print("cannot do conjunctions on unconditioning for now")
|
||||
print("cannot do conjunctions on unconditioning for now, everything except the first prompt will be ignored")
|
||||
negative_conditioning = negative_conditioning_list[0][0]
|
||||
|
||||
#positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight))
|
||||
|
@ -547,13 +547,13 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
|
||||
print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
return batch_z
|
||||
|
||||
@ -589,18 +589,19 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
)['input_ids']
|
||||
all_tokens = []
|
||||
per_token_weights = []
|
||||
print("all fragments:", fragments, weights)
|
||||
#print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(item_encodings):
|
||||
weight = weights[index]
|
||||
print("processing fragment", fragment, weight)
|
||||
#print("processing fragment", fragment, weight)
|
||||
fragment_tokens = item_encodings[index]
|
||||
print("fragment", fragment, "processed to", fragment_tokens)
|
||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
||||
# trim bos and eos markers before appending
|
||||
all_tokens.extend(fragment_tokens[1:-1])
|
||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
||||
|
||||
if len(all_tokens) > self.max_length - 2:
|
||||
print("prompt is too long and has been truncated")
|
||||
if (len(all_tokens) + 2) > self.max_length:
|
||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
||||
print(f"prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_tokens = all_tokens[:self.max_length - 2]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
@ -613,7 +614,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
return all_tokens_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
@ -641,7 +642,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
#print(weighted_z[:5])
|
||||
|
Loading…
Reference in New Issue
Block a user