diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 65459b5c5f..7c095de7b7 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -50,7 +50,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] - print("parsed prompt to", parsed_prompt) + print(f">> Parsed prompt to {parsed_prompt}") conditioning = None cac_args:CrossAttentionControl.Arguments = None @@ -59,7 +59,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n blend: Blend = parsed_prompt embeddings_to_blend = None for flattened_prompt in blend.prompts: - this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) + this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( (embeddings_to_blend, this_embedding)) conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), @@ -103,14 +103,14 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edit_options.append(None) original_token_count += count edited_token_count += count - original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens) # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of # subsequent tokens when there is >1 edit and earlier edits change the total token count. # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra # token 'smiling' in the inactive 'cat' edit. # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions - edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) + edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens) conditioning = original_embeddings edited_conditioning = edited_embeddings @@ -121,10 +121,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edit_options = edit_options ) else: - conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) + conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) - unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) return ( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( cross_attention_control_args=cac_args @@ -138,12 +138,27 @@ def build_token_edit_opcodes(original_tokens, edited_tokens): return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() -def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): +def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False): if type(flattened_prompt) is not FlattenedPrompt: raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") fragments = [x.text for x in flattened_prompt.children] weights = [x.weight for x in flattened_prompt.children] embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) + if not flattened_prompt.is_empty and log_tokens: + start_token = model.cond_stage_model.tokenizer.bos_token_id + end_token = model.cond_stage_model.tokenizer.eos_token_id + tokens_list = tokens[0].tolist() + if tokens_list[0] == start_token: + tokens_list[0] = '' + try: + first_end_token_index = tokens_list.index(end_token) + tokens_list[first_end_token_index] = '' + tokens_list = tokens_list[:first_end_token_index+1] + except ValueError: + pass + + print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}") + return embeddings, tokens def get_tokens_length(model, fragments: list[Fragment]): diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 3a96d664f0..6709f48066 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -51,6 +51,11 @@ class FlattenedPrompt(): raise PromptParser.ParsingException( f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + @property + def is_empty(self): + return len(self.children) == 0 or \ + (len(self.children) == 1 and len(self.children[0].text) == 0) + def __repr__(self): return f"FlattenedPrompt:{self.children}" def __eq__(self, other):