mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
bring in prompt parser from fix-prompts branch
attention is parsed but ignored, blends old syntax doesn't work, conjunctions are parsed but ignored, the only part that's used here is the new .blend() syntax and cross-attention control using .swap()
This commit is contained in:
parent
42883545f9
commit
c9d27634b4
@ -4,7 +4,7 @@ weighted subprompts.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c() get the conditioned and unconditioned latent
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
|
||||
@ -39,8 +39,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
|
||||
pp = PromptParser()
|
||||
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned)
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words)
|
||||
# we don't support conjunctions for now
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0]
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0]
|
||||
print("parsed prompt to", parsed_prompt)
|
||||
|
||||
conditioning = None
|
||||
edited_conditioning = None
|
||||
@ -50,7 +52,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for flattened_prompt in blend.prompts:
|
||||
this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt)
|
||||
this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
@ -72,16 +74,16 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt)
|
||||
edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt)
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
|
||||
else:
|
||||
conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt)
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
|
||||
|
||||
unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt)
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
|
||||
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
|
||||
|
||||
|
||||
@ -91,11 +93,11 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
|
||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||
|
||||
def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
||||
fragments = [x[0] for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True)
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True)
|
||||
return embeddings, tokens
|
||||
|
||||
|
||||
|
@ -34,7 +34,7 @@ class Img2Img(Generator):
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
@ -52,7 +52,7 @@ class Img2Img(Generator):
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent,
|
||||
structured_conditioning = structured_conditioning
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
# changes how noising is performed in ksampler
|
||||
)
|
||||
|
||||
|
@ -22,7 +22,7 @@ class Txt2Img(Generator):
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -46,7 +46,7 @@ class Txt2Img(Generator):
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
structured_conditioning = structured_conditioning,
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
|
@ -24,7 +24,7 @@ class Txt2Img2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -63,7 +63,7 @@ class Txt2Img2Img(Generator):
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
structured_conditioning = structured_conditioning
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
|
||||
print(
|
||||
@ -97,7 +97,7 @@ class Txt2Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
structured_conditioning = structured_conditioning
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
|
@ -7,7 +7,7 @@ class Prompt():
|
||||
|
||||
def __init__(self, parts: list):
|
||||
for c in parts:
|
||||
if not issubclass(type(c), BaseFragment):
|
||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment):
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
|
||||
self.children = parts
|
||||
def __repr__(self):
|
||||
@ -56,6 +56,17 @@ class Fragment(BaseFragment):
|
||||
and other.text == self.text \
|
||||
and other.weight == self.weight
|
||||
|
||||
class Attention():
|
||||
def __init__(self, weight: float, children: list):
|
||||
self.weight = weight
|
||||
self.children = children
|
||||
#print(f"A: requested attention '{children}' to {weight}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Attention:'{self.children}' @ {self.weight}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||
|
||||
class CrossAttentionControlledFragment(BaseFragment):
|
||||
pass
|
||||
|
||||
@ -65,7 +76,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
||||
self.edited = edited
|
||||
|
||||
def __repr__(self):
|
||||
return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')"
|
||||
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})"
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlSubstitute \
|
||||
and other.original == self.original \
|
||||
@ -137,7 +148,7 @@ class PromptParser():
|
||||
self.root = self.build_parser_logic()
|
||||
|
||||
|
||||
def parse(self, prompt: str) -> [list]:
|
||||
def parse(self, prompt: str) -> Conjunction:
|
||||
'''
|
||||
:param prompt: The prompt string to parse
|
||||
:return: a tuple
|
||||
@ -181,13 +192,17 @@ class PromptParser():
|
||||
for x in node:
|
||||
results = flatten_internal(x, weight_scale, results, prefix+'pr')
|
||||
#print(prefix, " ParseResults expanded, results is now", results)
|
||||
elif issubclass(type(node), BaseFragment):
|
||||
results.append(node)
|
||||
#elif type(node) is Attention:
|
||||
# #if node.weight < 1:
|
||||
# # todo: inject a blend when flattening attention with weight <1"
|
||||
# for c in node.children:
|
||||
# results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
|
||||
elif type(node) is Attention:
|
||||
# if node.weight < 1:
|
||||
# todo: inject a blend when flattening attention with weight <1"
|
||||
for c in node.children:
|
||||
results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ')
|
||||
elif type(node) is Fragment:
|
||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||
elif type(node) is CrossAttentionControlSubstitute:
|
||||
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
|
||||
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
|
||||
results += [CrossAttentionControlSubstitute(original, edited)]
|
||||
elif type(node) is Blend:
|
||||
flattened_subprompts = []
|
||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||
@ -204,7 +219,7 @@ class PromptParser():
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
#print(prefix + "-> after flattening", type(node), "results is", results)
|
||||
print(prefix + "-> after flattening", type(node), "results is", results)
|
||||
return results
|
||||
|
||||
#print("flattening", root)
|
||||
@ -239,32 +254,83 @@ class PromptParser():
|
||||
else:
|
||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||
|
||||
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
|
||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||
attention = pp.Forward()
|
||||
attention_head = (number | pp.Word('+') | pp.Word('-'))\
|
||||
.set_name("attention_head")\
|
||||
.set_debug(False)
|
||||
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
|
||||
.set_parse_action(make_fragment)\
|
||||
.set_name("fragment_inside_attention")\
|
||||
.set_debug(False)
|
||||
attention_with_parens = pp.Forward()
|
||||
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
|
||||
attention_with_parens << (attention_head + attention_with_parens_body)
|
||||
|
||||
def make_attention(x):
|
||||
# print("making Attention from parsing with args", x0, x1)
|
||||
weight = 1
|
||||
# number(str)
|
||||
if type(x[0]) is float or type(x[0]) is int:
|
||||
weight = float(x[0])
|
||||
# +(str) or -(str) or +str or -str
|
||||
elif type(x[0]) is str:
|
||||
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
|
||||
weight = pow(base, len(x[0]))
|
||||
# print("Making attention with children of type", [str(type(x)) for x in x1])
|
||||
return Attention(weight=weight, children=x[1])
|
||||
|
||||
attention_with_parens.set_parse_action(make_attention)\
|
||||
.set_name("attention_with_parens")\
|
||||
.set_debug(False)
|
||||
|
||||
# attention control of the form ++word --word (no parens)
|
||||
attention_without_parens = (
|
||||
(pp.Word('+') | pp.Word('-')) +
|
||||
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
|
||||
)\
|
||||
.set_name("attention_without_parens")\
|
||||
.set_debug(False)
|
||||
attention_without_parens.set_parse_action(make_attention)
|
||||
|
||||
attention << (attention_with_parens | attention_without_parens)\
|
||||
.set_name("attention")\
|
||||
.set_debug(False)
|
||||
|
||||
# cross-attention control
|
||||
empty_string = ((lparen + rparen) |
|
||||
pp.Literal('""').suppress() |
|
||||
(lparen + pp.Literal('""').suppress() + rparen)
|
||||
).set_parse_action(lambda x: Fragment(""))
|
||||
|
||||
original_words = (
|
||||
(lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
|
||||
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
|
||||
(lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False)
|
||||
(lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
|
||||
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
|
||||
(lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
|
||||
).set_name('original_words')
|
||||
edited_words = (
|
||||
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
|
||||
pp.CharsNotIn(')').set_name('termB').set_debug(False)
|
||||
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
|
||||
pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
|
||||
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
|
||||
).set_name('edited_words')
|
||||
cross_attention_substitute = original_words + \
|
||||
cross_attention_substitute = (empty_string | original_words) + \
|
||||
pp.Literal(".swap").suppress() + \
|
||||
lparen + edited_words + rparen
|
||||
(empty_string | (lparen + edited_words + rparen)
|
||||
)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||
|
||||
def make_cross_attention_substitute(x):
|
||||
#print("making cacs for", x)
|
||||
return CrossAttentionControlSubstitute(x[0], x[1])
|
||||
cacs = CrossAttentionControlSubstitute(x[0], x[1])
|
||||
#print("made", cacs)
|
||||
#return cacs
|
||||
return cacs
|
||||
|
||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||
|
||||
# simple fragments of text
|
||||
prompt_part << (cross_attention_substitute
|
||||
#| attention
|
||||
| attention
|
||||
| word
|
||||
)
|
||||
prompt_part.set_debug(False)
|
||||
|
@ -16,10 +16,10 @@ class DDIMSampler(Sampler):
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
@ -34,10 +34,10 @@ class CFGDenoiser(nn.Module):
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
@ -164,7 +164,7 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
structured_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -197,7 +197,7 @@ class KSampler(Sampler):
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
@ -224,7 +224,7 @@ class KSampler(Sampler):
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
structured_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
@ -250,7 +250,7 @@ class KSampler(Sampler):
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
|
||||
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap,
|
||||
img,
|
||||
|
@ -20,10 +20,10 @@ class PLMSSampler(Sampler):
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
@ -439,6 +439,13 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
|
||||
should_return_tokens = False
|
||||
if 'return_tokens' in kwargs:
|
||||
should_return_tokens = kwargs.get('return_tokens', False)
|
||||
# self.transformer doesn't like having extra kwargs
|
||||
kwargs.pop('return_tokens')
|
||||
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
@ -451,7 +458,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
z = self.transformer(input_ids=tokens, **kwargs)
|
||||
|
||||
if kwargs.get('return_tokens', False):
|
||||
if should_return_tokens:
|
||||
return z, tokens
|
||||
else:
|
||||
return z
|
||||
|
@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute
|
||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
|
||||
Fragment
|
||||
|
||||
|
||||
def parse_prompt(prompt_string):
|
||||
@ -135,7 +136,7 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
def test_cross_attention_control(self):
|
||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute('flames', 'trees')])])
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
|
||||
@ -144,13 +145,13 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
|
||||
|
||||
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute('flames', 'trees and houses')])])
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
|
||||
|
||||
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute('trees and houses', 'flames')])])
|
||||
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
|
||||
@ -159,14 +160,46 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
|
||||
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute('flames', 'trees'),
|
||||
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
|
||||
(', fire', 1.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
||||
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape "".swap("in winter")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape " ".swap("in winter")'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap("")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap(" ")'))
|
||||
|
||||
def test_cross_attention_control_with_attention(self):
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user