make conditioning.py work with compel 1.1.5

This commit is contained in:
Lincoln Stein 2023-05-10 18:08:33 -04:00
parent fb6ef61a4d
commit 5d5157fc65

View File

@ -16,6 +16,7 @@ from compel.prompt_parser import (
FlattenedPrompt, FlattenedPrompt,
Fragment, Fragment,
PromptParser, PromptParser,
Conjunction,
) )
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -25,58 +26,51 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype from ..util import torch_dtype
def get_uc_and_c_and_ec( def get_uc_and_c_and_ec(prompt_string,
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False model: InvokeAIDiffuserComponent,
): log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions. # lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used. # this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
prompt_string
)
tokenizer = model.tokenizer compel = Compel(tokenizer=model.tokenizer,
compel = Compel(
tokenizer=tokenizer,
text_encoder=model.text_encoder, text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False truncate_long_prompts=False,
) )
# get rid of any newline characters # get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ") prompt_string = prompt_string.replace("\n", " ")
( positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
positive_prompt_string,
negative_prompt_string, legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
) = split_prompt_to_positive_and_negative(prompt_string) positive_conjunction: Conjunction
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_conjunction = legacy_blend
else: else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string) positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( positive_prompt = positive_conjunction.prompts[0]
negative_prompt_string
)
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False): if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
extra_conditioning_info=None,
step_count=-1):
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
tokens_count = get_max_token_count(tokenizer, positive_prompt) # now build the "real" ec
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( cross_attention_control_args=options.get(
tokens_count_including_eos_bos=tokens_count, 'cross_attention_control', None))
cross_attention_control_args=options.get("cross_attention_control", None),
)
return uc, c, ec return uc, c, ec
def get_prompt_structure( def get_prompt_structure(
prompt_string, skip_normalize_legacy_blend: bool = False prompt_string, skip_normalize_legacy_blend: bool = False
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt): ) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
@ -87,18 +81,17 @@ def get_prompt_structure(
legacy_blend = try_parse_legacy_blend( legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend positive_prompt_string, skip_normalize_legacy_blend
) )
positive_prompt: Union[FlattenedPrompt, Blend] positive_prompt: Conjunction
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_conjunction = legacy_blend
else: else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string) positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( positive_prompt = positive_conjunction.prompts[0]
negative_prompt_string negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
) negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
return positive_prompt, negative_prompt return positive_prompt, negative_prompt
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int: ) -> int:
@ -245,22 +238,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):") logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m") logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1: if len(weighted_subprompts) <= 1:
return None return None
strings = [x[0] for x in weighted_subprompts] strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser() pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings] parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] flattened_prompts = []
weights = []
return Blend( for i, x in enumerate(parsed_conjunctions):
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize if len(x.prompts)>0:
) flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
def split_weighted_subprompts(text, skip_normalize=False) -> list: def split_weighted_subprompts(text, skip_normalize=False) -> list:
""" """