Merge branch 'development' of github.com:invoke-ai/InvokeAI into development

This commit is contained in:
Lincoln Stein 2022-10-31 09:35:21 -04:00
commit e8aba99c92
4 changed files with 59 additions and 38 deletions

View File

@ -1,12 +1,9 @@
''' '''
This module handles the generation of the conditioning tensors, including management of This module handles the generation of the conditioning tensors.
weighted subprompts.
Useful function exports: Useful function exports:
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
''' '''
import re import re
@ -16,7 +13,7 @@ from typing import Union
import torch import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
from ..models.diffusion.cross_attention_control import CrossAttentionControl from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -58,8 +55,11 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt blend: Blend = parsed_prompt
embeddings_to_blend = None embeddings_to_blend = None
for flattened_prompt in blend.prompts: for i,flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding)) (embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
@ -103,14 +103,20 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options.append(None) edit_options.append(None)
original_token_count += count original_token_count += count
edited_token_count += count edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens) original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count. # subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit. # token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens) edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings conditioning = original_embeddings
edited_conditioning = edited_embeddings edited_conditioning = edited_embeddings
@ -121,9 +127,15 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options = edit_options edit_options = edit_options
) )
else: else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
# hybrid conditioning is in play # hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning) unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
@ -144,26 +156,15 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False): def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
if type(flattened_prompt) is not FlattenedPrompt: if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children] fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children] weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if not flattened_prompt.is_empty and log_tokens: if log_tokens:
start_token = model.cond_stage_model.tokenizer.bos_token_id text = " ".join(fragments)
end_token = model.cond_stage_model.tokenizer.eos_token_id log_tokenization(text, model, display_label=log_display_label)
tokens_list = tokens[0].tolist()
if tokens_list[0] == start_token:
tokens_list[0] = '<start>'
try:
first_end_token_index = tokens_list.index(end_token)
tokens_list[first_end_token_index] = '<end>'
tokens_list = tokens_list[:first_end_token_index+1]
except ValueError:
pass
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
return embeddings, tokens return embeddings, tokens

View File

@ -2,6 +2,19 @@ import string
from typing import Union, Optional from typing import Union, Optional
import re import re
import pyparsing as pp import pyparsing as pp
'''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts.
Useful class exports:
PromptParser - parses prompts
Useful function exports:
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
class Prompt(): class Prompt():
""" """
@ -205,12 +218,17 @@ class Blend():
#print("making Blend with prompts", prompts, "and weights", weights) #print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights): if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for c in prompts: for p in prompts:
if type(c) is not Prompt and type(c) is not FlattenedPrompt: if type(p) is not Prompt and type(p) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
for f in p.children:
if isinstance(f, CrossAttentionControlSubstitute):
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
# upcast all lists to Prompt objects # upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts] else Prompt(x)
for x in prompts]
self.prompts = prompts self.prompts = prompts
self.weights = weights self.weights = weights
self.normalize_weights = normalize_weights self.normalize_weights = normalize_weights
@ -662,9 +680,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
# shows how the prompt is tokenized # shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1): def log_tokenization(text, model, display_label=None):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text) tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = "" tokenized = ""
discarded = "" discarded = ""
@ -679,7 +695,7 @@ def log_tokenization(text, model, log=False, weight=1):
usedTokens += 1 usedTokens += 1
else: # over max token length else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}" discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "": if discarded != "":
print( print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"

View File

@ -157,10 +157,11 @@ class InvokeAIDiffuserComponent:
# percent_through will never reach 1.0 (but this is intended) # percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count) return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence # find the best possible index of the current sigma in the sigma sequence
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1] smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
# flip because sigmas[0] is for the fully denoised image # flip because sigmas[0] is for the fully denoised image
# percent_through must be <1 # percent_through must be <1
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0]) return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item()) # print('estimated percent_through', percent_through, 'from sigma', sigma.item())

View File

@ -10,6 +10,9 @@ import warnings
import time import time
import traceback import traceback
import yaml import yaml
from ldm.invoke.prompt_parser import PromptParser
sys.path.append('.') # corrects a weird problem on Macs sys.path.append('.') # corrects a weird problem on Macs
from ldm.invoke.readline import get_completer from ldm.invoke.readline import get_completer
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
@ -18,7 +21,7 @@ from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log from ldm.invoke.log import write_log
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
from pyparsing import ParseException import pyparsing
# global used in multiple functions (fix) # global used in multiple functions (fix)
infile = None infile = None
@ -340,7 +343,7 @@ def main_loop(gen, opt):
catch_interrupts=catch_ctrl_c, catch_interrupts=catch_ctrl_c,
**vars(opt) **vars(opt)
) )
except ParseException as e: except (PromptParser.ParsingException, pyparsing.ParseException) as e:
print('** An error occurred while processing your prompt **') print('** An error occurred while processing your prompt **')
print(f'** {str(e)} **') print(f'** {str(e)} **')
elif operation == 'postprocess': elif operation == 'postprocess':