more wip sliced attention (.swap doesn't work yet)

This commit is contained in:
Damian Stewart 2023-01-25 14:51:08 +01:00
parent 63c6019f92
commit a4aea1540b
3 changed files with 75 additions and 49 deletions

View File

@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if is_xformers_available() and not Globals.disable_xformers: if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
else:
slice_size = 2
self.enable_attention_slicing(slice_size=slice_size)
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
@ -370,11 +373,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count=len(self.scheduler.timesteps),
step_count=len(self.scheduler.timesteps)) do_attention_map_saving=False):
else:
self.invokeai_diffuser.remove_cross_attention_control()
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents) latents=latents)
@ -385,7 +386,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
self.invokeai_diffuser.remove_attention_map_saving()
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
@ -396,16 +396,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = step_output.prev_sample latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None) predicted_original = getattr(step_output, 'pred_original_sample', None)
if i == len(timesteps)-1 and extra_conditioning_info is not None: # TODO resuscitate attention map saving
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 #if i == len(timesteps)-1 and extra_conditioning_info is not None:
attention_map_token_ids = range(1, eos_token_index) # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) # attention_map_token_ids = range(1, eos_token_index)
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original, attention_map_saver=attention_map_saver) predicted_original=predicted_original, attention_map_saver=attention_map_saver)
self.invokeai_diffuser.remove_attention_map_saving()
return latents, attention_map_saver return latents, attention_map_saver
@torch.inference_mode() @torch.inference_mode()

View File

@ -7,6 +7,7 @@ import torch
import diffusers import diffusers
from torch import nn from torch import nn
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype from ldm.invoke.devices import torch_dtype
@ -305,11 +306,10 @@ class InvokeAICrossAttentionMixin:
def remove_cross_attention_control(model, is_running_diffusers: bool): def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
if is_running_diffusers: if is_running_diffusers:
unet = model unet = model
print("** need to know what cross attn processor to use by default, None in the following line is wrong") unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
unet.set_attn_processor(CrossAttnProcessor())
else: else:
remove_attention_function(model) remove_attention_function(model)
@ -343,10 +343,16 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
context.cross_attention_index_map = indices.to(device) context.cross_attention_index_map = indices.to(device)
if is_running_diffusers: if is_running_diffusers:
unet = model unet = model
unet.set_attn_processor(SwapCrossAttnProcessor()) old_attn_processors = unet.attn_processors
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else: else:
context.register_cross_attention_modules(model) context.register_cross_attention_modules(model)
inject_attention_function(model, context) inject_attention_function(model, context)
return None
@ -509,13 +515,11 @@ class CrossAttnProcessor:
return hidden_states return hidden_states
""" """
import enum
from dataclasses import field, dataclass from dataclasses import field, dataclass
import torch import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
@dataclass @dataclass
@ -523,7 +527,7 @@ class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[]) cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def __int__(self, def __int__(self,
cac_types_to_do: [CrossAttentionType], cac_types_to_do: [CrossAttentionType],
@ -629,9 +633,6 @@ class SwapCrossAttnProcessor(CrossAttnProcessor):
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
def __init__(self, slice_size = 1e6):
self.slice_count = slice_size
# TODO: dynamically pick slice size based on memory conditions # TODO: dynamically pick slice size based on memory conditions
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
@ -660,7 +661,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
original_text_key = attn.head_to_batch_dim(original_text_key) original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
modified_text_key = attn.to_k(modified_text_embeddings) modified_text_key = attn.to_k(modified_text_embeddings)
modified_text_key = attn.head_to_batch_dim(original_text_key) modified_text_key = attn.head_to_batch_dim(modified_text_key)
# for the "value" just use the modified text embeddings. # for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings) value = attn.to_v(modified_text_embeddings)

View File

@ -1,11 +1,13 @@
import math import math
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Callable, Optional, Union, Any from typing import Callable, Optional, Union, Any, Dict
import numpy as np import numpy as np
import torch import torch
from diffusers.models.cross_attention import AttnProcessor
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext CrossAttentionType, SwapCrossAttnContext
@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None self.cross_attention_control_context = None
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): @contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
do_attention_map_saving: bool):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None
if do_swap:
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
step_count=step_count)
try:
yield None
finally:
self.remove_cross_attention_control(old_attn_processor)
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
self.conditioning = conditioning self.conditioning = conditioning
self.cross_attention_control_context = Context( self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args, arguments=self.conditioning.cross_attention_control_args,
step_count=step_count step_count=step_count
) )
setup_cross_attention_control(self.model, return setup_cross_attention_control(self.model,
self.cross_attention_control_context, self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers) is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self): def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers) remove_cross_attention_control(self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor)
def setup_attention_map_saving(self, saver: AttentionMapSaver): def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key): def callback(slice, dim, offset, slice_size, key):