mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix blend
This commit is contained in:
parent
4c1267338b
commit
404d59b1b8
@ -48,7 +48,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
edited_conditioning = None
|
||||
edit_opcodes = None
|
||||
|
||||
if parsed_prompt is Blend:
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for flattened_prompt in blend.prompts:
|
||||
@ -60,7 +60,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
normalize=blend.normalize_weights)
|
||||
else:
|
||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
||||
wants_cross_attention_control = any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||
if wants_cross_attention_control:
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
@ -95,7 +96,7 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
|
@ -308,7 +308,8 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
||||
quotes = pp.Literal('"').suppress()
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pp.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
||||
number = pp.pyparsing_common.real | \
|
||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
|
||||
|
||||
attention = pp.Forward()
|
||||
@ -498,12 +499,13 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
||||
quoted_prompt.set_name('quoted_prompt')
|
||||
|
||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
|
||||
blend_weights = pp.delimited_list(number).set_name('blend_weights')
|
||||
debug_blend=True
|
||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
|
||||
blend_weights = pp.delimited_list(number).set_name('blend_weights').set_debug(debug_blend)
|
||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
||||
+ pp.Literal(".blend").suppress()
|
||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
||||
blend.set_debug(False)
|
||||
blend.set_debug(debug_blend)
|
||||
|
||||
|
||||
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
|
||||
|
@ -153,6 +153,12 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
|
||||
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
|
||||
parse_prompt('("mountain, man, hairy", "face, teeth, --eyes").blend(1,-1)')
|
||||
)
|
||||
|
||||
|
||||
def test_nested(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
|
||||
|
Loading…
Reference in New Issue
Block a user