mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix blend
This commit is contained in:
parent
4c1267338b
commit
404d59b1b8
@ -48,7 +48,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
edited_conditioning = None
|
edited_conditioning = None
|
||||||
edit_opcodes = None
|
edit_opcodes = None
|
||||||
|
|
||||||
if parsed_prompt is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
blend: Blend = parsed_prompt
|
blend: Blend = parsed_prompt
|
||||||
embeddings_to_blend = None
|
embeddings_to_blend = None
|
||||||
for flattened_prompt in blend.prompts:
|
for flattened_prompt in blend.prompts:
|
||||||
@ -60,7 +60,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
normalize=blend.normalize_weights)
|
normalize=blend.normalize_weights)
|
||||||
else:
|
else:
|
||||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
flattened_prompt: FlattenedPrompt = parsed_prompt
|
||||||
wants_cross_attention_control = any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
||||||
|
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||||
if wants_cross_attention_control:
|
if wants_cross_attention_control:
|
||||||
original_prompt = FlattenedPrompt()
|
original_prompt = FlattenedPrompt()
|
||||||
edited_prompt = FlattenedPrompt()
|
edited_prompt = FlattenedPrompt()
|
||||||
@ -95,7 +96,7 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
|
|||||||
|
|
||||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
|
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
|
||||||
if type(flattened_prompt) is not FlattenedPrompt:
|
if type(flattened_prompt) is not FlattenedPrompt:
|
||||||
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||||
fragments = [x.text for x in flattened_prompt.children]
|
fragments = [x.text for x in flattened_prompt.children]
|
||||||
weights = [x.weight for x in flattened_prompt.children]
|
weights = [x.weight for x in flattened_prompt.children]
|
||||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||||
|
@ -308,7 +308,8 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
quotes = pp.Literal('"').suppress()
|
quotes = pp.Literal('"').suppress()
|
||||||
|
|
||||||
# accepts int or float notation, always maps to float
|
# accepts int or float notation, always maps to float
|
||||||
number = pp.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
number = pp.pyparsing_common.real | \
|
||||||
|
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||||
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
|
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
|
||||||
|
|
||||||
attention = pp.Forward()
|
attention = pp.Forward()
|
||||||
@ -498,12 +499,13 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
||||||
quoted_prompt.set_name('quoted_prompt')
|
quoted_prompt.set_name('quoted_prompt')
|
||||||
|
|
||||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
|
debug_blend=True
|
||||||
blend_weights = pp.delimited_list(number).set_name('blend_weights')
|
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
|
||||||
|
blend_weights = pp.delimited_list(number).set_name('blend_weights').set_debug(debug_blend)
|
||||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
||||||
+ pp.Literal(".blend").suppress()
|
+ pp.Literal(".blend").suppress()
|
||||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
||||||
blend.set_debug(False)
|
blend.set_debug(debug_blend)
|
||||||
|
|
||||||
|
|
||||||
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
|
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
|
||||||
|
@ -153,6 +153,12 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
|
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
|
||||||
|
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
|
||||||
|
parse_prompt('("mountain, man, hairy", "face, teeth, --eyes").blend(1,-1)')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_nested(self):
|
def test_nested(self):
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
|
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user