2022-10-20 13:56:46 +00:00
import string
2022-10-24 09:16:52 +00:00
from typing import Union , Optional
import re
2022-10-19 23:42:04 +00:00
import pyparsing as pp
2022-10-30 22:01:05 +00:00
'''
This module parses prompt strings and produces tree - like structures that can be used generate and control the conditioning tensors .
weighted subprompts .
Useful class exports :
PromptParser - parses prompts
Useful function exports :
split_weighted_subpromopts ( ) split subprompts , normalize and weight them
log_tokenization ( ) print out colour - coded tokens and warn if truncated
'''
2022-10-19 23:42:04 +00:00
class Prompt ( ) :
2022-10-21 01:29:50 +00:00
"""
Mid - level structure for storing the tree - like result of parsing a prompt . A Prompt may not represent the whole of
the singular user - defined " prompt string " ( although it can ) - for example , if the user specifies a Blend , the objects
that are to be blended together are stored individuall as Prompt objects .
Nesting makes this object not suitable for directly tokenizing ; instead call flatten ( ) on the containing Conjunction
to produce a FlattenedPrompt .
"""
2022-10-19 23:42:04 +00:00
def __init__ ( self , parts : list ) :
for c in parts :
2022-10-20 19:05:36 +00:00
if type ( c ) is not Attention and not issubclass ( type ( c ) , BaseFragment ) and type ( c ) is not pp . ParseResults :
2022-11-01 09:08:42 +00:00
raise PromptParser . ParsingException ( f " Prompt cannot contain { type ( c ) . __name__ } ( { c } ), only { [ c . __name__ for c in BaseFragment . __subclasses__ ( ) ] } are allowed " )
2022-10-19 23:42:04 +00:00
self . children = parts
def __repr__ ( self ) :
return f " Prompt: { self . children } "
def __eq__ ( self , other ) :
return type ( other ) is Prompt and other . children == self . children
2022-10-20 19:41:32 +00:00
class BaseFragment :
pass
2022-10-19 23:42:04 +00:00
class FlattenedPrompt ( ) :
2022-10-21 01:29:50 +00:00
"""
A Prompt that has been passed through flatten ( ) . Its children can be readily tokenized .
"""
2022-10-20 19:41:32 +00:00
def __init__ ( self , parts : list = [ ] ) :
self . children = [ ]
2022-10-19 23:42:04 +00:00
for part in parts :
2022-10-20 19:41:32 +00:00
self . append ( part )
def append ( self , fragment : Union [ list , BaseFragment , tuple ] ) :
2022-10-21 01:29:50 +00:00
# verify type correctness
2022-10-20 19:41:32 +00:00
if type ( fragment ) is list :
for x in fragment :
self . append ( x )
elif issubclass ( type ( fragment ) , BaseFragment ) :
self . children . append ( fragment )
elif type ( fragment ) is tuple :
# upgrade tuples to Fragments
if type ( fragment [ 0 ] ) is not str or ( type ( fragment [ 1 ] ) is not float and type ( fragment [ 1 ] ) is not int ) :
2022-10-19 23:42:04 +00:00
raise PromptParser . ParsingException (
2022-10-20 19:41:32 +00:00
f " FlattenedPrompt cannot contain { fragment } , only Fragments or (str, float) tuples are allowed " )
self . children . append ( Fragment ( fragment [ 0 ] , fragment [ 1 ] ) )
else :
raise PromptParser . ParsingException (
f " FlattenedPrompt cannot contain { fragment } , only Fragments or (str, float) tuples are allowed " )
2022-10-19 23:42:04 +00:00
2022-10-24 09:49:47 +00:00
@property
def is_empty ( self ) :
return len ( self . children ) == 0 or \
( len ( self . children ) == 1 and len ( self . children [ 0 ] . text ) == 0 )
2022-10-19 23:42:04 +00:00
def __repr__ ( self ) :
return f " FlattenedPrompt: { self . children } "
def __eq__ ( self , other ) :
return type ( other ) is FlattenedPrompt and other . children == self . children
2022-10-21 01:29:50 +00:00
2022-10-19 23:42:04 +00:00
class Fragment ( BaseFragment ) :
2022-10-21 01:29:50 +00:00
"""
A Fragment is a chunk of plain text and an optional weight . The text should be passed as - is to the CLIP tokenizer .
"""
2022-10-19 23:42:04 +00:00
def __init__ ( self , text : str , weight : float = 1 ) :
assert ( type ( text ) is str )
2022-10-20 19:05:36 +00:00
if ' \\ " ' in text or ' \\ ( ' in text or ' \\ ) ' in text :
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
text = text . replace ( ' \\ ( ' , ' ( ' ) . replace ( ' \\ ) ' , ' ) ' ) . replace ( ' \\ " ' , ' " ' )
2022-10-19 23:42:04 +00:00
self . text = text
self . weight = float ( weight )
def __repr__ ( self ) :
return " Fragment: ' " + self . text + " ' @ " + str ( self . weight )
def __eq__ ( self , other ) :
return type ( other ) is Fragment \
and other . text == self . text \
and other . weight == self . weight
2022-10-20 10:01:48 +00:00
class Attention ( ) :
2022-10-21 01:29:50 +00:00
"""
Nestable weight control for fragments . Each object in the children array may in turn be an Attention object ;
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting .
Do not traverse directly ; instead obtain a FlattenedPrompt by calling Flatten ( ) on a top - level Conjunction object .
"""
2022-10-20 10:01:48 +00:00
def __init__ ( self , weight : float , children : list ) :
2022-11-01 09:08:42 +00:00
if type ( weight ) is not float :
raise PromptParser . ParsingException (
f " Attention weight must be float (got { type ( weight ) . __name__ } { weight } ) " )
2022-10-20 10:01:48 +00:00
self . weight = weight
2022-11-01 09:08:42 +00:00
if type ( children ) is not list :
raise PromptParser . ParsingException ( f " cannot make Attention with non-list of children (got { type ( children ) } ) " )
assert ( type ( children ) is list )
2022-10-20 10:01:48 +00:00
self . children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__ ( self ) :
2022-11-01 09:08:42 +00:00
return f " Attention: { self . children } * { self . weight } "
2022-10-20 10:01:48 +00:00
def __eq__ ( self , other ) :
return type ( other ) is Attention and other . weight == self . weight and other . fragment == self . fragment
2022-10-19 23:42:04 +00:00
class CrossAttentionControlledFragment ( BaseFragment ) :
pass
class CrossAttentionControlSubstitute ( CrossAttentionControlledFragment ) :
2022-10-21 01:29:50 +00:00
"""
A Cross - Attention Controlled ( ' prompt2prompt ' ) fragment , for use inside a Prompt , Attention , or FlattenedPrompt .
Representing an " original " word sequence that supplies feature vectors for an initial diffusion operation , and an
" edited " word sequence , to which the attention maps produced by the " original " word sequence are applied . Intuitively ,
the result should be an " edited " image that looks like the " original " image with concepts swapped .
eg " a cat sitting on a car " ( original ) - > " a smiling dog sitting on a car " ( edited ) : the edited image should look
almost exactly the same as the original , but with a smiling dog rendered in place of the cat . The
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped :
CrossAttentionControlSubstitute ( original = [ Fragment ( ' cat ' ) ] , edited = [ Fragment ( ' dog ' ) ] )
or it may represent a larger portion of the token sequence :
CrossAttentionControlSubstitute ( original = [ Fragment ( ' a cat sitting on a car ' ) ] ,
edited = [ Fragment ( ' a smiling dog sitting on a car ' ) ] )
In either case expect it to be embedded in a Prompt or FlattenedPrompt :
FlattenedPrompt ( [
Fragment ( ' a ' ) ,
CrossAttentionControlSubstitute ( original = [ Fragment ( ' cat ' ) ] , edited = [ Fragment ( ' dog ' ) ] ) ,
Fragment ( ' sitting on a car ' )
] )
"""
2022-11-01 09:08:42 +00:00
def __init__ ( self , original : list , edited : list , options : dict = None ) :
2022-11-03 11:39:45 +00:00
self . original = original if len ( original ) > 0 else [ Fragment ( ' ' ) ]
2022-11-01 09:08:42 +00:00
self . edited = edited if len ( edited ) > 0 else [ Fragment ( ' ' ) ]
2022-10-23 17:38:31 +00:00
2022-10-22 09:27:56 +00:00
default_options = {
' s_start ' : 0.0 ,
2022-10-27 19:17:23 +00:00
' s_end ' : 0.2062994740159002 , # ~= shape_freedom=0.5
2022-10-22 09:27:56 +00:00
' t_start ' : 0.0 ,
' t_end ' : 1.0
}
merged_options = default_options
if options is not None :
2022-10-23 17:38:31 +00:00
shape_freedom = options . pop ( ' shape_freedom ' , None )
if shape_freedom is not None :
# high shape freedom = SD can do what it wants with the shape of the object
# high shape freedom => s_end = 0
# low shape freedom => s_end = 1
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
# and there is very little perceptible difference as s_end increases above 0.5
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
# -> cube root and subtract from 1.0
2022-10-23 17:43:35 +00:00
merged_options [ ' s_end ' ] = 1.0 - shape_freedom * * ( 1. / 3. )
2022-10-27 19:17:23 +00:00
#print('converted shape_freedom argument to', merged_options)
2022-10-22 09:27:56 +00:00
merged_options . update ( options )
2022-10-23 17:38:31 +00:00
2022-10-22 09:27:56 +00:00
self . options = merged_options
2022-10-19 23:42:04 +00:00
def __repr__ ( self ) :
2022-10-22 09:27:56 +00:00
return f " CrossAttentionControlSubstitute:( { self . original } -> { self . edited } ( { self . options } ) "
2022-10-19 23:42:04 +00:00
def __eq__ ( self , other ) :
return type ( other ) is CrossAttentionControlSubstitute \
and other . original == self . original \
2022-10-22 09:27:56 +00:00
and other . edited == self . edited \
and other . options == self . options
2022-10-19 23:42:04 +00:00
2022-10-21 01:29:50 +00:00
2022-10-19 23:42:04 +00:00
class CrossAttentionControlAppend ( CrossAttentionControlledFragment ) :
def __init__ ( self , fragment : Fragment ) :
self . fragment = fragment
def __repr__ ( self ) :
return " CrossAttentionControlAppend: " , self . fragment
def __eq__ ( self , other ) :
return type ( other ) is CrossAttentionControlAppend \
and other . fragment == self . fragment
class Conjunction ( ) :
2022-10-21 01:29:50 +00:00
"""
Storage for one or more Prompts or Blends , each of which is to be separately diffused and then the results merged
by weighted sum in latent space .
"""
2022-10-19 23:42:04 +00:00
def __init__ ( self , prompts : list , weights : list = None ) :
# force everything to be a Prompt
2022-11-01 09:08:42 +00:00
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
2022-10-19 23:42:04 +00:00
self . prompts = [ x if ( type ( x ) is Prompt
or type ( x ) is Blend
or type ( x ) is FlattenedPrompt )
else Prompt ( x ) for x in prompts ]
2022-11-01 09:08:42 +00:00
self . weights = [ 1.0 ] * len ( self . prompts ) if ( weights is None or len ( weights ) == 0 ) else list ( weights )
2022-10-19 23:42:04 +00:00
if len ( self . weights ) != len ( self . prompts ) :
raise PromptParser . ParsingException ( f " while parsing Conjunction: mismatched parts/weights counts { prompts } , { weights } " )
self . type = ' AND '
def __repr__ ( self ) :
return f " Conjunction: { self . prompts } | weights { self . weights } "
def __eq__ ( self , other ) :
return type ( other ) is Conjunction \
and other . prompts == self . prompts \
and other . weights == self . weights
class Blend ( ) :
2022-10-21 01:29:50 +00:00
"""
Stores a Blend of multiple Prompts . To apply , build feature vectors for each of the child Prompts and then perform a
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
Prompts .
"""
2022-10-19 23:42:04 +00:00
def __init__ ( self , prompts : list , weights : list [ float ] , normalize_weights : bool = True ) :
#print("making Blend with prompts", prompts, "and weights", weights)
2022-11-01 09:08:42 +00:00
weights = [ 1.0 ] * len ( prompts ) if ( weights is None or len ( weights ) == 0 ) else list ( weights )
2022-10-19 23:42:04 +00:00
if len ( prompts ) != len ( weights ) :
raise PromptParser . ParsingException ( f " while parsing Blend: mismatched prompts/weights counts { prompts } , { weights } " )
2022-10-30 22:01:05 +00:00
for p in prompts :
if type ( p ) is not Prompt and type ( p ) is not FlattenedPrompt :
raise ( PromptParser . ParsingException ( f " { type ( p ) } cannot be added to a Blend, only Prompts or FlattenedPrompts " ) )
for f in p . children :
if isinstance ( f , CrossAttentionControlSubstitute ) :
raise ( PromptParser . ParsingException ( f " while parsing Blend: sorry, you cannot do .swap() as part of a Blend " ) )
2022-10-19 23:42:04 +00:00
# upcast all lists to Prompt objects
self . prompts = [ x if ( type ( x ) is Prompt or type ( x ) is FlattenedPrompt )
2022-10-30 22:01:05 +00:00
else Prompt ( x )
for x in prompts ]
2022-10-19 23:42:04 +00:00
self . prompts = prompts
self . weights = weights
self . normalize_weights = normalize_weights
def __repr__ ( self ) :
2022-10-21 02:34:53 +00:00
return f " Blend: { self . prompts } | weights { ' ' if self . normalize_weights else ' (non-normalized) ' } { self . weights } "
2022-10-19 23:42:04 +00:00
def __eq__ ( self , other ) :
return other . __repr__ ( ) == self . __repr__ ( )
class PromptParser ( ) :
class ParsingException ( Exception ) :
pass
2022-11-01 09:16:55 +00:00
class UnrecognizedOperatorException ( ParsingException ) :
2022-11-01 09:08:42 +00:00
def __init__ ( self , operator : str ) :
super ( ) . __init__ ( " Unrecognized operator: " + operator )
2022-10-19 23:42:04 +00:00
def __init__ ( self , attention_plus_base = 1.1 , attention_minus_base = 0.9 ) :
2022-10-24 09:16:52 +00:00
self . conjunction , self . prompt = build_parser_syntax ( attention_plus_base , attention_minus_base )
2022-10-19 23:42:04 +00:00
2022-10-24 09:16:52 +00:00
def parse_conjunction ( self , prompt : str ) - > Conjunction :
2022-10-19 23:42:04 +00:00
'''
: param prompt : The prompt string to parse
2022-10-20 19:05:36 +00:00
: return : a Conjunction representing the parsed results .
2022-10-19 23:42:04 +00:00
'''
#print(f"!!parsing '{prompt}'")
if len ( prompt . strip ( ) ) == 0 :
return Conjunction ( prompts = [ FlattenedPrompt ( [ ( ' ' , 1.0 ) ] ) ] , weights = [ 1.0 ] )
2022-10-24 09:16:52 +00:00
root = self . conjunction . parse_string ( prompt )
2022-10-19 23:42:04 +00:00
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self . flatten ( root [ 0 ] )
2022-10-24 09:16:52 +00:00
def parse_legacy_blend ( self , text : str ) - > Optional [ Blend ] :
weighted_subprompts = split_weighted_subprompts ( text , skip_normalize = False )
2022-10-27 17:01:54 +00:00
if len ( weighted_subprompts ) < = 1 :
2022-10-24 09:16:52 +00:00
return None
strings = [ x [ 0 ] for x in weighted_subprompts ]
weights = [ x [ 1 ] for x in weighted_subprompts ]
parsed_conjunctions = [ self . parse_conjunction ( x ) for x in strings ]
flattened_prompts = [ x . prompts [ 0 ] for x in parsed_conjunctions ]
return Blend ( prompts = flattened_prompts , weights = weights , normalize_weights = True )
2022-10-21 01:29:50 +00:00
2022-11-01 09:08:42 +00:00
def flatten ( self , root : Conjunction , verbose = False ) - > Conjunction :
2022-10-21 01:29:50 +00:00
"""
Flattening a Conjunction traverses all of the nested tree - like structures in each of its Prompts or Blends ,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
that can be readily tokenized without the need to walk a complex tree structure .
: param root : The Conjunction to flatten .
: return : A Conjunction containing the result of flattening each of the prompts in the passed - in root .
"""
2022-10-19 23:42:04 +00:00
def fuse_fragments ( items ) :
# print("fusing fragments in ", items)
result = [ ]
for x in items :
2022-10-20 14:56:34 +00:00
if type ( x ) is CrossAttentionControlSubstitute :
original_fused = fuse_fragments ( x . original )
edited_fused = fuse_fragments ( x . edited )
2022-10-22 09:27:56 +00:00
result . append ( CrossAttentionControlSubstitute ( original_fused , edited_fused , options = x . options ) )
2022-10-19 23:42:04 +00:00
else :
last_weight = result [ - 1 ] . weight \
if ( len ( result ) > 0 and not issubclass ( type ( result [ - 1 ] ) , CrossAttentionControlledFragment ) ) \
else None
this_text = x . text
this_weight = x . weight
if last_weight is not None and last_weight == this_weight :
last_text = result [ - 1 ] . text
result [ - 1 ] = Fragment ( last_text + ' ' + this_text , last_weight )
else :
result . append ( x )
return result
def flatten_internal ( node , weight_scale , results , prefix ) :
2022-11-01 09:08:42 +00:00
verbose and print ( prefix + " flattening " , node , " ... " )
if type ( node ) is pp . ParseResults or type ( node ) is list :
2022-10-19 23:42:04 +00:00
for x in node :
2022-10-20 19:05:36 +00:00
results = flatten_internal ( x , weight_scale , results , prefix + ' pr ' )
2022-10-19 23:42:04 +00:00
#print(prefix, " ParseResults expanded, results is now", results)
2022-10-20 10:01:48 +00:00
elif type ( node ) is Attention :
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
2022-10-20 19:05:36 +00:00
for index , c in enumerate ( node . children ) :
results = flatten_internal ( c , weight_scale * node . weight , results , prefix + f " att { index } " )
2022-10-20 10:01:48 +00:00
elif type ( node ) is Fragment :
results + = [ Fragment ( node . text , node . weight * weight_scale ) ]
elif type ( node ) is CrossAttentionControlSubstitute :
2022-10-20 13:56:46 +00:00
original = flatten_internal ( node . original , weight_scale , [ ] , prefix + ' CAo ' )
edited = flatten_internal ( node . edited , weight_scale , [ ] , prefix + ' CAe ' )
2022-10-22 09:27:56 +00:00
results + = [ CrossAttentionControlSubstitute ( original , edited , options = node . options ) ]
2022-10-19 23:42:04 +00:00
elif type ( node ) is Blend :
flattened_subprompts = [ ]
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node . prompts :
# prompt is a list
flattened_subprompts = flatten_internal ( prompt , weight_scale , flattened_subprompts , prefix + ' B ' )
2022-10-21 02:34:53 +00:00
results + = [ Blend ( prompts = flattened_subprompts , weights = node . weights , normalize_weights = node . normalize_weights ) ]
2022-10-19 23:42:04 +00:00
elif type ( node ) is Prompt :
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = [ ]
for child in node . children :
flattened_prompt = flatten_internal ( child , weight_scale , flattened_prompt , prefix + ' P ' )
results + = [ FlattenedPrompt ( parts = fuse_fragments ( flattened_prompt ) ) ]
#print(prefix + "after flattening Prompt, results is", results)
else :
raise PromptParser . ParsingException ( f " unhandled node type { type ( node ) } when flattening { node } " )
2022-11-01 09:08:42 +00:00
verbose and print ( prefix + " -> after flattening " , type ( node ) . __name__ , " results is " , results )
2022-10-19 23:42:04 +00:00
return results
2022-11-01 09:08:42 +00:00
verbose and print ( " flattening " , root )
2022-10-19 23:42:04 +00:00
flattened_parts = [ ]
for part in root . prompts :
flattened_parts + = flatten_internal ( part , 1.0 , [ ] , ' C| ' )
2022-10-21 01:29:50 +00:00
2022-11-01 09:08:42 +00:00
verbose and print ( " flattened to " , flattened_parts )
2022-10-21 01:29:50 +00:00
2022-10-19 23:42:04 +00:00
weights = root . weights
return Conjunction ( flattened_parts , weights )
2022-11-01 09:08:42 +00:00
def build_parser_syntax ( attention_plus_base : float , attention_minus_base : float ) :
def make_operator_object ( x ) :
#print('making operator for', x)
target = x [ 0 ]
operator = x [ 1 ]
arguments = x [ 2 ]
if operator == ' .attend ' :
weight_raw = arguments [ 0 ]
weight = 1.0
if type ( weight_raw ) is float or type ( weight_raw ) is int :
weight = weight_raw
elif type ( weight_raw ) is str :
base = attention_plus_base if weight_raw [ 0 ] == ' + ' else attention_minus_base
weight = pow ( base , len ( weight_raw ) )
return Attention ( weight = weight , children = [ x for x in x [ 0 ] ] )
elif operator == ' .swap ' :
return CrossAttentionControlSubstitute ( target , arguments , x . as_dict ( ) )
elif operator == ' .blend ' :
prompts = [ Prompt ( p ) for p in x [ 0 ] ]
weights_raw = x [ 2 ]
normalize_weights = True
if len ( weights_raw ) > 0 and weights_raw [ - 1 ] [ 0 ] == ' no_normalize ' :
normalize_weights = False
weights_raw = weights_raw [ : - 1 ]
weights = [ float ( w [ 0 ] ) for w in weights_raw ]
return Blend ( prompts = prompts , weights = weights , normalize_weights = normalize_weights )
elif operator == ' .and ' or operator == ' .add ' :
prompts = [ Prompt ( p ) for p in x [ 0 ] ]
weights = [ float ( w [ 0 ] ) for w in x [ 2 ] ]
return Conjunction ( prompts = prompts , weights = weights )
raise PromptParser . UnrecognizedOperatorException ( operator )
def parse_fragment_str ( x , expression : pp . ParseExpression , in_quotes : bool = False , in_parens : bool = False ) :
2022-10-27 08:14:44 +00:00
#print(f"parsing fragment string for {x}")
2022-10-21 01:29:50 +00:00
fragment_string = x [ 0 ]
if len ( fragment_string . strip ( ) ) == 0 :
return Fragment ( ' ' )
if in_quotes :
# escape unescaped quotes
fragment_string = fragment_string . replace ( ' " ' , ' \\ " ' )
2022-10-27 08:14:44 +00:00
try :
2022-11-01 09:08:42 +00:00
result = ( expression + pp . StringEnd ( ) ) . parse_string ( fragment_string )
2022-10-27 08:14:44 +00:00
#print("parsed to", result)
return result
except pp . ParseException as e :
2022-10-27 20:36:33 +00:00
#print("parse_fragment_str couldn't parse prompt string:", e)
2022-10-27 08:14:44 +00:00
raise
2022-10-21 01:29:50 +00:00
2022-11-01 09:08:42 +00:00
# meaningful symbols
lparen = pp . Literal ( " ( " ) . suppress ( )
rparen = pp . Literal ( " ) " ) . suppress ( )
quote = pp . Literal ( ' " ' ) . suppress ( )
comma = pp . Literal ( " , " ) . suppress ( )
dot = pp . Literal ( " . " ) . suppress ( )
equals = pp . Literal ( " = " ) . suppress ( )
escaped_lparen = pp . Literal ( ' \\ ( ' )
escaped_rparen = pp . Literal ( ' \\ ) ' )
escaped_quote = pp . Literal ( ' \\ " ' )
escaped_comma = pp . Literal ( ' \\ , ' )
escaped_dot = pp . Literal ( ' \\ . ' )
escaped_plus = pp . Literal ( ' \\ + ' )
escaped_minus = pp . Literal ( ' \\ - ' )
escaped_equals = pp . Literal ( ' \\ = ' )
syntactic_symbols = {
' ( ' : escaped_lparen ,
' ) ' : escaped_rparen ,
' " ' : escaped_quote ,
' , ' : escaped_comma ,
' . ' : escaped_dot ,
' + ' : escaped_plus ,
' - ' : escaped_minus ,
' = ' : escaped_equals ,
}
syntactic_chars = " " . join ( syntactic_symbols . keys ( ) )
2022-10-21 01:29:50 +00:00
2022-11-01 09:08:42 +00:00
# accepts int or float notation, always maps to float
number = pp . pyparsing_common . real | \
pp . Combine ( pp . Optional ( " - " ) + pp . Word ( pp . nums ) ) . set_parse_action ( pp . token_map ( float ) )
2022-10-23 21:01:53 +00:00
2022-11-01 09:08:42 +00:00
# for options
keyword = pp . Word ( pp . alphanums + ' _ ' )
# a word that absolutely does not contain any meaningful syntax
non_syntax_word = pp . Combine ( pp . OneOrMore ( pp . MatchFirst ( [
pp . Or ( syntactic_symbols . values ( ) ) ,
pp . one_of ( [ ' - ' , ' + ' ] ) + pp . NotAny ( pp . White ( ) | pp . Char ( syntactic_chars ) | pp . StringEnd ( ) ) ,
# build character-by-character
pp . CharsNotIn ( string . whitespace + syntactic_chars , exact = 1 )
] ) ) )
non_syntax_word . set_parse_action ( lambda x : [ Fragment ( t ) for t in x ] )
non_syntax_word . set_name ( ' non_syntax_word ' )
non_syntax_word . set_debug ( False )
# a word that can contain any character at all - greedily consumes syntax, so use with care
free_word = pp . CharsNotIn ( string . whitespace ) . set_parse_action ( lambda x : Fragment ( x [ 0 ] ) )
free_word . set_name ( ' free_word ' )
free_word . set_debug ( False )
2022-10-21 01:29:50 +00:00
2022-11-01 09:08:42 +00:00
# ok here we go. forward declare some things..
attention = pp . Forward ( )
cross_attention_substitute = pp . Forward ( )
parenthesized_fragment = pp . Forward ( )
quoted_fragment = pp . Forward ( )
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
fragment_part_expressions = [
attention ,
cross_attention_substitute ,
parenthesized_fragment ,
quoted_fragment ,
non_syntax_word
]
# a fragment that is permitted to contain commas
fragment_including_commas = pp . ZeroOrMore ( pp . MatchFirst (
fragment_part_expressions + [
pp . Literal ( ' , ' ) . set_parse_action ( lambda x : Fragment ( x [ 0 ] ) )
]
) )
# a fragment that is not permitted to contain commas
fragment_excluding_commas = pp . ZeroOrMore ( pp . MatchFirst (
fragment_part_expressions
) )
# a fragment in double quotes (may be nested)
quoted_fragment << pp . QuotedString ( quote_char = ' " ' , esc_char = None , esc_quote = ' \\ " ' )
quoted_fragment . set_parse_action ( lambda x : parse_fragment_str ( x , fragment_including_commas , in_quotes = True ) )
# a fragment inside parentheses (may be nested)
parenthesized_fragment << ( lparen + fragment_including_commas + rparen )
parenthesized_fragment . set_name ( ' parenthesized_fragment ' )
parenthesized_fragment . set_debug ( False )
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
option = pp . Group ( pp . MatchFirst ( [
keyword + equals + ( number | keyword ) , # option=value
number . copy ( ) . set_parse_action ( pp . token_map ( str ) ) , # weight
keyword # flag
] ) )
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
options = pp . Dict ( pp . Optional ( pp . delimited_list ( option ) ) )
options . set_name ( ' options ' )
options . set_debug ( False )
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
potential_operator_target = ( quoted_fragment | parenthesized_fragment | non_syntax_word )
# a fragment whose weight has been increased or decreased by a given amount
attention_weight_operator = pp . Word ( ' + ' ) | pp . Word ( ' - ' ) | number
attention_explicit = (
pp . Group ( potential_operator_target )
+ pp . Literal ( ' .attend ' )
+ lparen
+ pp . Group ( attention_weight_operator )
+ rparen
)
attention_explicit . set_parse_action ( make_operator_object )
attention_implicit = (
pp . Group ( potential_operator_target )
+ pp . NotAny ( pp . White ( ) ) # do not permit whitespace between term and operator
+ pp . Group ( attention_weight_operator )
)
attention_implicit . set_parse_action ( lambda x : make_operator_object ( [ x [ 0 ] , ' .attend ' , x [ 1 ] ] ) )
attention << ( attention_explicit | attention_implicit )
attention . set_name ( ' attention ' )
attention . set_debug ( False )
# cross-attention control by swapping one fragment for another
cross_attention_substitute << (
pp . Group ( potential_operator_target ) . set_name ( ' ca-target ' ) . set_debug ( False )
+ pp . Literal ( " .swap " ) . set_name ( ' ca-operator ' ) . set_debug ( False )
+ lparen
+ pp . Group ( fragment_excluding_commas ) . set_name ( ' ca-replacement ' ) . set_debug ( False )
+ pp . Optional ( comma + options ) . set_name ( ' ca-options ' ) . set_debug ( False )
+ rparen
)
cross_attention_substitute . set_name ( ' cross_attention_substitute ' )
cross_attention_substitute . set_debug ( False )
cross_attention_substitute . set_parse_action ( make_operator_object )
# an entire self-contained prompt, which can be used in a Blend or Conjunction
prompt = pp . ZeroOrMore ( pp . MatchFirst ( [
cross_attention_substitute ,
attention ,
quoted_fragment ,
parenthesized_fragment ,
free_word ,
pp . White ( ) . suppress ( )
] ) )
quoted_prompt = quoted_fragment . copy ( ) . set_parse_action ( lambda x : parse_fragment_str ( x , prompt , in_quotes = True ) )
# a blend/lerp between the feature vectors for two or more prompts
blend = (
lparen
+ pp . Group ( pp . delimited_list ( pp . Group ( potential_operator_target | quoted_prompt ) , min = 1 ) ) . set_name ( ' bl-target ' ) . set_debug ( False )
+ rparen
+ pp . Literal ( " .blend " ) . set_name ( ' bl-operator ' ) . set_debug ( False )
+ lparen
+ pp . Group ( options ) . set_name ( ' bl-options ' ) . set_debug ( False )
+ rparen
)
blend . set_name ( ' blend ' )
blend . set_debug ( False )
blend . set_parse_action ( make_operator_object )
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
explicit_conjunction = (
lparen
+ pp . Group ( pp . delimited_list ( pp . Group ( potential_operator_target | quoted_prompt ) , min = 1 ) ) . set_name ( ' cj-target ' ) . set_debug ( False )
+ rparen
+ pp . one_of ( [ " .and " , " .add " ] ) . set_name ( ' cj-operator ' ) . set_debug ( False )
+ lparen
+ pp . Group ( options ) . set_name ( ' cj-options ' ) . set_debug ( False )
+ rparen
)
explicit_conjunction . set_name ( ' explicit_conjunction ' )
explicit_conjunction . set_debug ( False )
explicit_conjunction . set_parse_action ( make_operator_object )
# by default a prompt consists of a Conjunction with a single term
implicit_conjunction = ( blend | pp . Group ( prompt ) ) + pp . StringEnd ( )
2022-10-21 01:29:50 +00:00
implicit_conjunction . set_parse_action ( lambda x : Conjunction ( x ) )
2022-11-01 09:08:42 +00:00
conjunction = ( explicit_conjunction | implicit_conjunction )
2022-10-21 01:29:50 +00:00
2022-10-24 09:16:52 +00:00
return conjunction , prompt
def split_weighted_subprompts ( text , skip_normalize = False ) - > list :
"""
Legacy blend parsing .
grabs all text up to the first occurrence of ' : '
uses the grabbed text as a sub - prompt , and takes the value following ' : ' as weight
if ' : ' has no value defined , defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re . compile ( """
( ? P < prompt > # capture group for 'prompt'
( ? : \\\: | [ ^ : ] ) + # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
( ? : # non-capture group
: + # match one or more ':' characters
( ? P < weight > # capture group for 'weight'
- ? \d + ( ? : \. \d + ) ? # match positive or negative integer or decimal number
) ? # end weight capture group, make optional
\s * # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""" , re.VERBOSE)
parsed_prompts = [ ( match . group ( " prompt " ) . replace ( " \\ : " , " : " ) , float (
match . group ( " weight " ) or 1 ) ) for match in re . finditer ( prompt_parser , text ) ]
if skip_normalize :
return parsed_prompts
weight_sum = sum ( map ( lambda x : x [ 1 ] , parsed_prompts ) )
if weight_sum == 0 :
print (
2022-11-11 15:30:44 +00:00
" Warning: Subprompt weights add up to zero. Discarding and using even weights instead. " )
2022-10-24 09:16:52 +00:00
equal_weight = 1 / max ( len ( parsed_prompts ) , 1 )
return [ ( x [ 0 ] , equal_weight ) for x in parsed_prompts ]
return [ ( x [ 0 ] , x [ 1 ] / weight_sum ) for x in parsed_prompts ]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
2022-10-30 22:01:05 +00:00
def log_tokenization ( text , model , display_label = None ) :
2022-10-24 09:16:52 +00:00
tokens = model . cond_stage_model . tokenizer . _tokenize ( text )
tokenized = " "
discarded = " "
usedTokens = 0
totalTokens = len ( tokens )
for i in range ( 0 , totalTokens ) :
token = tokens [ i ] . replace ( ' </w> ' , ' x` ' )
# alternate color
s = ( usedTokens % 6 ) + 1
if i < model . cond_stage_model . max_length :
tokenized = tokenized + f " \x1b [0;3 { s } ;40m { token } "
usedTokens + = 1
else : # over max token length
discarded = discarded + f " \x1b [0;3 { s } ;40m { token } "
2022-10-30 22:01:05 +00:00
print ( f " \n >> Tokens { display_label or ' ' } ( { usedTokens } ): \n { tokenized } \x1b [0m " )
2022-10-24 09:16:52 +00:00
if discarded != " " :
print (
f " >> Tokens Discarded ( { totalTokens - usedTokens } ): \n { discarded } \x1b [0m "
)