make InvokeAIDiffuserComponent.custom_attention_control a classmethod

This commit is contained in:
Lincoln Stein 2023-05-11 21:13:18 -04:00
parent aca4770481
commit 037078c8ad
3 changed files with 39 additions and 41 deletions

View File

@ -545,6 +545,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context( with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps), step_count=len(self.scheduler.timesteps),
): ):

View File

@ -10,6 +10,7 @@ import diffusers
import psutil import psutil
import torch import torch
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttentionProcessor
from torch import nn from torch import nn
@ -352,8 +353,7 @@ def restore_default_cross_attention(
else: else:
remove_attention_function(model) remove_attention_function(model)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
def override_cross_attention(model, context: Context, is_running_diffusers=False):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -372,15 +372,13 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
indices = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length: if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0): if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited # these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1] indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1 mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device) context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device) context.cross_attention_index_map = indices.to(device)
if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
@ -388,21 +386,8 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
else: else:
# try to re-use an existing slice size # try to re-use an existing slice size
default_slice_size = 4 default_slice_size = 4
slice_size = next( slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
(
p.slice_size
for p in old_attn_processors.values()
if type(p) is SlicedAttnProcessor
),
default_slice_size,
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
return None
def get_cross_attention_modules( def get_cross_attention_modules(
model, which: CrossAttentionType model, which: CrossAttentionType

View File

@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -17,8 +18,8 @@ from .cross_attention_control import (
CrossAttentionType, CrossAttentionType,
SwapCrossAttnContext, SwapCrossAttnContext,
get_cross_attention_modules, get_cross_attention_modules,
override_cross_attention,
restore_default_cross_attention, restore_default_cross_attention,
setup_cross_attention_control_attention_processors,
) )
from .cross_attention_map_saving import AttentionMapSaver from .cross_attention_map_saving import AttentionMapSaver
@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance self.sequential_guidance = Globals.sequential_guidance
@classmethod
@contextmanager @contextmanager
def custom_attention_context( def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
): ):
do_swap = ( old_attn_processors = None
extra_conditioning_info is not None if extra_conditioning_info and (
and extra_conditioning_info.wants_cross_attention_control extra_conditioning_info.wants_cross_attention_control
):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
) )
old_attn_processor = None setup_cross_attention_control_attention_processors(
if do_swap: unet,
old_attn_processor = self.override_cross_attention( cross_attention_control_context,
extra_conditioning_info, step_count=step_count
) )
try: try:
yield None yield None
finally: finally:
if old_attn_processor is not None: if old_attn_processors is not None:
self.restore_default_cross_attention(old_attn_processor) unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
# self.remove_attention_map_saving() # self.remove_attention_map_saving()