Merge branch 'main' into feat/ui/fix-uploading

This commit is contained in:
blessedcoolant 2023-05-16 02:20:59 +12:00 committed by GitHub
commit 2fc70c509b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 96 deletions

View File

@ -100,7 +100,8 @@ class CompelInvocation(BaseInvocation):
# TODO: support legacy blend? # TODO: support legacy blend?
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str) conjunction = Compel.parse_prompt_string(prompt_str)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if getattr(Globals, "log_tokenization", False): if getattr(Globals, "log_tokenization", False):
log_tokenization_for_prompt_object(prompt, tokenizer) log_tokenization_for_prompt_object(prompt, tokenizer)

View File

@ -16,6 +16,7 @@ from compel.prompt_parser import (
FlattenedPrompt, FlattenedPrompt,
Fragment, Fragment,
PromptParser, PromptParser,
Conjunction,
) )
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -25,58 +26,48 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype from ..util import torch_dtype
def get_uc_and_c_and_ec( def get_uc_and_c_and_ec(prompt_string,
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False model: InvokeAIDiffuserComponent,
): log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions. # lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used. # this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
prompt_string
)
tokenizer = model.tokenizer compel = Compel(tokenizer=model.tokenizer,
compel = Compel( text_encoder=model.text_encoder,
tokenizer=tokenizer, textual_inversion_manager=model.textual_inversion_manager,
text_encoder=model.text_encoder, dtype_for_device_getter=torch_dtype,
textual_inversion_manager=model.textual_inversion_manager, truncate_long_prompts=False,
dtype_for_device_getter=torch_dtype, )
truncate_long_prompts=False
)
# get rid of any newline characters # get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ") prompt_string = prompt_string.replace("\n", " ")
( positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
positive_prompt_string,
negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False): if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
tokens_count = get_max_token_count(tokenizer, positive_prompt) ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( 'cross_attention_control', None))
tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get("cross_attention_control", None),
)
return uc, c, ec return uc, c, ec
def get_prompt_structure( def get_prompt_structure(
prompt_string, skip_normalize_legacy_blend: bool = False prompt_string, skip_normalize_legacy_blend: bool = False
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt): ) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
@ -87,18 +78,17 @@ def get_prompt_structure(
legacy_blend = try_parse_legacy_blend( legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend positive_prompt_string, skip_normalize_legacy_blend
) )
positive_prompt: Union[FlattenedPrompt, Blend] positive_prompt: Conjunction
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_conjunction = legacy_blend
else: else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string) positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( positive_prompt = positive_conjunction.prompts[0]
negative_prompt_string negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
) negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
return positive_prompt, negative_prompt return positive_prompt, negative_prompt
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int: ) -> int:
@ -245,22 +235,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):") logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m") logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1: if len(weighted_subprompts) <= 1:
return None return None
strings = [x[0] for x in weighted_subprompts] strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser() pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings] parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] flattened_prompts = []
weights = []
return Blend( for i, x in enumerate(parsed_conjunctions):
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize if len(x.prompts)>0:
) flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
def split_weighted_subprompts(text, skip_normalize=False) -> list: def split_weighted_subprompts(text, skip_normalize=False) -> list:
""" """

View File

@ -548,8 +548,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context( with self.invokeai_diffuser.custom_attention_context(
extra_conditioning_info=extra_conditioning_info, self.invokeai_diffuser.model,
step_count=len(self.scheduler.timesteps), extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
): ):
yield PipelineIntermediateState( yield PipelineIntermediateState(
run_id=run_id, run_id=run_id,

View File

@ -10,6 +10,7 @@ import diffusers
import psutil import psutil
import torch import torch
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttentionProcessor
from torch import nn from torch import nn
@ -352,8 +353,7 @@ def restore_default_cross_attention(
else: else:
remove_attention_function(model) remove_attention_function(model)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
def override_cross_attention(model, context: Context, is_running_diffusers=False):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
indices = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length: if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0): if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited # these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1] indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1 mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device) context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device) context.cross_attention_index_map = indices.to(device)
if is_running_diffusers: old_attn_processors = unet.attn_processors
unet = model if torch.backends.mps.is_available():
old_attn_processors = unet.attn_processors # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
if torch.backends.mps.is_available(): unet.set_attn_processor(SwapCrossAttnProcessor())
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next(
(
p.slice_size
for p in old_attn_processors.values()
if type(p) is SlicedAttnProcessor
),
default_slice_size,
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else: else:
context.register_cross_attention_modules(model) # try to re-use an existing slice size
inject_attention_function(model, context) default_slice_size = 4
return None slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules( def get_cross_attention_modules(
model, which: CrossAttentionType model, which: CrossAttentionType

View File

@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -17,8 +18,8 @@ from .cross_attention_control import (
CrossAttentionType, CrossAttentionType,
SwapCrossAttnContext, SwapCrossAttnContext,
get_cross_attention_modules, get_cross_attention_modules,
override_cross_attention,
restore_default_cross_attention, restore_default_cross_attention,
setup_cross_attention_control_attention_processors,
) )
from .cross_attention_map_saving import AttentionMapSaver from .cross_attention_map_saving import AttentionMapSaver
@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance self.sequential_guidance = Globals.sequential_guidance
@classmethod
@contextmanager @contextmanager
def custom_attention_context( def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
): ):
do_swap = ( old_attn_processors = None
extra_conditioning_info is not None if extra_conditioning_info and (
and extra_conditioning_info.wants_cross_attention_control extra_conditioning_info.wants_cross_attention_control
) ):
old_attn_processor = None old_attn_processors = unet.attn_processors
if do_swap: # Load lora conditions into the model
old_attn_processor = self.override_cross_attention( if extra_conditioning_info.wants_cross_attention_control:
extra_conditioning_info, step_count=step_count cross_attention_control_context = Context(
) arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
)
try: try:
yield None yield None
finally: finally:
if old_attn_processor is not None: if old_attn_processors is not None:
self.restore_default_cross_attention(old_attn_processor) unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
# self.remove_attention_map_saving() # self.remove_attention_map_saving()

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel~=1.1.5", "compel~=1.1.5",
"datasets", "datasets",
"diffusers[torch]~=0.16.1", "diffusers[torch]~=0.16.1",
"dnspython==2.2.1", "dnspython==2.2.1",