non-normalized blend

This commit is contained in:
Damian at mba 2022-10-21 04:34:53 +02:00
parent d965540103
commit b385fdd7de

View File

@ -187,7 +187,7 @@ class Blend():
self.normalize_weights = normalize_weights
def __repr__(self):
return f"Blend:{self.prompts} | weights {self.weights}"
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
@ -276,7 +276,7 @@ class PromptParser():
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights)]
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
@ -501,14 +501,22 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
debug_blend=True
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
blend_weights = pp.delimited_list(number).set_name('blend_weights').set_debug(debug_blend)
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(debug_blend)
def make_blend(x):
prompts = x[0][0]
weights = x[0][1]
normalize = True
if weights[-1] == 'no_normalize':
normalize = False
weights = weights[:-1]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
blend.set_parse_action(make_blend)
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')