''' This module handles the generation of the conditioning tensors. Useful function exports: get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control ''' import re from typing import Union import torch from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment from ..models.diffusion import cross_attention_control from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False): prompt, negative_prompt = get_prompt_structure(prompt_string, skip_normalize_legacy_blend=skip_normalize_legacy_blend) conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens) return conditioning def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> ( Union[FlattenedPrompt, Blend], FlattenedPrompt): """ parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt) """ prompt, negative_prompt = _parse_prompt_string(prompt_string, skip_normalize_legacy_blend=skip_normalize_legacy_blend) return prompt, negative_prompt def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]: text_fragments = [x.text if type(x) is Fragment else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x)) for x in parsed_prompt.children] text = " ".join(text_fragments) tokens = model.cond_stage_model.tokenizer.tokenize(text) return tokens def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]: # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) else: prompt_string_cleaned = prompt_string_uncleaned pp = PromptParser() parsed_prompt: Union[FlattenedPrompt, Blend] = None legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend) if legacy_blend is not None: parsed_prompt = legacy_blend else: # we don't support conjunctions for now parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] return parsed_prompt, parsed_negative_prompt def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt, model, log_tokens=False) \ -> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]: """ Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info) """ if log_tokens: print(f">> Parsed prompt to {parsed_prompt}") print(f">> Parsed negative prompt to {parsed_negative_prompt}") conditioning = None cac_args: cross_attention_control.Arguments = None if type(parsed_prompt) is Blend: conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens) elif type(parsed_prompt) is FlattenedPrompt: if parsed_prompt.wants_cross_attention_control: conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens) else: conditioning, _ = _get_embeddings_and_tokens_for_prompt(model, parsed_prompt, log_tokens=log_tokens, log_display_label="(prompt)") else: raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type") unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model, parsed_negative_prompt, log_tokens=log_tokens, log_display_label="(unconditioning)") if isinstance(conditioning, dict): # hybrid conditioning is in play unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning) if cac_args is not None: print( ">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.") cac_args = None eos_token_index = 1 if type(parsed_prompt) is not Blend: tokens = get_tokens_for_prompt(model, parsed_prompt) eos_token_index = len(tokens)+1 return ( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( tokens_count_including_eos_bos=eos_token_index + 1, cross_attention_control_args=cac_args ) ) def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True): original_prompt = FlattenedPrompt() edited_prompt = FlattenedPrompt() # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed original_token_count = 0 edited_token_count = 0 edit_options = [] edit_opcodes = [] # beginning of sequence edit_opcodes.append( ('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1)) edit_options.append(None) original_token_count += 1 edited_token_count += 1 for fragment in prompt.children: if type(fragment) is CrossAttentionControlSubstitute: original_prompt.append(fragment.original) edited_prompt.append(fragment.edited) to_replace_token_count = _get_tokens_length(model, fragment.original) replacement_token_count = _get_tokens_length(model, fragment.edited) edit_opcodes.append(('replace', original_token_count, original_token_count + to_replace_token_count, edited_token_count, edited_token_count + replacement_token_count )) original_token_count += to_replace_token_count edited_token_count += replacement_token_count edit_options.append(fragment.options) # elif type(fragment) is CrossAttentionControlAppend: # edited_prompt.append(fragment.fragment) else: # regular fragment original_prompt.append(fragment) edited_prompt.append(fragment) count = _get_tokens_length(model, [fragment]) edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count, edited_token_count + count)) edit_options.append(None) original_token_count += count edited_token_count += count # end of sequence edit_opcodes.append( ('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1)) edit_options.append(None) original_token_count += 1 edited_token_count += 1 original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model, original_prompt, log_tokens=log_tokens, log_display_label="(.swap originals)") # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of # subsequent tokens when there is >1 edit and earlier edits change the total token count. # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra # token 'smiling' in the inactive 'cat' edit. # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model, edited_prompt, log_tokens=log_tokens, log_display_label="(.swap replacements)") conditioning = original_embeddings edited_conditioning = edited_embeddings # print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) cac_args = cross_attention_control.Arguments( edited_conditioning=edited_conditioning, edit_opcodes=edit_opcodes, edit_options=edit_options ) return conditioning, cac_args def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False): embeddings_to_blend = None for i, flattened_prompt in enumerate(blend.prompts): this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model, flattened_prompt, log_tokens=log_tokens, log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})") embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( (embeddings_to_blend, this_embedding)) conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), blend.weights, normalize=blend.normalize_weights) return conditioning def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False, log_display_label: str = None): if type(flattened_prompt) is not FlattenedPrompt: raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") fragments = [x.text for x in flattened_prompt.children] weights = [x.weight for x in flattened_prompt.children] embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) if log_tokens: text = " ".join(fragments) log_tokenization(text, model, display_label=log_display_label) return embeddings, tokens def _get_tokens_length(model, fragments: list[Fragment]): fragment_texts = [x.text for x in fragments] tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) return sum([len(x) for x in tokens]) def _flatten_hybrid_conditioning(uncond, cond): ''' This handles the choice between a conditional conditioning that is a tensor (used by cross attention) vs one that has additional dimensions as well, as used by 'hybrid' ''' assert isinstance(uncond, dict) assert isinstance(cond, dict) cond_flattened = dict() for k in cond: if isinstance(cond[k], list): cond_flattened[k] = [ torch.cat([uncond[k][i], cond[k][i]]) for i in range(len(cond[k])) ] else: cond_flattened[k] = torch.cat([uncond[k], cond[k]]) return uncond, cond_flattened def log_tokenization(text, model, display_label=None): """ shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' """ tokens = model.cond_stage_model.tokenizer.tokenize(text) tokenized = "" discarded = "" usedTokens = 0 totalTokens = len(tokens) for i in range(0, totalTokens): token = tokens[i].replace('', ' ') # alternate color s = (usedTokens % 6) + 1 if i < model.cond_stage_model.max_length: tokenized = tokenized + f"\x1b[0;3{s};40m{token}" usedTokens += 1 else: # over max token length discarded = discarded + f"\x1b[0;3{s};40m{token}" print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m") if discarded != "": print( f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m" )