bring in prompt parser from fix-prompts branch

attention is parsed but ignored, blends old syntax doesn't work,
	  conjunctions are parsed but ignored, the only part that's used
	  here is the new .blend() syntax and cross-attention control
	  using .swap()
This commit is contained in:
Damian at mba
2022-10-20 12:01:48 +02:00
parent 42883545f9
commit c9d27634b4
10 changed files with 169 additions and 61 deletions

View File

@ -4,7 +4,7 @@ weighted subprompts.
Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
@ -39,8 +39,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned)
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words)
# we don't support conjunctions for now
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0]
print("parsed prompt to", parsed_prompt)
conditioning = None
edited_conditioning = None
@ -50,7 +52,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt)
this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
@ -72,16 +74,16 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt)
edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt)
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
conditioning = original_embeddings
edited_conditioning = edited_embeddings
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
else:
conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt)
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
@ -91,11 +93,11 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
if type(flattened_prompt) is not FlattenedPrompt:
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
fragments = [x[0] for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True)
fragments = [x.text for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True)
return embeddings, tokens

View File

@ -34,7 +34,7 @@ class Img2Img(Generator):
t_enc = int(strength * steps)
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
def make_image(x_T):
# encode (scaled latent)
@ -52,7 +52,7 @@ class Img2Img(Generator):
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent,
structured_conditioning = structured_conditioning
extra_conditioning_info = extra_conditioning_info
# changes how noising is performed in ksampler
)

View File

@ -22,7 +22,7 @@ class Txt2Img(Generator):
"""
self.perlin = perlin
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):
@ -46,7 +46,7 @@ class Txt2Img(Generator):
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
structured_conditioning = structured_conditioning,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,

View File

@ -24,7 +24,7 @@ class Txt2Img2Img(Generator):
kwargs are 'width' and 'height'
"""
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):
@ -63,7 +63,7 @@ class Txt2Img2Img(Generator):
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback,
structured_conditioning = structured_conditioning
extra_conditioning_info = extra_conditioning_info
)
print(
@ -97,7 +97,7 @@ class Txt2Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
structured_conditioning = structured_conditioning
extra_conditioning_info = extra_conditioning_info
)
if self.free_gpu_mem:

View File

@ -7,7 +7,7 @@ class Prompt():
def __init__(self, parts: list):
for c in parts:
if not issubclass(type(c), BaseFragment):
if type(c) is not Attention and not issubclass(type(c), BaseFragment):
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts
def __repr__(self):
@ -56,6 +56,17 @@ class Fragment(BaseFragment):
and other.text == self.text \
and other.weight == self.weight
class Attention():
def __init__(self, weight: float, children: list):
self.weight = weight
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment):
pass
@ -65,7 +76,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
self.edited = edited
def __repr__(self):
return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')"
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
@ -137,7 +148,7 @@ class PromptParser():
self.root = self.build_parser_logic()
def parse(self, prompt: str) -> [list]:
def parse(self, prompt: str) -> Conjunction:
'''
:param prompt: The prompt string to parse
:return: a tuple
@ -181,13 +192,17 @@ class PromptParser():
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+'pr')
#print(prefix, " ParseResults expanded, results is now", results)
elif issubclass(type(node), BaseFragment):
results.append(node)
#elif type(node) is Attention:
# #if node.weight < 1:
# # todo: inject a blend when flattening attention with weight <1"
# for c in node.children:
# results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for c in node.children:
results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ')
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited)]
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
@ -204,7 +219,7 @@ class PromptParser():
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node), "results is", results)
print(prefix + "-> after flattening", type(node), "results is", results)
return results
#print("flattening", root)
@ -239,32 +254,83 @@ class PromptParser():
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention = pp.Forward()
attention_head = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_head")\
.set_debug(False)
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
.set_parse_action(make_fragment)\
.set_name("fragment_inside_attention")\
.set_debug(False)
attention_with_parens = pp.Forward()
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
attention_with_parens << (attention_head + attention_with_parens_body)
def make_attention(x):
# print("making Attention from parsing with args", x0, x1)
weight = 1
# number(str)
if type(x[0]) is float or type(x[0]) is int:
weight = float(x[0])
# +(str) or -(str) or +str or -str
elif type(x[0]) is str:
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
weight = pow(base, len(x[0]))
# print("Making attention with children of type", [str(type(x)) for x in x1])
return Attention(weight=weight, children=x[1])
attention_with_parens.set_parse_action(make_attention)\
.set_name("attention_with_parens")\
.set_debug(False)
# attention control of the form ++word --word (no parens)
attention_without_parens = (
(pp.Word('+') | pp.Word('-')) +
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
)\
.set_name("attention_without_parens")\
.set_debug(False)
attention_without_parens.set_parse_action(make_attention)
attention << (attention_with_parens | attention_without_parens)\
.set_name("attention")\
.set_debug(False)
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
original_words = (
(lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
(lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False)
(lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
(lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
).set_name('original_words')
edited_words = (
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
pp.CharsNotIn(')').set_name('termB').set_debug(False)
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
).set_name('edited_words')
cross_attention_substitute = original_words + \
cross_attention_substitute = (empty_string | original_words) + \
pp.Literal(".swap").suppress() + \
lparen + edited_words + rparen
(empty_string | (lparen + edited_words + rparen)
)
cross_attention_substitute.set_name('cross_attention_substitute')
def make_cross_attention_substitute(x):
#print("making cacs for", x)
return CrossAttentionControlSubstitute(x[0], x[1])
cacs = CrossAttentionControlSubstitute(x[0], x[1])
#print("made", cacs)
#return cacs
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text
prompt_part << (cross_attention_substitute
#| attention
| attention
| word
)
prompt_part.set_debug(False)