mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move all prompting stuff to use compel
This commit is contained in:
@ -9,6 +9,8 @@ from typing import List, Optional, Union
|
||||
|
||||
import click
|
||||
|
||||
from compel import PromptParser
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
@ -23,7 +25,6 @@ from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.invoke.readline import Completer, get_completer
|
||||
from ldm.util import url_attachment_name
|
||||
|
||||
@ -749,7 +750,7 @@ def import_ckpt_model(
|
||||
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
|
||||
default_name = Path(base_name).stem
|
||||
default_description = f"Imported model {default_name}"
|
||||
|
||||
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
@ -834,7 +835,7 @@ def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=Fa
|
||||
'2': 'v2-inference-v.yaml',
|
||||
'3': 'v1-inpainting-inference.yaml',
|
||||
}
|
||||
|
||||
|
||||
prompt = '''What type of models are these?:
|
||||
[1] Models based on Stable Diffusion 1.X
|
||||
[2] Models based on Stable Diffusion 2.X
|
||||
@ -843,7 +844,7 @@ def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=Fa
|
||||
[1] A model based on Stable Diffusion 1.X
|
||||
[2] A model based on Stable Diffusion 2.X
|
||||
[3] An inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else'''
|
||||
[4] Something else'''
|
||||
print(prompt)
|
||||
choice = input(f'Your choice: [{default}] ')
|
||||
choice = choice.strip() or default
|
||||
|
@ -93,9 +93,9 @@ import shlex
|
||||
import sys
|
||||
import ldm.invoke
|
||||
import ldm.invoke.pngwriter
|
||||
from compel.prompt_parser import split_weighted_subprompts
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -9,59 +9,75 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
|
||||
from ..models.diffusion import cross_attention_control
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute
|
||||
from .devices import torch_dtype
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype)
|
||||
|
||||
return conditioning
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
positive_prompt = compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt = compel.parse_prompt_string(negative_prompt_string)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
|
||||
tokens_count = get_tokens_for_prompt(tokenizer=model.tokenizer, parsed_prompt=positive_prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
|
||||
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
|
||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
def get_prompt_structure(prompt_string, model, skip_normalize_legacy_blend: bool = False) -> (
|
||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
"""
|
||||
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||
"""
|
||||
prompt, negative_prompt = _parse_prompt_string(prompt_string,
|
||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||
return prompt, negative_prompt
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype)
|
||||
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
positive_prompt = compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt = compel.parse_prompt_string(negative_prompt_string)
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
|
||||
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||
def get_tokens_for_prompt(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||
text_fragments = [x.text if type(x) is Fragment else
|
||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||
str(x))
|
||||
for x in parsed_prompt.children]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
|
||||
# Extract Unconditioned Words From Prompt
|
||||
def split_prompt_to_positive_and_negative(prompt_string_uncleaned):
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = ' '.join(unconditionals)
|
||||
|
||||
@ -71,210 +87,57 @@ def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=Fa
|
||||
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
|
||||
pp = PromptParser()
|
||||
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
|
||||
if legacy_blend is not None:
|
||||
parsed_prompt = legacy_blend
|
||||
else:
|
||||
# we don't support conjunctions for now
|
||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||
return parsed_prompt, parsed_negative_prompt
|
||||
return prompt_string_cleaned, unconditioned_words
|
||||
|
||||
|
||||
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
|
||||
model, log_tokens=False) \
|
||||
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
|
||||
"""
|
||||
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
|
||||
"""
|
||||
def log_tokenization(positive_prompt: Blend | FlattenedPrompt,
|
||||
negative_prompt: Blend | FlattenedPrompt,
|
||||
tokenizer):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}")
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(negative_prompt, tokenizer, display_label_prefix="(negative prompt)")
|
||||
|
||||
conditioning = None
|
||||
cac_args: cross_attention_control.Arguments = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
|
||||
elif type(parsed_prompt) is FlattenedPrompt:
|
||||
if parsed_prompt.wants_cross_attention_control:
|
||||
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
|
||||
def log_tokenization_for_prompt_object(p: Blend | FlattenedPrompt, tokenizer, display_label_prefix=None):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c, tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})")
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(original_text, tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)")
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(edited_text, tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)")
|
||||
else:
|
||||
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
parsed_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(prompt)")
|
||||
else:
|
||||
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
|
||||
|
||||
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
parsed_negative_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(unconditioning)")
|
||||
if isinstance(conditioning, dict):
|
||||
# hybrid conditioning is in play
|
||||
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
if cac_args is not None:
|
||||
print(
|
||||
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
cac_args = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
|
||||
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
|
||||
eos_token_index = len(longest_token_sequence)+1
|
||||
else:
|
||||
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
||||
eos_token_index = len(tokens)+1
|
||||
return (
|
||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=eos_token_index + 1,
|
||||
cross_attention_control_args=cac_args
|
||||
)
|
||||
)
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||
|
||||
|
||||
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_options = []
|
||||
edit_opcodes = []
|
||||
# beginning of sequence
|
||||
edit_opcodes.append(
|
||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
for fragment in prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
|
||||
to_replace_token_count = _get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = _get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
# elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
count = _get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
|
||||
edited_token_count + count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
# end of sequence
|
||||
edit_opcodes.append(
|
||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||
original_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap originals)")
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||
edited_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap replacements)")
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = cross_attention_control.Arguments(
|
||||
edited_conditioning=edited_conditioning,
|
||||
edit_opcodes=edit_opcodes,
|
||||
edit_options=edit_options
|
||||
)
|
||||
return conditioning, cac_args
|
||||
|
||||
|
||||
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
||||
embeddings_to_blend = None
|
||||
for i, flattened_prompt in enumerate(blend.prompts):
|
||||
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedPromptFragmentsToEmbeddingsConverter.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
return conditioning
|
||||
|
||||
|
||||
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
|
||||
log_display_label: str = None):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
text = " ".join(fragments)
|
||||
log_tokenization(text, model, display_label=log_display_label)
|
||||
|
||||
return embeddings, tokens
|
||||
|
||||
|
||||
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
|
||||
def _flatten_hybrid_conditioning(uncond, cond):
|
||||
'''
|
||||
This handles the choice between a conditional conditioning
|
||||
that is a tensor (used by cross attention) vs one that has additional
|
||||
dimensions as well, as used by 'hybrid'
|
||||
'''
|
||||
assert isinstance(uncond, dict)
|
||||
assert isinstance(cond, dict)
|
||||
cond_flattened = dict()
|
||||
for k in cond:
|
||||
if isinstance(cond[k], list):
|
||||
cond_flattened[k] = [
|
||||
torch.cat([uncond[k][i], cond[k][i]])
|
||||
for i in range(len(cond[k]))
|
||||
]
|
||||
else:
|
||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||
return uncond, cond_flattened
|
||||
|
||||
|
||||
def log_tokenization(text, model, display_label=None):
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None):
|
||||
""" shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
@ -284,7 +147,7 @@ def log_tokenization(text, model, display_label=None):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
if i < tokenizer.model_max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
@ -293,7 +156,7 @@ def log_tokenization(text, model, display_label=None):
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f'{tokenized}\x1b[0m')
|
||||
|
||||
|
||||
if discarded != "":
|
||||
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):')
|
||||
print(f'{discarded}\x1b[0m')
|
||||
|
@ -31,7 +31,7 @@ from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ..devices import normalize_device, CPU_DEVICE
|
||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
from compel import EmbeddingsProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -294,7 +294,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager
|
||||
@ -726,15 +726,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""
|
||||
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.prompt_fragments_to_embeddings_converter.get_embeddings_for_weighted_prompt_fragments(
|
||||
text=c,
|
||||
fragment_weights=fragment_weights,
|
||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||
text_batch=c,
|
||||
fragment_weights_batch=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self._model_group.device_for(self.unet))
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
return self.prompt_fragments_to_embeddings_converter
|
||||
return self.embeddings_provider
|
||||
|
||||
@torch.inference_mode()
|
||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||
|
@ -1,655 +0,0 @@
|
||||
import string
|
||||
from typing import Union, Optional
|
||||
import re
|
||||
import pyparsing as pp
|
||||
'''
|
||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||
weighted subprompts.
|
||||
|
||||
Useful class exports:
|
||||
|
||||
PromptParser - parses prompts
|
||||
|
||||
Useful function exports:
|
||||
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
'''
|
||||
|
||||
class Prompt():
|
||||
"""
|
||||
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
|
||||
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
|
||||
that are to be blended together are stored individuall as Prompt objects.
|
||||
|
||||
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
|
||||
to produce a FlattenedPrompt.
|
||||
"""
|
||||
def __init__(self, parts: list):
|
||||
for c in parts:
|
||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
|
||||
self.children = parts
|
||||
def __repr__(self):
|
||||
return f"Prompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Prompt and other.children == self.children
|
||||
|
||||
class BaseFragment:
|
||||
pass
|
||||
|
||||
class FlattenedPrompt():
|
||||
"""
|
||||
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
|
||||
"""
|
||||
def __init__(self, parts: list=[]):
|
||||
self.children = []
|
||||
for part in parts:
|
||||
self.append(part)
|
||||
|
||||
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
||||
# verify type correctness
|
||||
if type(fragment) is list:
|
||||
for x in fragment:
|
||||
self.append(x)
|
||||
elif issubclass(type(fragment), BaseFragment):
|
||||
self.children.append(fragment)
|
||||
elif type(fragment) is tuple:
|
||||
# upgrade tuples to Fragments
|
||||
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
self.children.append(Fragment(fragment[0], fragment[1]))
|
||||
else:
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return len(self.children) == 0 or \
|
||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return any(
|
||||
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is FlattenedPrompt and other.children == self.children
|
||||
|
||||
|
||||
class Fragment(BaseFragment):
|
||||
"""
|
||||
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
|
||||
"""
|
||||
def __init__(self, text: str, weight: float=1):
|
||||
assert(type(text) is str)
|
||||
if '\\"' in text or '\\(' in text or '\\)' in text:
|
||||
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
|
||||
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
|
||||
self.text = text
|
||||
self.weight = float(weight)
|
||||
|
||||
def __repr__(self):
|
||||
return "Fragment:'"+self.text+"'@"+str(self.weight)
|
||||
def __eq__(self, other):
|
||||
return type(other) is Fragment \
|
||||
and other.text == self.text \
|
||||
and other.weight == self.weight
|
||||
|
||||
class Attention():
|
||||
"""
|
||||
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
|
||||
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
|
||||
|
||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||
"""
|
||||
def __init__(self, weight: float, children: list):
|
||||
if type(weight) is not float:
|
||||
raise PromptParser.ParsingException(
|
||||
f"Attention weight must be float (got {type(weight).__name__} {weight})")
|
||||
self.weight = weight
|
||||
if type(children) is not list:
|
||||
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
|
||||
assert(type(children) is list)
|
||||
self.children = children
|
||||
#print(f"A: requested attention '{children}' to {weight}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Attention:{self.children} * {self.weight}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||
|
||||
class CrossAttentionControlledFragment(BaseFragment):
|
||||
pass
|
||||
|
||||
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
||||
"""
|
||||
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
|
||||
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
|
||||
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
|
||||
the result should be an "edited" image that looks like the "original" image with concepts swapped.
|
||||
|
||||
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
|
||||
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
|
||||
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
|
||||
or it may represent a larger portion of the token sequence:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
|
||||
edited=[Fragment('a smiling dog sitting on a car')])
|
||||
|
||||
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
|
||||
FlattenedPrompt([
|
||||
Fragment('a'),
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
|
||||
Fragment('sitting on a car')
|
||||
])
|
||||
"""
|
||||
def __init__(self, original: list, edited: list, options: dict=None):
|
||||
self.original = original if len(original)>0 else [Fragment('')]
|
||||
self.edited = edited if len(edited)>0 else [Fragment('')]
|
||||
|
||||
default_options = {
|
||||
's_start': 0.0,
|
||||
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
|
||||
't_start': 0.1,
|
||||
't_end': 1.0
|
||||
}
|
||||
merged_options = default_options
|
||||
if options is not None:
|
||||
shape_freedom = options.pop('shape_freedom', None)
|
||||
if shape_freedom is not None:
|
||||
# high shape freedom = SD can do what it wants with the shape of the object
|
||||
# high shape freedom => s_end = 0
|
||||
# low shape freedom => s_end = 1
|
||||
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
|
||||
# and there is very little perceptible difference as s_end increases above 0.5
|
||||
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
|
||||
# -> cube root and subtract from 1.0
|
||||
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
|
||||
#print('converted shape_freedom argument to', merged_options)
|
||||
merged_options.update(options)
|
||||
|
||||
self.options = merged_options
|
||||
|
||||
def __repr__(self):
|
||||
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlSubstitute \
|
||||
and other.original == self.original \
|
||||
and other.edited == self.edited \
|
||||
and other.options == self.options
|
||||
|
||||
|
||||
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
||||
def __init__(self, fragment: Fragment):
|
||||
self.fragment = fragment
|
||||
def __repr__(self):
|
||||
return "CrossAttentionControlAppend:",self.fragment
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlAppend \
|
||||
and other.fragment == self.fragment
|
||||
|
||||
|
||||
|
||||
class Conjunction():
|
||||
"""
|
||||
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
|
||||
by weighted sum in latent space.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list = None):
|
||||
# force everything to be a Prompt
|
||||
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
|
||||
self.prompts = [x if (type(x) is Prompt
|
||||
or type(x) is Blend
|
||||
or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||
if len(self.weights) != len(self.prompts):
|
||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||
self.type = 'AND'
|
||||
|
||||
def __repr__(self):
|
||||
return f"Conjunction:{self.prompts} | weights {self.weights}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Conjunction \
|
||||
and other.prompts == self.prompts \
|
||||
and other.weights == self.weights
|
||||
|
||||
|
||||
class Blend():
|
||||
"""
|
||||
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
|
||||
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
|
||||
Prompts.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||
if len(prompts) != len(weights):
|
||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||
for p in prompts:
|
||||
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
|
||||
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||
for f in p.children:
|
||||
if isinstance(f, CrossAttentionControlSubstitute):
|
||||
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
|
||||
|
||||
# upcast all lists to Prompt objects
|
||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
||||
else Prompt(x)
|
||||
for x in prompts]
|
||||
self.prompts = prompts
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
# blends cannot cross-attention control
|
||||
return False
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||
def __eq__(self, other):
|
||||
return other.__repr__() == self.__repr__()
|
||||
|
||||
|
||||
class PromptParser():
|
||||
|
||||
class ParsingException(Exception):
|
||||
pass
|
||||
|
||||
class UnrecognizedOperatorException(ParsingException):
|
||||
def __init__(self, operator:str):
|
||||
super().__init__("Unrecognized operator: " + operator)
|
||||
|
||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||
|
||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||
|
||||
|
||||
def parse_conjunction(self, prompt: str) -> Conjunction:
|
||||
'''
|
||||
:param prompt: The prompt string to parse
|
||||
:return: a Conjunction representing the parsed results.
|
||||
'''
|
||||
#print(f"!!parsing '{prompt}'")
|
||||
|
||||
if len(prompt.strip()) == 0:
|
||||
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
|
||||
|
||||
root = self.conjunction.parse_string(prompt)
|
||||
#print(f"'{prompt}' parsed to root", root)
|
||||
#fused = fuse_fragments(parts)
|
||||
#print("fused to", fused)
|
||||
|
||||
return self.flatten(root[0])
|
||||
|
||||
def parse_legacy_blend(self, text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||
|
||||
|
||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||
"""
|
||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||
that can be readily tokenized without the need to walk a complex tree structure.
|
||||
|
||||
:param root: The Conjunction to flatten.
|
||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||
"""
|
||||
|
||||
def fuse_fragments(items):
|
||||
# print("fusing fragments in ", items)
|
||||
result = []
|
||||
for x in items:
|
||||
if type(x) is CrossAttentionControlSubstitute:
|
||||
original_fused = fuse_fragments(x.original)
|
||||
edited_fused = fuse_fragments(x.edited)
|
||||
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
|
||||
else:
|
||||
last_weight = result[-1].weight \
|
||||
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
||||
else None
|
||||
this_text = x.text
|
||||
this_weight = x.weight
|
||||
if last_weight is not None and last_weight == this_weight:
|
||||
last_text = result[-1].text
|
||||
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
|
||||
else:
|
||||
result.append(x)
|
||||
return result
|
||||
|
||||
def flatten_internal(node, weight_scale, results, prefix):
|
||||
verbose and print(prefix + "flattening", node, "...")
|
||||
if type(node) is pp.ParseResults or type(node) is list:
|
||||
for x in node:
|
||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
||||
#print(prefix, " ParseResults expanded, results is now", results)
|
||||
elif type(node) is Attention:
|
||||
# if node.weight < 1:
|
||||
# todo: inject a blend when flattening attention with weight <1"
|
||||
for index,c in enumerate(node.children):
|
||||
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
|
||||
elif type(node) is Fragment:
|
||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||
elif type(node) is CrossAttentionControlSubstitute:
|
||||
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
||||
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
||||
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
|
||||
elif type(node) is Blend:
|
||||
flattened_subprompts = []
|
||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||
for prompt in node.prompts:
|
||||
# prompt is a list
|
||||
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
|
||||
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
|
||||
elif type(node) is Prompt:
|
||||
#print(prefix + "about to flatten Prompt with children", node.children)
|
||||
flattened_prompt = []
|
||||
for child in node.children:
|
||||
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
|
||||
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||
return results
|
||||
|
||||
verbose and print("flattening", root)
|
||||
|
||||
flattened_parts = []
|
||||
for part in root.prompts:
|
||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||
|
||||
verbose and print("flattened to", flattened_parts)
|
||||
|
||||
weights = root.weights
|
||||
return Conjunction(flattened_parts, weights)
|
||||
|
||||
|
||||
|
||||
|
||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||
def make_operator_object(x):
|
||||
#print('making operator for', x)
|
||||
target = x[0]
|
||||
operator = x[1]
|
||||
arguments = x[2]
|
||||
if operator == '.attend':
|
||||
weight_raw = arguments[0]
|
||||
weight = 1.0
|
||||
if type(weight_raw) is float or type(weight_raw) is int:
|
||||
weight = weight_raw
|
||||
elif type(weight_raw) is str:
|
||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||
weight = pow(base, len(weight_raw))
|
||||
return Attention(weight=weight, children=[x for x in x[0]])
|
||||
elif operator == '.swap':
|
||||
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
|
||||
elif operator == '.blend':
|
||||
prompts = [Prompt(p) for p in x[0]]
|
||||
weights_raw = x[2]
|
||||
normalize_weights = True
|
||||
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
|
||||
normalize_weights = False
|
||||
weights_raw = weights_raw[:-1]
|
||||
weights = [float(w[0]) for w in weights_raw]
|
||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
|
||||
elif operator == '.and' or operator == '.add':
|
||||
prompts = [Prompt(p) for p in x[0]]
|
||||
weights = [float(w[0]) for w in x[2]]
|
||||
return Conjunction(prompts=prompts, weights=weights)
|
||||
|
||||
raise PromptParser.UnrecognizedOperatorException(operator)
|
||||
|
||||
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
|
||||
#print(f"parsing fragment string for {x}")
|
||||
fragment_string = x[0]
|
||||
if len(fragment_string.strip()) == 0:
|
||||
return Fragment('')
|
||||
|
||||
if in_quotes:
|
||||
# escape unescaped quotes
|
||||
fragment_string = fragment_string.replace('"', '\\"')
|
||||
|
||||
try:
|
||||
result = (expression + pp.StringEnd()).parse_string(fragment_string)
|
||||
#print("parsed to", result)
|
||||
return result
|
||||
except pp.ParseException as e:
|
||||
#print("parse_fragment_str couldn't parse prompt string:", e)
|
||||
raise
|
||||
|
||||
# meaningful symbols
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
quote = pp.Literal('"').suppress()
|
||||
comma = pp.Literal(",").suppress()
|
||||
dot = pp.Literal(".").suppress()
|
||||
equals = pp.Literal("=").suppress()
|
||||
|
||||
escaped_lparen = pp.Literal('\\(')
|
||||
escaped_rparen = pp.Literal('\\)')
|
||||
escaped_quote = pp.Literal('\\"')
|
||||
escaped_comma = pp.Literal('\\,')
|
||||
escaped_dot = pp.Literal('\\.')
|
||||
escaped_plus = pp.Literal('\\+')
|
||||
escaped_minus = pp.Literal('\\-')
|
||||
escaped_equals = pp.Literal('\\=')
|
||||
|
||||
syntactic_symbols = {
|
||||
'(': escaped_lparen,
|
||||
')': escaped_rparen,
|
||||
'"': escaped_quote,
|
||||
',': escaped_comma,
|
||||
'.': escaped_dot,
|
||||
'+': escaped_plus,
|
||||
'-': escaped_minus,
|
||||
'=': escaped_equals,
|
||||
}
|
||||
syntactic_chars = "".join(syntactic_symbols.keys())
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pp.pyparsing_common.real | \
|
||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||
|
||||
# for options
|
||||
keyword = pp.Word(pp.alphanums + '_')
|
||||
|
||||
# a word that absolutely does not contain any meaningful syntax
|
||||
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
|
||||
pp.Or(syntactic_symbols.values()),
|
||||
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
|
||||
# build character-by-character
|
||||
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
|
||||
])))
|
||||
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
|
||||
non_syntax_word.set_name('non_syntax_word')
|
||||
non_syntax_word.set_debug(False)
|
||||
|
||||
# a word that can contain any character at all - greedily consumes syntax, so use with care
|
||||
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
|
||||
free_word.set_name('free_word')
|
||||
free_word.set_debug(False)
|
||||
|
||||
|
||||
# ok here we go. forward declare some things..
|
||||
attention = pp.Forward()
|
||||
cross_attention_substitute = pp.Forward()
|
||||
parenthesized_fragment = pp.Forward()
|
||||
quoted_fragment = pp.Forward()
|
||||
|
||||
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
|
||||
fragment_part_expressions = [
|
||||
attention,
|
||||
cross_attention_substitute,
|
||||
parenthesized_fragment,
|
||||
quoted_fragment,
|
||||
non_syntax_word
|
||||
]
|
||||
# a fragment that is permitted to contain commas
|
||||
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||
fragment_part_expressions + [
|
||||
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
|
||||
]
|
||||
))
|
||||
# a fragment that is not permitted to contain commas
|
||||
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||
fragment_part_expressions
|
||||
))
|
||||
|
||||
# a fragment in double quotes (may be nested)
|
||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
|
||||
|
||||
# a fragment inside parentheses (may be nested)
|
||||
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment')
|
||||
parenthesized_fragment.set_debug(False)
|
||||
|
||||
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
|
||||
option = pp.Group(pp.MatchFirst([
|
||||
keyword + equals + (number | keyword), # option=value
|
||||
number.copy().set_parse_action(pp.token_map(str)), # weight
|
||||
keyword # flag
|
||||
]))
|
||||
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
|
||||
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
|
||||
options.set_name('options')
|
||||
options.set_debug(False)
|
||||
|
||||
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
|
||||
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
|
||||
|
||||
# a fragment whose weight has been increased or decreased by a given amount
|
||||
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
|
||||
attention_explicit = (
|
||||
pp.Group(potential_operator_target)
|
||||
+ pp.Literal('.attend')
|
||||
+ lparen
|
||||
+ pp.Group(attention_weight_operator)
|
||||
+ rparen
|
||||
)
|
||||
attention_explicit.set_parse_action(make_operator_object)
|
||||
attention_implicit = (
|
||||
pp.Group(potential_operator_target)
|
||||
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
|
||||
+ pp.Group(attention_weight_operator)
|
||||
)
|
||||
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
|
||||
attention << (attention_explicit | attention_implicit)
|
||||
attention.set_name('attention')
|
||||
attention.set_debug(False)
|
||||
|
||||
# cross-attention control by swapping one fragment for another
|
||||
cross_attention_substitute << (
|
||||
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
|
||||
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
|
||||
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||
cross_attention_substitute.set_debug(False)
|
||||
cross_attention_substitute.set_parse_action(make_operator_object)
|
||||
|
||||
|
||||
# an entire self-contained prompt, which can be used in a Blend or Conjunction
|
||||
prompt = pp.ZeroOrMore(pp.MatchFirst([
|
||||
cross_attention_substitute,
|
||||
attention,
|
||||
quoted_fragment,
|
||||
parenthesized_fragment,
|
||||
free_word,
|
||||
pp.White().suppress()
|
||||
]))
|
||||
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
|
||||
|
||||
|
||||
# a blend/lerp between the feature vectors for two or more prompts
|
||||
blend = (
|
||||
lparen
|
||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
|
||||
+ rparen
|
||||
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(options).set_name('bl-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
blend.set_name('blend')
|
||||
blend.set_debug(False)
|
||||
blend.set_parse_action(make_operator_object)
|
||||
|
||||
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
|
||||
explicit_conjunction = (
|
||||
lparen
|
||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
|
||||
+ rparen
|
||||
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(options).set_name('cj-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
explicit_conjunction.set_name('explicit_conjunction')
|
||||
explicit_conjunction.set_debug(False)
|
||||
explicit_conjunction.set_parse_action(make_operator_object)
|
||||
|
||||
# by default a prompt consists of a Conjunction with a single term
|
||||
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
|
||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||
|
||||
conjunction = (explicit_conjunction | implicit_conjunction)
|
||||
|
||||
return conjunction, prompt
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
Reference in New Issue
Block a user