make conditioning.py work with compel 1.1.5 (#3383)

This PR fixes the ValueError issue that was preventing all prompts from
working.
This commit is contained in:
Lincoln Stein 2023-05-15 09:46:04 -04:00 committed by GitHub
commit 34fb1c4b19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 96 deletions

View File

@ -100,7 +100,8 @@ class CompelInvocation(BaseInvocation):
# TODO: support legacy blend?
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
conjunction = Compel.parse_prompt_string(prompt_str)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if getattr(Globals, "log_tokenization", False):
log_tokenization_for_prompt_object(prompt, tokenizer)

View File

@ -16,6 +16,7 @@ from compel.prompt_parser import (
FlattenedPrompt,
Fragment,
PromptParser,
Conjunction,
)
import invokeai.backend.util.logging as logger
@ -25,58 +26,48 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
):
def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_string
)
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
tokenizer = model.tokenizer
compel = Compel(
tokenizer=tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False
)
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
(
positive_prompt_string,
negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
tokens_count = get_max_token_count(tokenizer, positive_prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get("cross_attention_control", None),
)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None))
return uc, c, ec
def get_prompt_structure(
prompt_string, skip_normalize_legacy_blend: bool = False
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
@ -87,18 +78,17 @@ def get_prompt_structure(
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
positive_prompt: Conjunction
if legacy_blend is not None:
positive_prompt = legacy_blend
positive_conjunction = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
return positive_prompt, negative_prompt
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
@ -245,22 +235,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
)
flattened_prompts = []
weights = []
for i, x in enumerate(parsed_conjunctions):
if len(x.prompts)>0:
flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
def split_weighted_subprompts(text, skip_normalize=False) -> list:
"""

View File

@ -548,8 +548,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
yield PipelineIntermediateState(
run_id=run_id,

View File

@ -10,6 +10,7 @@ import diffusers
import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from torch import nn
@ -352,8 +353,7 @@ def restore_default_cross_attention(
else:
remove_attention_function(model)
def override_cross_attention(model, context: Context, is_running_diffusers=False):
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next(
(
p.slice_size
for p in old_attn_processors.values()
if type(p) is SlicedAttnProcessor
),
default_slice_size,
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
return None
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(
model, which: CrossAttentionType

View File

@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
@ -17,8 +18,8 @@ from .cross_attention_control import (
CrossAttentionType,
SwapCrossAttnContext,
get_cross_attention_modules,
override_cross_attention,
restore_default_cross_attention,
setup_cross_attention_control_attention_processors,
)
from .cross_attention_map_saving import AttentionMapSaver
@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance
@classmethod
@contextmanager
def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
):
do_swap = (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
)
old_attn_processor = None
if do_swap:
old_attn_processor = self.override_cross_attention(
extra_conditioning_info, step_count=step_count
)
old_attn_processors = None
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
)
try:
yield None
finally:
if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor)
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel~=1.1.5",
"compel~=1.1.5",
"datasets",
"diffusers[torch]~=0.16.1",
"dnspython==2.2.1",