From 5d5157fc6512a3e264fa310c9d1030701ceef8b6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 10 May 2023 18:08:33 -0400 Subject: [PATCH] make conditioning.py work with compel 1.1.5 --- invokeai/backend/prompting/conditioning.py | 100 ++++++++++----------- 1 file changed, 46 insertions(+), 54 deletions(-) diff --git a/invokeai/backend/prompting/conditioning.py b/invokeai/backend/prompting/conditioning.py index d9130ace04..f94f82ef72 100644 --- a/invokeai/backend/prompting/conditioning.py +++ b/invokeai/backend/prompting/conditioning.py @@ -16,6 +16,7 @@ from compel.prompt_parser import ( FlattenedPrompt, Fragment, PromptParser, + Conjunction, ) import invokeai.backend.util.logging as logger @@ -25,58 +26,51 @@ from ..stable_diffusion import InvokeAIDiffuserComponent from ..util import torch_dtype -def get_uc_and_c_and_ec( - prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False -): +def get_uc_and_c_and_ec(prompt_string, + model: InvokeAIDiffuserComponent, + log_tokens=False, skip_normalize_legacy_blend=False): # lazy-load any deferred textual inversions. # this might take a couple of seconds the first time a textual inversion is used. - model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( - prompt_string - ) + model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string) - tokenizer = model.tokenizer - compel = Compel( - tokenizer=tokenizer, - text_encoder=model.text_encoder, - textual_inversion_manager=model.textual_inversion_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=False - ) + compel = Compel(tokenizer=model.tokenizer, + text_encoder=model.text_encoder, + textual_inversion_manager=model.textual_inversion_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=False, + ) # get rid of any newline characters prompt_string = prompt_string.replace("\n", " ") - ( - positive_prompt_string, - negative_prompt_string, - ) = split_prompt_to_positive_and_negative(prompt_string) - legacy_blend = try_parse_legacy_blend( - positive_prompt_string, skip_normalize_legacy_blend - ) - positive_prompt: Union[FlattenedPrompt, Blend] + positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) + + legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) + positive_conjunction: Conjunction if legacy_blend is not None: - positive_prompt = legacy_blend + positive_conjunction = legacy_blend else: - positive_prompt = Compel.parse_prompt_string(positive_prompt_string) - negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( - negative_prompt_string - ) + positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) + positive_prompt = positive_conjunction.prompts[0] + negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) + negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0] + + tokens_count = get_max_token_count(model.tokenizer, positive_prompt) if log_tokens or getattr(Globals, "log_tokenization", False): - log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) + log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer) - c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) - uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) - [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + with InvokeAIDiffuserComponent.custom_attention_context(model.unet, + extra_conditioning_info=None, + step_count=-1): + c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) + uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) - tokens_count = get_max_token_count(tokenizer, positive_prompt) - - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=tokens_count, - cross_attention_control_args=options.get("cross_attention_control", None), - ) + # now build the "real" ec + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count, + cross_attention_control_args=options.get( + 'cross_attention_control', None)) return uc, c, ec - def get_prompt_structure( prompt_string, skip_normalize_legacy_blend: bool = False ) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt): @@ -87,18 +81,17 @@ def get_prompt_structure( legacy_blend = try_parse_legacy_blend( positive_prompt_string, skip_normalize_legacy_blend ) - positive_prompt: Union[FlattenedPrompt, Blend] + positive_prompt: Conjunction if legacy_blend is not None: - positive_prompt = legacy_blend + positive_conjunction = legacy_blend else: - positive_prompt = Compel.parse_prompt_string(positive_prompt_string) - negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( - negative_prompt_string - ) + positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) + positive_prompt = positive_conjunction.prompts[0] + negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) + negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0] return positive_prompt, negative_prompt - def get_max_token_count( tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False ) -> int: @@ -245,22 +238,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):") logger.debug(f"{discarded}\x1b[0m") - -def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]: +def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]: weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) if len(weighted_subprompts) <= 1: return None strings = [x[0] for x in weighted_subprompts] - weights = [x[1] for x in weighted_subprompts] pp = PromptParser() parsed_conjunctions = [pp.parse_conjunction(x) for x in strings] - flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] - - return Blend( - prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize - ) - + flattened_prompts = [] + weights = [] + for i, x in enumerate(parsed_conjunctions): + if len(x.prompts)>0: + flattened_prompts.append(x.prompts[0]) + weights.append(weighted_subprompts[i][1]) + return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)]) def split_weighted_subprompts(text, skip_normalize=False) -> list: """