fix blend

This commit is contained in:
Damian at mba 2022-10-21 04:15:10 +02:00
parent 4c1267338b
commit 404d59b1b8
3 changed files with 16 additions and 7 deletions

View File

@ -48,7 +48,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edited_conditioning = None
edit_opcodes = None
if parsed_prompt is Blend:
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
@ -60,7 +60,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
@ -95,7 +96,7 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
if type(flattened_prompt) is not FlattenedPrompt:
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])

View File

@ -308,7 +308,8 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
quotes = pp.Literal('"').suppress()
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
attention = pp.Forward()
@ -498,12 +499,13 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
blend_weights = pp.delimited_list(number).set_name('blend_weights')
debug_blend=True
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
blend_weights = pp.delimited_list(number).set_name('blend_weights').set_debug(debug_blend)
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(False)
blend.set_debug(debug_blend)
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))

View File

@ -153,6 +153,12 @@ class PromptParserTestCase(unittest.TestCase):
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
parse_prompt('("mountain, man, hairy", "face, teeth, --eyes").blend(1,-1)')
)
def test_nested(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),