mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make conditioning.py work with compel 1.1.5 (#3383)
This PR fixes the ValueError issue that was preventing all prompts from working.
This commit is contained in:
commit
34fb1c4b19
@ -100,7 +100,8 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: support legacy blend?
|
# TODO: support legacy blend?
|
||||||
|
|
||||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||||
|
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||||
|
|
||||||
if getattr(Globals, "log_tokenization", False):
|
if getattr(Globals, "log_tokenization", False):
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
@ -16,6 +16,7 @@ from compel.prompt_parser import (
|
|||||||
FlattenedPrompt,
|
FlattenedPrompt,
|
||||||
Fragment,
|
Fragment,
|
||||||
PromptParser,
|
PromptParser,
|
||||||
|
Conjunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
@ -25,58 +26,48 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
|
|||||||
from ..util import torch_dtype
|
from ..util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(
|
def get_uc_and_c_and_ec(prompt_string,
|
||||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
model: InvokeAIDiffuserComponent,
|
||||||
):
|
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||||
prompt_string
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = model.tokenizer
|
compel = Compel(tokenizer=model.tokenizer,
|
||||||
compel = Compel(
|
text_encoder=model.text_encoder,
|
||||||
tokenizer=tokenizer,
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
text_encoder=model.text_encoder,
|
dtype_for_device_getter=torch_dtype,
|
||||||
textual_inversion_manager=model.textual_inversion_manager,
|
truncate_long_prompts=False,
|
||||||
dtype_for_device_getter=torch_dtype,
|
)
|
||||||
truncate_long_prompts=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# get rid of any newline characters
|
# get rid of any newline characters
|
||||||
prompt_string = prompt_string.replace("\n", " ")
|
prompt_string = prompt_string.replace("\n", " ")
|
||||||
(
|
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||||
positive_prompt_string,
|
|
||||||
negative_prompt_string,
|
|
||||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
|
||||||
legacy_blend = try_parse_legacy_blend(
|
|
||||||
positive_prompt_string, skip_normalize_legacy_blend
|
|
||||||
)
|
|
||||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
|
||||||
if legacy_blend is not None:
|
|
||||||
positive_prompt = legacy_blend
|
|
||||||
else:
|
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
|
||||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
|
||||||
negative_prompt_string
|
|
||||||
)
|
|
||||||
|
|
||||||
|
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||||
|
positive_conjunction: Conjunction
|
||||||
|
if legacy_blend is not None:
|
||||||
|
positive_conjunction = legacy_blend
|
||||||
|
else:
|
||||||
|
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
|
positive_prompt = positive_conjunction.prompts[0]
|
||||||
|
|
||||||
|
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||||
|
|
||||||
|
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||||
|
cross_attention_control_args=options.get(
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
'cross_attention_control', None))
|
||||||
tokens_count_including_eos_bos=tokens_count,
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
return uc, c, ec
|
return uc, c, ec
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_structure(
|
def get_prompt_structure(
|
||||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||||
@ -87,18 +78,17 @@ def get_prompt_structure(
|
|||||||
legacy_blend = try_parse_legacy_blend(
|
legacy_blend = try_parse_legacy_blend(
|
||||||
positive_prompt_string, skip_normalize_legacy_blend
|
positive_prompt_string, skip_normalize_legacy_blend
|
||||||
)
|
)
|
||||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
positive_prompt: Conjunction
|
||||||
if legacy_blend is not None:
|
if legacy_blend is not None:
|
||||||
positive_prompt = legacy_blend
|
positive_conjunction = legacy_blend
|
||||||
else:
|
else:
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
positive_prompt = positive_conjunction.prompts[0]
|
||||||
negative_prompt_string
|
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
)
|
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||||
|
|
||||||
return positive_prompt, negative_prompt
|
return positive_prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -245,22 +235,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
|||||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||||
logger.debug(f"{discarded}\x1b[0m")
|
logger.debug(f"{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
|
||||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||||
if len(weighted_subprompts) <= 1:
|
if len(weighted_subprompts) <= 1:
|
||||||
return None
|
return None
|
||||||
strings = [x[0] for x in weighted_subprompts]
|
strings = [x[0] for x in weighted_subprompts]
|
||||||
weights = [x[1] for x in weighted_subprompts]
|
|
||||||
|
|
||||||
pp = PromptParser()
|
pp = PromptParser()
|
||||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
flattened_prompts = []
|
||||||
|
weights = []
|
||||||
return Blend(
|
for i, x in enumerate(parsed_conjunctions):
|
||||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
if len(x.prompts)>0:
|
||||||
)
|
flattened_prompts.append(x.prompts[0])
|
||||||
|
weights.append(weighted_subprompts[i][1])
|
||||||
|
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -548,8 +548,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
with self.invokeai_diffuser.custom_attention_context(
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
self.invokeai_diffuser.model,
|
||||||
step_count=len(self.scheduler.timesteps),
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
|
step_count=len(self.scheduler.timesteps),
|
||||||
):
|
):
|
||||||
yield PipelineIntermediateState(
|
yield PipelineIntermediateState(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
@ -10,6 +10,7 @@ import diffusers
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -352,8 +353,7 @@ def restore_default_cross_attention(
|
|||||||
else:
|
else:
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
|||||||
indices = torch.arange(max_length, dtype=torch.long)
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||||
if b0 < max_length:
|
if b0 < max_length:
|
||||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||||
# these tokens have not been edited
|
# these tokens have not been edited
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
mask[b0:b1] = 1
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
context.cross_attention_mask = mask.to(device)
|
context.cross_attention_mask = mask.to(device)
|
||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
if is_running_diffusers:
|
old_attn_processors = unet.attn_processors
|
||||||
unet = model
|
if torch.backends.mps.is_available():
|
||||||
old_attn_processors = unet.attn_processors
|
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||||
if torch.backends.mps.is_available():
|
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
|
||||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
|
||||||
else:
|
|
||||||
# try to re-use an existing slice size
|
|
||||||
default_slice_size = 4
|
|
||||||
slice_size = next(
|
|
||||||
(
|
|
||||||
p.slice_size
|
|
||||||
for p in old_attn_processors.values()
|
|
||||||
if type(p) is SlicedAttnProcessor
|
|
||||||
),
|
|
||||||
default_slice_size,
|
|
||||||
)
|
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
|
||||||
return old_attn_processors
|
|
||||||
else:
|
else:
|
||||||
context.register_cross_attention_modules(model)
|
# try to re-use an existing slice size
|
||||||
inject_attention_function(model, context)
|
default_slice_size = 4
|
||||||
return None
|
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||||
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
|
|
||||||
def get_cross_attention_modules(
|
def get_cross_attention_modules(
|
||||||
model, which: CrossAttentionType
|
model, which: CrossAttentionType
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
@ -17,8 +18,8 @@ from .cross_attention_control import (
|
|||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
get_cross_attention_modules,
|
get_cross_attention_modules,
|
||||||
override_cross_attention,
|
|
||||||
restore_default_cross_attention,
|
restore_default_cross_attention,
|
||||||
|
setup_cross_attention_control_attention_processors,
|
||||||
)
|
)
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = Globals.sequential_guidance
|
self.sequential_guidance = Globals.sequential_guidance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
cls,
|
||||||
|
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||||
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
|
step_count: int
|
||||||
):
|
):
|
||||||
do_swap = (
|
old_attn_processors = None
|
||||||
extra_conditioning_info is not None
|
if extra_conditioning_info and (
|
||||||
and extra_conditioning_info.wants_cross_attention_control
|
extra_conditioning_info.wants_cross_attention_control
|
||||||
)
|
):
|
||||||
old_attn_processor = None
|
old_attn_processors = unet.attn_processors
|
||||||
if do_swap:
|
# Load lora conditions into the model
|
||||||
old_attn_processor = self.override_cross_attention(
|
if extra_conditioning_info.wants_cross_attention_control:
|
||||||
extra_conditioning_info, step_count=step_count
|
cross_attention_control_context = Context(
|
||||||
)
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
|
step_count=step_count,
|
||||||
|
)
|
||||||
|
setup_cross_attention_control_attention_processors(
|
||||||
|
unet,
|
||||||
|
cross_attention_control_context,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if old_attn_processor is not None:
|
if old_attn_processors is not None:
|
||||||
self.restore_default_cross_attention(old_attn_processor)
|
unet.set_attn_processor(old_attn_processors)
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel~=1.1.5",
|
"compel~=1.1.5",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.16.1",
|
"diffusers[torch]~=0.16.1",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
|
Loading…
Reference in New Issue
Block a user