diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 6a44986f8d..924ea39c77 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -16,7 +16,8 @@ from typing import Union import torch from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ - CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend + CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment +from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder @@ -65,27 +66,54 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n if wants_cross_attention_control: original_prompt = FlattenedPrompt() edited_prompt = FlattenedPrompt() + # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed + original_token_count = 0 + edited_token_count = 0 + edit_opcodes = [] + edit_options = [] for fragment in flattened_prompt.children: if type(fragment) is CrossAttentionControlSubstitute: original_prompt.append(fragment.original) edited_prompt.append(fragment.edited) + + to_replace_token_count = get_tokens_length(model, fragment.original) + replacement_token_count = get_tokens_length(model, fragment.edited) + edit_opcodes.append(('replace', + original_token_count, original_token_count + to_replace_token_count, + edited_token_count, edited_token_count + replacement_token_count + )) + original_token_count += to_replace_token_count + edited_token_count += replacement_token_count + edit_options.append(fragment.options) #elif type(fragment) is CrossAttentionControlAppend: # edited_prompt.append(fragment.fragment) else: # regular fragment original_prompt.append(fragment) edited_prompt.append(fragment) + + count = get_tokens_length(model, [fragment]) + edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count)) + edit_options.append(None) + original_token_count += count + edited_token_count += count original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) conditioning = original_embeddings edited_conditioning = edited_embeddings - edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens) + print('got edit_opcodes', edit_opcodes, 'options', edit_options) else: conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) - return (unconditioning, conditioning, edited_conditioning, edit_opcodes) + return ( + unconditioning, conditioning, edited_conditioning, edit_opcodes + #InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning, + # edit_opcodes=edit_opcodes, + # edit_options=edit_options) + ) def build_token_edit_opcodes(original_tokens, edited_tokens): @@ -102,6 +130,10 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) return embeddings, tokens +def get_tokens_length(model, fragments: list[Fragment]): + fragment_texts = [x.text for x in fragments] + tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) + return sum([len(x) for x in tokens]) def split_weighted_subprompts(text, skip_normalize=False)->list: diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 2883f24d1a..8917a27a40 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -557,6 +557,21 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): else: return batch_z + def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: + tokens = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + if include_start_and_end_markers: + return tokens + else: + return [x[1:-1] for x in tokens] + + @classmethod def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)