fix newlines causing negative prompt to be parsed incorrectly (#2838)

This is the same fix that was applied to main in PR 2837.
This commit is contained in:
Lincoln Stein 2023-03-06 18:37:44 -05:00 committed by GitHub
commit 62cfdb9f11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,7 +9,7 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
import re import re
from typing import Union, Optional, Any from typing import Union, Optional, Any
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer
from compel import Compel from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
@ -52,6 +52,8 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype) dtype_for_device_getter=torch_dtype)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend positive_prompt: FlattenedPrompt|Blend
@ -113,7 +115,7 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
return tokens return tokens
def split_prompt_to_positive_and_negative(prompt_string_uncleaned): def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
unconditioned_words = '' unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)