mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make conditioning.py work with compel 1.1.5
This commit is contained in:
parent
fb6ef61a4d
commit
5d5157fc65
@ -16,6 +16,7 @@ from compel.prompt_parser import (
|
|||||||
FlattenedPrompt,
|
FlattenedPrompt,
|
||||||
Fragment,
|
Fragment,
|
||||||
PromptParser,
|
PromptParser,
|
||||||
|
Conjunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
@ -25,58 +26,51 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
|
|||||||
from ..util import torch_dtype
|
from ..util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(
|
def get_uc_and_c_and_ec(prompt_string,
|
||||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
model: InvokeAIDiffuserComponent,
|
||||||
):
|
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||||
prompt_string
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = model.tokenizer
|
compel = Compel(tokenizer=model.tokenizer,
|
||||||
compel = Compel(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder=model.text_encoder,
|
text_encoder=model.text_encoder,
|
||||||
textual_inversion_manager=model.textual_inversion_manager,
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=False
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# get rid of any newline characters
|
# get rid of any newline characters
|
||||||
prompt_string = prompt_string.replace("\n", " ")
|
prompt_string = prompt_string.replace("\n", " ")
|
||||||
(
|
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||||
positive_prompt_string,
|
|
||||||
negative_prompt_string,
|
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
positive_conjunction: Conjunction
|
||||||
legacy_blend = try_parse_legacy_blend(
|
|
||||||
positive_prompt_string, skip_normalize_legacy_blend
|
|
||||||
)
|
|
||||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
|
||||||
if legacy_blend is not None:
|
if legacy_blend is not None:
|
||||||
positive_prompt = legacy_blend
|
positive_conjunction = legacy_blend
|
||||||
else:
|
else:
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
positive_prompt = positive_conjunction.prompts[0]
|
||||||
negative_prompt_string
|
|
||||||
)
|
|
||||||
|
|
||||||
|
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||||
|
|
||||||
|
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||||
|
|
||||||
|
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
|
||||||
|
extra_conditioning_info=None,
|
||||||
|
step_count=-1):
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
|
|
||||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
# now build the "real" ec
|
||||||
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
cross_attention_control_args=options.get(
|
||||||
tokens_count_including_eos_bos=tokens_count,
|
'cross_attention_control', None))
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
return uc, c, ec
|
return uc, c, ec
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_structure(
|
def get_prompt_structure(
|
||||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||||
@ -87,18 +81,17 @@ def get_prompt_structure(
|
|||||||
legacy_blend = try_parse_legacy_blend(
|
legacy_blend = try_parse_legacy_blend(
|
||||||
positive_prompt_string, skip_normalize_legacy_blend
|
positive_prompt_string, skip_normalize_legacy_blend
|
||||||
)
|
)
|
||||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
positive_prompt: Conjunction
|
||||||
if legacy_blend is not None:
|
if legacy_blend is not None:
|
||||||
positive_prompt = legacy_blend
|
positive_conjunction = legacy_blend
|
||||||
else:
|
else:
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
positive_prompt = positive_conjunction.prompts[0]
|
||||||
negative_prompt_string
|
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
)
|
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||||
|
|
||||||
return positive_prompt, negative_prompt
|
return positive_prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -245,22 +238,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
|||||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||||
logger.debug(f"{discarded}\x1b[0m")
|
logger.debug(f"{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
|
||||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||||
if len(weighted_subprompts) <= 1:
|
if len(weighted_subprompts) <= 1:
|
||||||
return None
|
return None
|
||||||
strings = [x[0] for x in weighted_subprompts]
|
strings = [x[0] for x in weighted_subprompts]
|
||||||
weights = [x[1] for x in weighted_subprompts]
|
|
||||||
|
|
||||||
pp = PromptParser()
|
pp = PromptParser()
|
||||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
flattened_prompts = []
|
||||||
|
weights = []
|
||||||
return Blend(
|
for i, x in enumerate(parsed_conjunctions):
|
||||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
if len(x.prompts)>0:
|
||||||
)
|
flattened_prompts.append(x.prompts[0])
|
||||||
|
weights.append(weighted_subprompts[i][1])
|
||||||
|
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user