mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' of github.com:invoke-ai/InvokeAI into development
This commit is contained in:
commit
e8aba99c92
@ -1,12 +1,9 @@
|
|||||||
'''
|
'''
|
||||||
This module handles the generation of the conditioning tensors, including management of
|
This module handles the generation of the conditioning tensors.
|
||||||
weighted subprompts.
|
|
||||||
|
|
||||||
Useful function exports:
|
Useful function exports:
|
||||||
|
|
||||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
|
||||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
import re
|
import re
|
||||||
@ -16,7 +13,7 @@ from typing import Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
|
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
||||||
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||||
@ -58,8 +55,11 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
if type(parsed_prompt) is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
blend: Blend = parsed_prompt
|
blend: Blend = parsed_prompt
|
||||||
embeddings_to_blend = None
|
embeddings_to_blend = None
|
||||||
for flattened_prompt in blend.prompts:
|
for i,flattened_prompt in enumerate(blend.prompts):
|
||||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||||
|
flattened_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
|
||||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||||
(embeddings_to_blend, this_embedding))
|
(embeddings_to_blend, this_embedding))
|
||||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||||
@ -103,14 +103,20 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
edit_options.append(None)
|
edit_options.append(None)
|
||||||
original_token_count += count
|
original_token_count += count
|
||||||
edited_token_count += count
|
edited_token_count += count
|
||||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
|
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||||
|
original_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(.swap originals)")
|
||||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||||
# token 'smiling' in the inactive 'cat' edit.
|
# token 'smiling' in the inactive 'cat' edit.
|
||||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
|
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||||
|
edited_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(.swap replacements)")
|
||||||
|
|
||||||
conditioning = original_embeddings
|
conditioning = original_embeddings
|
||||||
edited_conditioning = edited_embeddings
|
edited_conditioning = edited_embeddings
|
||||||
@ -121,9 +127,15 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
edit_options = edit_options
|
edit_options = edit_options
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||||
|
flattened_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(prompt)")
|
||||||
|
|
||||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||||
|
parsed_negative_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(unconditioning)")
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
# hybrid conditioning is in play
|
# hybrid conditioning is in play
|
||||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||||
@ -144,26 +156,15 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
|
|||||||
|
|
||||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||||
|
|
||||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
|
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
|
||||||
if type(flattened_prompt) is not FlattenedPrompt:
|
if type(flattened_prompt) is not FlattenedPrompt:
|
||||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||||
fragments = [x.text for x in flattened_prompt.children]
|
fragments = [x.text for x in flattened_prompt.children]
|
||||||
weights = [x.weight for x in flattened_prompt.children]
|
weights = [x.weight for x in flattened_prompt.children]
|
||||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||||
if not flattened_prompt.is_empty and log_tokens:
|
if log_tokens:
|
||||||
start_token = model.cond_stage_model.tokenizer.bos_token_id
|
text = " ".join(fragments)
|
||||||
end_token = model.cond_stage_model.tokenizer.eos_token_id
|
log_tokenization(text, model, display_label=log_display_label)
|
||||||
tokens_list = tokens[0].tolist()
|
|
||||||
if tokens_list[0] == start_token:
|
|
||||||
tokens_list[0] = '<start>'
|
|
||||||
try:
|
|
||||||
first_end_token_index = tokens_list.index(end_token)
|
|
||||||
tokens_list[first_end_token_index] = '<end>'
|
|
||||||
tokens_list = tokens_list[:first_end_token_index+1]
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
|
|
||||||
|
|
||||||
return embeddings, tokens
|
return embeddings, tokens
|
||||||
|
|
||||||
|
@ -2,6 +2,19 @@ import string
|
|||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
import re
|
import re
|
||||||
import pyparsing as pp
|
import pyparsing as pp
|
||||||
|
'''
|
||||||
|
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||||
|
weighted subprompts.
|
||||||
|
|
||||||
|
Useful class exports:
|
||||||
|
|
||||||
|
PromptParser - parses prompts
|
||||||
|
|
||||||
|
Useful function exports:
|
||||||
|
|
||||||
|
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||||
|
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||||
|
'''
|
||||||
|
|
||||||
class Prompt():
|
class Prompt():
|
||||||
"""
|
"""
|
||||||
@ -205,12 +218,17 @@ class Blend():
|
|||||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||||
if len(prompts) != len(weights):
|
if len(prompts) != len(weights):
|
||||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||||
for c in prompts:
|
for p in prompts:
|
||||||
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
|
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
|
||||||
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||||
|
for f in p.children:
|
||||||
|
if isinstance(f, CrossAttentionControlSubstitute):
|
||||||
|
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
|
||||||
|
|
||||||
# upcast all lists to Prompt objects
|
# upcast all lists to Prompt objects
|
||||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
||||||
else Prompt(x) for x in prompts]
|
else Prompt(x)
|
||||||
|
for x in prompts]
|
||||||
self.prompts = prompts
|
self.prompts = prompts
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.normalize_weights = normalize_weights
|
self.normalize_weights = normalize_weights
|
||||||
@ -662,9 +680,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
|||||||
# shows how the prompt is tokenized
|
# shows how the prompt is tokenized
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
# but for readability it has been replaced with ' '
|
# but for readability it has been replaced with ' '
|
||||||
def log_tokenization(text, model, log=False, weight=1):
|
def log_tokenization(text, model, display_label=None):
|
||||||
if not log:
|
|
||||||
return
|
|
||||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||||
tokenized = ""
|
tokenized = ""
|
||||||
discarded = ""
|
discarded = ""
|
||||||
@ -679,7 +695,7 @@ def log_tokenization(text, model, log=False, weight=1):
|
|||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
else: # over max token length
|
else: # over max token length
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||||
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||||
if discarded != "":
|
if discarded != "":
|
||||||
print(
|
print(
|
||||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||||
|
@ -157,10 +157,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
# percent_through will never reach 1.0 (but this is intended)
|
# percent_through will never reach 1.0 (but this is intended)
|
||||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||||
# find the best possible index of the current sigma in the sigma sequence
|
# find the best possible index of the current sigma in the sigma sequence
|
||||||
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||||
|
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||||
# flip because sigmas[0] is for the fully denoised image
|
# flip because sigmas[0] is for the fully denoised image
|
||||||
# percent_through must be <1
|
# percent_through must be <1
|
||||||
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
||||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,9 @@ import warnings
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from ldm.invoke.prompt_parser import PromptParser
|
||||||
|
|
||||||
sys.path.append('.') # corrects a weird problem on Macs
|
sys.path.append('.') # corrects a weird problem on Macs
|
||||||
from ldm.invoke.readline import get_completer
|
from ldm.invoke.readline import get_completer
|
||||||
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
||||||
@ -18,7 +21,7 @@ from ldm.invoke.image_util import make_grid
|
|||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pyparsing import ParseException
|
import pyparsing
|
||||||
|
|
||||||
# global used in multiple functions (fix)
|
# global used in multiple functions (fix)
|
||||||
infile = None
|
infile = None
|
||||||
@ -340,7 +343,7 @@ def main_loop(gen, opt):
|
|||||||
catch_interrupts=catch_ctrl_c,
|
catch_interrupts=catch_ctrl_c,
|
||||||
**vars(opt)
|
**vars(opt)
|
||||||
)
|
)
|
||||||
except ParseException as e:
|
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
|
||||||
print('** An error occurred while processing your prompt **')
|
print('** An error occurred while processing your prompt **')
|
||||||
print(f'** {str(e)} **')
|
print(f'** {str(e)} **')
|
||||||
elif operation == 'postprocess':
|
elif operation == 'postprocess':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user