make conditioning.py work with compel 1.1.5

This commit is contained in:
Lincoln Stein 2023-05-10 18:08:33 -04:00
parent fb6ef61a4d
commit 5d5157fc65

View File

@ -16,6 +16,7 @@ from compel.prompt_parser import (
FlattenedPrompt,
Fragment,
PromptParser,
Conjunction,
)
import invokeai.backend.util.logging as logger
@ -25,58 +26,51 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
):
def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_string
)
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
tokenizer = model.tokenizer
compel = Compel(
tokenizer=tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False
)
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
(
positive_prompt_string,
negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_prompt = legacy_blend
positive_conjunction = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
extra_conditioning_info=None,
step_count=-1):
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
tokens_count = get_max_token_count(tokenizer, positive_prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get("cross_attention_control", None),
)
# now build the "real" ec
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None))
return uc, c, ec
def get_prompt_structure(
prompt_string, skip_normalize_legacy_blend: bool = False
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
@ -87,18 +81,17 @@ def get_prompt_structure(
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
positive_prompt: Conjunction
if legacy_blend is not None:
positive_prompt = legacy_blend
positive_conjunction = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
return positive_prompt, negative_prompt
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
@ -245,22 +238,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
)
flattened_prompts = []
weights = []
for i, x in enumerate(parsed_conjunctions):
if len(x.prompts)>0:
flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
def split_weighted_subprompts(text, skip_normalize=False) -> list:
"""