InvokeAI/ldm/invoke/conditioning.py
2023-02-05 22:55:29 +13:00

299 lines
14 KiB
Python

'''
This module handles the generation of the conditioning tensors.
Useful function exports:
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
'''
import re
from typing import Union
import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from ldm.invoke.globals import Globals
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
prompt, negative_prompt = get_prompt_structure(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
return conditioning
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
Union[FlattenedPrompt, Blend], FlattenedPrompt):
"""
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
"""
prompt, negative_prompt = _parse_prompt_string(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
return prompt, negative_prompt
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x))
for x in parsed_prompt.children]
text = " ".join(text_fragments)
tokens = model.cond_stage_model.tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
if legacy_blend is not None:
parsed_prompt = legacy_blend
else:
# we don't support conjunctions for now
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
return parsed_prompt, parsed_negative_prompt
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
model, log_tokens=False) \
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
"""
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
if log_tokens or Globals.log_tokenization:
print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}")
conditioning = None
cac_args: cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend:
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
elif type(parsed_prompt) is FlattenedPrompt:
if parsed_prompt.wants_cross_attention_control:
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
else:
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
else:
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
eos_token_index = len(longest_token_sequence)+1
else:
tokens = get_tokens_for_prompt(model, parsed_prompt)
eos_token_index = len(tokens)+1
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=eos_token_index + 1,
cross_attention_control_args=cac_args
)
)
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = _get_tokens_length(model, fragment.original)
replacement_token_count = _get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
# elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = _get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
edited_token_count + count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning=edited_conditioning,
edit_opcodes=edit_opcodes,
edit_options=edit_options
)
return conditioning, cac_args
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
embeddings_to_blend = None
for i, flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedPromptFragmentsToEmbeddingsConverter.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
return conditioning
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
log_display_label: str = None):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if log_tokens or Globals.log_tokenization:
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)
return embeddings, tokens
def _get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def _flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid'
'''
assert isinstance(uncond, dict)
assert isinstance(cond, dict)
cond_flattened = dict()
for k in cond:
if isinstance(cond[k], list):
cond_flattened[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened
def log_tokenization(text, model, display_label=None):
""" shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f'{tokenized}\x1b[0m')
if discarded != "":
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):')
print(f'{discarded}\x1b[0m')