re-enable tokenization logging

This commit is contained in:
Damian at mba 2022-10-24 11:49:47 +02:00
parent 194c8e1c2e
commit 61a4897b71
2 changed files with 27 additions and 7 deletions

View File

@ -50,7 +50,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
print("parsed prompt to", parsed_prompt)
print(f">> Parsed prompt to {parsed_prompt}")
conditioning = None
cac_args:CrossAttentionControl.Arguments = None
@ -59,7 +59,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
@ -103,14 +103,14 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options.append(None)
original_token_count += count
edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
conditioning = original_embeddings
edited_conditioning = edited_embeddings
@ -121,10 +121,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options = edit_options
)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args
@ -138,12 +138,27 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if not flattened_prompt.is_empty and log_tokens:
start_token = model.cond_stage_model.tokenizer.bos_token_id
end_token = model.cond_stage_model.tokenizer.eos_token_id
tokens_list = tokens[0].tolist()
if tokens_list[0] == start_token:
tokens_list[0] = '<start>'
try:
first_end_token_index = tokens_list.index(end_token)
tokens_list[first_end_token_index] = '<end>'
tokens_list = tokens_list[:first_end_token_index+1]
except ValueError:
pass
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):

View File

@ -51,6 +51,11 @@ class FlattenedPrompt():
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
@property
def is_empty(self):
return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0)
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):