mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
bring in prompt parser from fix-prompts branch
attention is parsed but ignored, blends old syntax doesn't work, conjunctions are parsed but ignored, the only part that's used here is the new .blend() syntax and cross-attention control using .swap()
This commit is contained in:
parent
42883545f9
commit
c9d27634b4
@ -4,7 +4,7 @@ weighted subprompts.
|
|||||||
|
|
||||||
Useful function exports:
|
Useful function exports:
|
||||||
|
|
||||||
get_uc_and_c() get the conditioned and unconditioned latent
|
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||||
|
|
||||||
@ -39,8 +39,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
|
|
||||||
pp = PromptParser()
|
pp = PromptParser()
|
||||||
|
|
||||||
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned)
|
# we don't support conjunctions for now
|
||||||
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words)
|
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0]
|
||||||
|
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0]
|
||||||
|
print("parsed prompt to", parsed_prompt)
|
||||||
|
|
||||||
conditioning = None
|
conditioning = None
|
||||||
edited_conditioning = None
|
edited_conditioning = None
|
||||||
@ -50,7 +52,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
blend: Blend = parsed_prompt
|
blend: Blend = parsed_prompt
|
||||||
embeddings_to_blend = None
|
embeddings_to_blend = None
|
||||||
for flattened_prompt in blend.prompts:
|
for flattened_prompt in blend.prompts:
|
||||||
this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt)
|
this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
|
||||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||||
(embeddings_to_blend, this_embedding))
|
(embeddings_to_blend, this_embedding))
|
||||||
conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||||
@ -72,16 +74,16 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
# regular fragment
|
# regular fragment
|
||||||
original_prompt.append(fragment)
|
original_prompt.append(fragment)
|
||||||
edited_prompt.append(fragment)
|
edited_prompt.append(fragment)
|
||||||
original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt)
|
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
|
||||||
edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt)
|
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
|
||||||
|
|
||||||
conditioning = original_embeddings
|
conditioning = original_embeddings
|
||||||
edited_conditioning = edited_embeddings
|
edited_conditioning = edited_embeddings
|
||||||
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
|
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
|
||||||
else:
|
else:
|
||||||
conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt)
|
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
|
||||||
|
|
||||||
unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt)
|
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
|
||||||
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
|
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
|
||||||
|
|
||||||
|
|
||||||
@ -91,11 +93,11 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
|
|||||||
|
|
||||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||||
|
|
||||||
def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
|
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
|
||||||
if type(flattened_prompt) is not FlattenedPrompt:
|
if type(flattened_prompt) is not FlattenedPrompt:
|
||||||
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
||||||
fragments = [x[0] for x in flattened_prompt.children]
|
fragments = [x.text for x in flattened_prompt.children]
|
||||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True)
|
embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True)
|
||||||
return embeddings, tokens
|
return embeddings, tokens
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class Img2Img(Generator):
|
|||||||
|
|
||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
uc, c, ec, edit_opcodes = conditioning
|
uc, c, ec, edit_opcodes = conditioning
|
||||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
@ -52,7 +52,7 @@ class Img2Img(Generator):
|
|||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
init_latent = self.init_latent,
|
init_latent = self.init_latent,
|
||||||
structured_conditioning = structured_conditioning
|
extra_conditioning_info = extra_conditioning_info
|
||||||
# changes how noising is performed in ksampler
|
# changes how noising is performed in ksampler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class Txt2Img(Generator):
|
|||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
uc, c, ec, edit_opcodes = conditioning
|
uc, c, ec, edit_opcodes = conditioning
|
||||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
@ -46,7 +46,7 @@ class Txt2Img(Generator):
|
|||||||
verbose = False,
|
verbose = False,
|
||||||
unconditional_guidance_scale = cfg_scale,
|
unconditional_guidance_scale = cfg_scale,
|
||||||
unconditional_conditioning = uc,
|
unconditional_conditioning = uc,
|
||||||
structured_conditioning = structured_conditioning,
|
extra_conditioning_info = extra_conditioning_info,
|
||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
threshold = threshold,
|
threshold = threshold,
|
||||||
|
@ -24,7 +24,7 @@ class Txt2Img2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
uc, c, ec, edit_opcodes = conditioning
|
uc, c, ec, edit_opcodes = conditioning
|
||||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
@ -63,7 +63,7 @@ class Txt2Img2Img(Generator):
|
|||||||
unconditional_conditioning = uc,
|
unconditional_conditioning = uc,
|
||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
structured_conditioning = structured_conditioning
|
extra_conditioning_info = extra_conditioning_info
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
@ -97,7 +97,7 @@ class Txt2Img2Img(Generator):
|
|||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
structured_conditioning = structured_conditioning
|
extra_conditioning_info = extra_conditioning_info
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.free_gpu_mem:
|
if self.free_gpu_mem:
|
||||||
|
@ -7,7 +7,7 @@ class Prompt():
|
|||||||
|
|
||||||
def __init__(self, parts: list):
|
def __init__(self, parts: list):
|
||||||
for c in parts:
|
for c in parts:
|
||||||
if not issubclass(type(c), BaseFragment):
|
if type(c) is not Attention and not issubclass(type(c), BaseFragment):
|
||||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
|
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
|
||||||
self.children = parts
|
self.children = parts
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -56,6 +56,17 @@ class Fragment(BaseFragment):
|
|||||||
and other.text == self.text \
|
and other.text == self.text \
|
||||||
and other.weight == self.weight
|
and other.weight == self.weight
|
||||||
|
|
||||||
|
class Attention():
|
||||||
|
def __init__(self, weight: float, children: list):
|
||||||
|
self.weight = weight
|
||||||
|
self.children = children
|
||||||
|
#print(f"A: requested attention '{children}' to {weight}")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Attention:'{self.children}' @ {self.weight}"
|
||||||
|
def __eq__(self, other):
|
||||||
|
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||||
|
|
||||||
class CrossAttentionControlledFragment(BaseFragment):
|
class CrossAttentionControlledFragment(BaseFragment):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -65,7 +76,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|||||||
self.edited = edited
|
self.edited = edited
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')"
|
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return type(other) is CrossAttentionControlSubstitute \
|
return type(other) is CrossAttentionControlSubstitute \
|
||||||
and other.original == self.original \
|
and other.original == self.original \
|
||||||
@ -137,7 +148,7 @@ class PromptParser():
|
|||||||
self.root = self.build_parser_logic()
|
self.root = self.build_parser_logic()
|
||||||
|
|
||||||
|
|
||||||
def parse(self, prompt: str) -> [list]:
|
def parse(self, prompt: str) -> Conjunction:
|
||||||
'''
|
'''
|
||||||
:param prompt: The prompt string to parse
|
:param prompt: The prompt string to parse
|
||||||
:return: a tuple
|
:return: a tuple
|
||||||
@ -181,13 +192,17 @@ class PromptParser():
|
|||||||
for x in node:
|
for x in node:
|
||||||
results = flatten_internal(x, weight_scale, results, prefix+'pr')
|
results = flatten_internal(x, weight_scale, results, prefix+'pr')
|
||||||
#print(prefix, " ParseResults expanded, results is now", results)
|
#print(prefix, " ParseResults expanded, results is now", results)
|
||||||
elif issubclass(type(node), BaseFragment):
|
elif type(node) is Attention:
|
||||||
results.append(node)
|
# if node.weight < 1:
|
||||||
#elif type(node) is Attention:
|
# todo: inject a blend when flattening attention with weight <1"
|
||||||
# #if node.weight < 1:
|
for c in node.children:
|
||||||
# # todo: inject a blend when flattening attention with weight <1"
|
results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ')
|
||||||
# for c in node.children:
|
elif type(node) is Fragment:
|
||||||
# results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
|
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||||
|
elif type(node) is CrossAttentionControlSubstitute:
|
||||||
|
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
|
||||||
|
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
|
||||||
|
results += [CrossAttentionControlSubstitute(original, edited)]
|
||||||
elif type(node) is Blend:
|
elif type(node) is Blend:
|
||||||
flattened_subprompts = []
|
flattened_subprompts = []
|
||||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||||
@ -204,7 +219,7 @@ class PromptParser():
|
|||||||
#print(prefix + "after flattening Prompt, results is", results)
|
#print(prefix + "after flattening Prompt, results is", results)
|
||||||
else:
|
else:
|
||||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||||
#print(prefix + "-> after flattening", type(node), "results is", results)
|
print(prefix + "-> after flattening", type(node), "results is", results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
#print("flattening", root)
|
#print("flattening", root)
|
||||||
@ -239,32 +254,83 @@ class PromptParser():
|
|||||||
else:
|
else:
|
||||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||||
|
|
||||||
|
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
|
||||||
|
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||||
|
attention = pp.Forward()
|
||||||
|
attention_head = (number | pp.Word('+') | pp.Word('-'))\
|
||||||
|
.set_name("attention_head")\
|
||||||
|
.set_debug(False)
|
||||||
|
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
|
||||||
|
.set_parse_action(make_fragment)\
|
||||||
|
.set_name("fragment_inside_attention")\
|
||||||
|
.set_debug(False)
|
||||||
|
attention_with_parens = pp.Forward()
|
||||||
|
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
|
||||||
|
attention_with_parens << (attention_head + attention_with_parens_body)
|
||||||
|
|
||||||
|
def make_attention(x):
|
||||||
|
# print("making Attention from parsing with args", x0, x1)
|
||||||
|
weight = 1
|
||||||
|
# number(str)
|
||||||
|
if type(x[0]) is float or type(x[0]) is int:
|
||||||
|
weight = float(x[0])
|
||||||
|
# +(str) or -(str) or +str or -str
|
||||||
|
elif type(x[0]) is str:
|
||||||
|
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
|
||||||
|
weight = pow(base, len(x[0]))
|
||||||
|
# print("Making attention with children of type", [str(type(x)) for x in x1])
|
||||||
|
return Attention(weight=weight, children=x[1])
|
||||||
|
|
||||||
|
attention_with_parens.set_parse_action(make_attention)\
|
||||||
|
.set_name("attention_with_parens")\
|
||||||
|
.set_debug(False)
|
||||||
|
|
||||||
|
# attention control of the form ++word --word (no parens)
|
||||||
|
attention_without_parens = (
|
||||||
|
(pp.Word('+') | pp.Word('-')) +
|
||||||
|
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
|
||||||
|
)\
|
||||||
|
.set_name("attention_without_parens")\
|
||||||
|
.set_debug(False)
|
||||||
|
attention_without_parens.set_parse_action(make_attention)
|
||||||
|
|
||||||
|
attention << (attention_with_parens | attention_without_parens)\
|
||||||
|
.set_name("attention")\
|
||||||
|
.set_debug(False)
|
||||||
|
|
||||||
|
# cross-attention control
|
||||||
|
empty_string = ((lparen + rparen) |
|
||||||
|
pp.Literal('""').suppress() |
|
||||||
|
(lparen + pp.Literal('""').suppress() + rparen)
|
||||||
|
).set_parse_action(lambda x: Fragment(""))
|
||||||
|
|
||||||
original_words = (
|
original_words = (
|
||||||
(lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
|
(lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
|
||||||
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
|
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
|
||||||
(lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False)
|
(lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
|
||||||
).set_name('original_words')
|
).set_name('original_words')
|
||||||
edited_words = (
|
edited_words = (
|
||||||
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
|
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
|
||||||
pp.CharsNotIn(')').set_name('termB').set_debug(False)
|
pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
|
||||||
|
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
|
||||||
).set_name('edited_words')
|
).set_name('edited_words')
|
||||||
cross_attention_substitute = original_words + \
|
cross_attention_substitute = (empty_string | original_words) + \
|
||||||
pp.Literal(".swap").suppress() + \
|
pp.Literal(".swap").suppress() + \
|
||||||
lparen + edited_words + rparen
|
(empty_string | (lparen + edited_words + rparen)
|
||||||
|
)
|
||||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||||
|
|
||||||
def make_cross_attention_substitute(x):
|
def make_cross_attention_substitute(x):
|
||||||
#print("making cacs for", x)
|
#print("making cacs for", x)
|
||||||
return CrossAttentionControlSubstitute(x[0], x[1])
|
cacs = CrossAttentionControlSubstitute(x[0], x[1])
|
||||||
#print("made", cacs)
|
#print("made", cacs)
|
||||||
#return cacs
|
return cacs
|
||||||
|
|
||||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||||
|
|
||||||
# simple fragments of text
|
# simple fragments of text
|
||||||
prompt_part << (cross_attention_substitute
|
prompt_part << (cross_attention_substitute
|
||||||
#| attention
|
| attention
|
||||||
| word
|
| word
|
||||||
)
|
)
|
||||||
prompt_part.set_debug(False)
|
prompt_part.set_debug(False)
|
||||||
|
@ -16,10 +16,10 @@ class DDIMSampler(Sampler):
|
|||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
super().prepare_to_sample(t_enc, **kwargs)
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
|
||||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
@ -34,10 +34,10 @@ class CFGDenoiser(nn.Module):
|
|||||||
|
|
||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
|
||||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
|
||||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
@ -164,7 +164,7 @@ class KSampler(Sampler):
|
|||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
structured_conditioning=None,
|
extra_conditioning_info=None,
|
||||||
threshold = 0,
|
threshold = 0,
|
||||||
perlin = 0,
|
perlin = 0,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
@ -197,7 +197,7 @@ class KSampler(Sampler):
|
|||||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
|
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@ -224,7 +224,7 @@ class KSampler(Sampler):
|
|||||||
index,
|
index,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
structured_conditioning=None,
|
extra_conditioning_info=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if self.model_wrap is None:
|
if self.model_wrap is None:
|
||||||
@ -250,7 +250,7 @@ class KSampler(Sampler):
|
|||||||
# so the actual formula for indexing into sigmas:
|
# so the actual formula for indexing into sigmas:
|
||||||
# sigma_index = (steps-index)
|
# sigma_index = (steps-index)
|
||||||
s_index = t_enc - index - 1
|
s_index = t_enc - index - 1
|
||||||
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
|
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||||
self.model_wrap,
|
self.model_wrap,
|
||||||
img,
|
img,
|
||||||
|
@ -20,10 +20,10 @@ class PLMSSampler(Sampler):
|
|||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
super().prepare_to_sample(t_enc, **kwargs)
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
|
||||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
@ -439,6 +439,13 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def forward(self, text, **kwargs):
|
def forward(self, text, **kwargs):
|
||||||
|
|
||||||
|
should_return_tokens = False
|
||||||
|
if 'return_tokens' in kwargs:
|
||||||
|
should_return_tokens = kwargs.get('return_tokens', False)
|
||||||
|
# self.transformer doesn't like having extra kwargs
|
||||||
|
kwargs.pop('return_tokens')
|
||||||
|
|
||||||
batch_encoding = self.tokenizer(
|
batch_encoding = self.tokenizer(
|
||||||
text,
|
text,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@ -451,7 +458,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
tokens = batch_encoding['input_ids'].to(self.device)
|
tokens = batch_encoding['input_ids'].to(self.device)
|
||||||
z = self.transformer(input_ids=tokens, **kwargs)
|
z = self.transformer(input_ids=tokens, **kwargs)
|
||||||
|
|
||||||
if kwargs.get('return_tokens', False):
|
if should_return_tokens:
|
||||||
return z, tokens
|
return z, tokens
|
||||||
else:
|
else:
|
||||||
return z
|
return z
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute
|
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
|
||||||
|
Fragment
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt(prompt_string):
|
def parse_prompt(prompt_string):
|
||||||
@ -135,7 +136,7 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_cross_attention_control(self):
|
def test_cross_attention_control(self):
|
||||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||||
CrossAttentionControlSubstitute('flames', 'trees')])])
|
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
|
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
|
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
|
||||||
@ -144,13 +145,13 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
|
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
|
||||||
|
|
||||||
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||||
CrossAttentionControlSubstitute('flames', 'trees and houses')])])
|
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
|
||||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
|
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
|
||||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
|
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
|
||||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
|
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
|
||||||
|
|
||||||
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||||
CrossAttentionControlSubstitute('trees and houses', 'flames')])])
|
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
|
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
|
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
|
||||||
@ -159,14 +160,46 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
|
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
|
||||||
|
|
||||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||||
CrossAttentionControlSubstitute('flames', 'trees'),
|
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
|
||||||
(', fire', 1.0)])])
|
(', fire', 1.0)])])
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
|
||||||
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
||||||
|
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
||||||
|
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
||||||
|
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||||
|
parse_prompt('a forest landscape "".swap("in winter")'))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
|
||||||
|
parse_prompt('a forest landscape " ".swap("in winter")'))
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||||
|
parse_prompt('a forest landscape "in winter".swap("")'))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||||
|
parse_prompt('a forest landscape "in winter".swap()'))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
|
||||||
|
parse_prompt('a forest landscape "in winter".swap(" ")'))
|
||||||
|
|
||||||
|
def test_cross_attention_control_with_attention(self):
|
||||||
|
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||||
|
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
|
||||||
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)'))
|
||||||
|
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||||
|
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
|
||||||
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)'))
|
||||||
|
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||||
|
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
|
||||||
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user