mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
various prompting fixes
This commit is contained in:
parent
80f2cfe3e3
commit
ced9c83e96
@ -1,12 +1,9 @@
|
||||
'''
|
||||
This module handles the generation of the conditioning tensors, including management of
|
||||
weighted subprompts.
|
||||
This module handles the generation of the conditioning tensors.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
|
||||
'''
|
||||
import re
|
||||
@ -16,7 +13,7 @@ from typing import Union
|
||||
import torch
|
||||
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
||||
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
@ -58,8 +55,11 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for flattened_prompt in blend.prompts:
|
||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||
for i,flattened_prompt in enumerate(blend.prompts):
|
||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
@ -103,14 +103,20 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
original_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap originals)")
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
edited_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap replacements)")
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
@ -121,9 +127,15 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
edit_options = edit_options
|
||||
)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(prompt)")
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
parsed_negative_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(unconditioning)")
|
||||
if isinstance(conditioning, dict):
|
||||
# hybrid conditioning is in play
|
||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
@ -144,26 +156,15 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
|
||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
if not flattened_prompt.is_empty and log_tokens:
|
||||
start_token = model.cond_stage_model.tokenizer.bos_token_id
|
||||
end_token = model.cond_stage_model.tokenizer.eos_token_id
|
||||
tokens_list = tokens[0].tolist()
|
||||
if tokens_list[0] == start_token:
|
||||
tokens_list[0] = '<start>'
|
||||
try:
|
||||
first_end_token_index = tokens_list.index(end_token)
|
||||
tokens_list[first_end_token_index] = '<end>'
|
||||
tokens_list = tokens_list[:first_end_token_index+1]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
|
||||
if log_tokens:
|
||||
text = " ".join(fragments)
|
||||
log_tokenization(text, model, display_label=log_display_label)
|
||||
|
||||
return embeddings, tokens
|
||||
|
||||
|
@ -2,6 +2,19 @@ import string
|
||||
from typing import Union, Optional
|
||||
import re
|
||||
import pyparsing as pp
|
||||
'''
|
||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||
weighted subprompts.
|
||||
|
||||
Useful class exports:
|
||||
|
||||
PromptParser - parses prompts
|
||||
|
||||
Useful function exports:
|
||||
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
'''
|
||||
|
||||
class Prompt():
|
||||
"""
|
||||
@ -205,12 +218,17 @@ class Blend():
|
||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||
if len(prompts) != len(weights):
|
||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||
for c in prompts:
|
||||
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
|
||||
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||
for p in prompts:
|
||||
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
|
||||
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||
for f in p.children:
|
||||
if isinstance(f, CrossAttentionControlSubstitute):
|
||||
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
|
||||
|
||||
# upcast all lists to Prompt objects
|
||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
else Prompt(x)
|
||||
for x in prompts]
|
||||
self.prompts = prompts
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
@ -662,9 +680,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False, weight=1):
|
||||
if not log:
|
||||
return
|
||||
def log_tokenization(text, model, display_label=None):
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
@ -679,7 +695,7 @@ def log_tokenization(text, model, log=False, weight=1):
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
|
@ -157,10 +157,11 @@ class InvokeAIDiffuserComponent:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
# find the best possible index of the current sigma in the sigma sequence
|
||||
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||
# flip because sigmas[0] is for the fully denoised image
|
||||
# percent_through must be <1
|
||||
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
||||
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||
|
||||
|
||||
|
@ -10,6 +10,9 @@ import warnings
|
||||
import time
|
||||
import traceback
|
||||
import yaml
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
|
||||
sys.path.append('.') # corrects a weird problem on Macs
|
||||
from ldm.invoke.readline import get_completer
|
||||
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
||||
@ -18,7 +21,7 @@ from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from pyparsing import ParseException
|
||||
import pyparsing
|
||||
|
||||
# global used in multiple functions (fix)
|
||||
infile = None
|
||||
@ -340,7 +343,7 @@ def main_loop(gen, opt):
|
||||
catch_interrupts=catch_ctrl_c,
|
||||
**vars(opt)
|
||||
)
|
||||
except ParseException as e:
|
||||
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
|
||||
print('** An error occurred while processing your prompt **')
|
||||
print(f'** {str(e)} **')
|
||||
elif operation == 'postprocess':
|
||||
|
Loading…
Reference in New Issue
Block a user