mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make InvokeAIDiffuserComponent.custom_attention_control a classmethod
This commit is contained in:
parent
aca4770481
commit
037078c8ad
@ -545,6 +545,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
with self.invokeai_diffuser.custom_attention_context(
|
||||||
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
):
|
):
|
||||||
|
@ -10,6 +10,7 @@ import diffusers
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -352,8 +353,7 @@ def restore_default_cross_attention(
|
|||||||
else:
|
else:
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -379,8 +379,6 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
|||||||
|
|
||||||
context.cross_attention_mask = mask.to(device)
|
context.cross_attention_mask = mask.to(device)
|
||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
if is_running_diffusers:
|
|
||||||
unet = model
|
|
||||||
old_attn_processors = unet.attn_processors
|
old_attn_processors = unet.attn_processors
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||||
@ -388,21 +386,8 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
|||||||
else:
|
else:
|
||||||
# try to re-use an existing slice size
|
# try to re-use an existing slice size
|
||||||
default_slice_size = 4
|
default_slice_size = 4
|
||||||
slice_size = next(
|
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||||
(
|
|
||||||
p.slice_size
|
|
||||||
for p in old_attn_processors.values()
|
|
||||||
if type(p) is SlicedAttnProcessor
|
|
||||||
),
|
|
||||||
default_slice_size,
|
|
||||||
)
|
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
return old_attn_processors
|
|
||||||
else:
|
|
||||||
context.register_cross_attention_modules(model)
|
|
||||||
inject_attention_function(model, context)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cross_attention_modules(
|
def get_cross_attention_modules(
|
||||||
model, which: CrossAttentionType
|
model, which: CrossAttentionType
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
@ -17,8 +18,8 @@ from .cross_attention_control import (
|
|||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
get_cross_attention_modules,
|
get_cross_attention_modules,
|
||||||
override_cross_attention,
|
|
||||||
restore_default_cross_attention,
|
restore_default_cross_attention,
|
||||||
|
setup_cross_attention_control_attention_processors,
|
||||||
)
|
)
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = Globals.sequential_guidance
|
self.sequential_guidance = Globals.sequential_guidance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
cls,
|
||||||
|
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||||
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
|
step_count: int
|
||||||
):
|
):
|
||||||
do_swap = (
|
old_attn_processors = None
|
||||||
extra_conditioning_info is not None
|
if extra_conditioning_info and (
|
||||||
and extra_conditioning_info.wants_cross_attention_control
|
extra_conditioning_info.wants_cross_attention_control
|
||||||
|
):
|
||||||
|
old_attn_processors = unet.attn_processors
|
||||||
|
# Load lora conditions into the model
|
||||||
|
if extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
cross_attention_control_context = Context(
|
||||||
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
|
step_count=step_count,
|
||||||
)
|
)
|
||||||
old_attn_processor = None
|
setup_cross_attention_control_attention_processors(
|
||||||
if do_swap:
|
unet,
|
||||||
old_attn_processor = self.override_cross_attention(
|
cross_attention_control_context,
|
||||||
extra_conditioning_info, step_count=step_count
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if old_attn_processor is not None:
|
if old_attn_processors is not None:
|
||||||
self.restore_default_cross_attention(old_attn_processor)
|
unet.set_attn_processor(old_attn_processors)
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user