more wip sliced attention (.swap doesn't work yet)

This commit is contained in:
Damian Stewart 2023-01-25 14:51:08 +01:00
parent 63c6019f92
commit a4aea1540b
3 changed files with 75 additions and 49 deletions

View File

@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention()
else:
slice_size = 2
self.enable_attention_slicing(slice_size=slice_size)
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
@ -370,43 +373,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps))
else:
self.invokeai_diffuser.remove_cross_attention_control()
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
do_attention_map_saving=False):
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
latents = self.scheduler.add_noise(latents, noise, batched_t)
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
self.invokeai_diffuser.remove_attention_map_saving()
attention_map_saver: Optional[AttentionMapSaver] = None
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
if i == len(timesteps)-1 and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
# TODO resuscitate attention map saving
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
self.invokeai_diffuser.remove_attention_map_saving()
return latents, attention_map_saver
return latents, attention_map_saver
@torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor,

View File

@ -7,6 +7,7 @@ import torch
import diffusers
from torch import nn
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype
@ -305,11 +306,10 @@ class InvokeAICrossAttentionMixin:
def remove_cross_attention_control(model, is_running_diffusers: bool):
def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
if is_running_diffusers:
unet = model
print("** need to know what cross attn processor to use by default, None in the following line is wrong")
unet.set_attn_processor(CrossAttnProcessor())
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
else:
remove_attention_function(model)
@ -343,10 +343,16 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers:
unet = model
unet.set_attn_processor(SwapCrossAttnProcessor())
old_attn_processors = unet.attn_processors
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
return None
@ -509,13 +515,11 @@ class CrossAttnProcessor:
return hidden_states
"""
import enum
from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
@dataclass
@ -523,7 +527,7 @@ class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[])
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def __int__(self,
cac_types_to_do: [CrossAttentionType],
@ -629,9 +633,6 @@ class SwapCrossAttnProcessor(CrossAttnProcessor):
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
def __init__(self, slice_size = 1e6):
self.slice_count = slice_size
# TODO: dynamically pick slice size based on memory conditions
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
@ -660,7 +661,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
modified_text_key = attn.to_k(modified_text_embeddings)
modified_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
# for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings)

View File

@ -1,11 +1,13 @@
import math
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Callable, Optional, Union, Any
from typing import Callable, Optional, Union, Any, Dict
import numpy as np
import torch
from diffusers.models.cross_attention import AttnProcessor
from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
@contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
do_attention_map_saving: bool):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None
if do_swap:
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
step_count=step_count)
try:
yield None
finally:
self.remove_cross_attention_control(old_attn_processor)
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
self.conditioning = conditioning
self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
setup_cross_attention_control(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
return setup_cross_attention_control(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self):
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
self.conditioning = None
self.cross_attention_control_context = None
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
remove_cross_attention_control(self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):