mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
more wip sliced attention (.swap doesn't work yet)
This commit is contained in:
parent
63c6019f92
commit
a4aea1540b
@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
if is_xformers_available() and not Globals.disable_xformers:
|
if is_xformers_available() and not Globals.disable_xformers:
|
||||||
self.enable_xformers_memory_efficient_attention()
|
self.enable_xformers_memory_efficient_attention()
|
||||||
|
else:
|
||||||
|
slice_size = 2
|
||||||
|
self.enable_attention_slicing(slice_size=slice_size)
|
||||||
|
|
||||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
@ -370,43 +373,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
step_count=len(self.scheduler.timesteps),
|
||||||
step_count=len(self.scheduler.timesteps))
|
do_attention_map_saving=False):
|
||||||
else:
|
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
|
||||||
|
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||||
latents=latents)
|
latents=latents)
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
batched_t = torch.full((batch_size,), timesteps[0],
|
batched_t = torch.full((batch_size,), timesteps[0],
|
||||||
dtype=timesteps.dtype, device=self.unet.device)
|
dtype=timesteps.dtype, device=self.unet.device)
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
self.invokeai_diffuser.remove_attention_map_saving()
|
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(batched_t, latents, conditioning_data,
|
step_output = self.step(batched_t, latents, conditioning_data,
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance)
|
additional_guidance=additional_guidance)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
|
|
||||||
if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
# TODO resuscitate attention map saving
|
||||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||||
attention_map_token_ids = range(1, eos_token_index)
|
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||||
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
# attention_map_token_ids = range(1, eos_token_index)
|
||||||
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||||
|
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||||
|
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||||
|
|
||||||
self.invokeai_diffuser.remove_attention_map_saving()
|
return latents, attention_map_saver
|
||||||
return latents, attention_map_saver
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
import diffusers
|
import diffusers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from ldm.invoke.devices import torch_dtype
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -305,11 +306,10 @@ class InvokeAICrossAttentionMixin:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def remove_cross_attention_control(model, is_running_diffusers: bool):
|
def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
||||||
if is_running_diffusers:
|
if is_running_diffusers:
|
||||||
unet = model
|
unet = model
|
||||||
print("** need to know what cross attn processor to use by default, None in the following line is wrong")
|
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||||
unet.set_attn_processor(CrossAttnProcessor())
|
|
||||||
else:
|
else:
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
@ -343,10 +343,16 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
|||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
if is_running_diffusers:
|
if is_running_diffusers:
|
||||||
unet = model
|
unet = model
|
||||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
old_attn_processors = unet.attn_processors
|
||||||
|
# try to re-use an existing slice size
|
||||||
|
default_slice_size = 4
|
||||||
|
slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||||
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
|
return old_attn_processors
|
||||||
else:
|
else:
|
||||||
context.register_cross_attention_modules(model)
|
context.register_cross_attention_modules(model)
|
||||||
inject_attention_function(model, context)
|
inject_attention_function(model, context)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -509,13 +515,11 @@ class CrossAttnProcessor:
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import enum
|
|
||||||
from dataclasses import field, dataclass
|
from dataclasses import field, dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
|
||||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -523,7 +527,7 @@ class SwapCrossAttnContext:
|
|||||||
modified_text_embeddings: torch.Tensor
|
modified_text_embeddings: torch.Tensor
|
||||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||||
mask: torch.Tensor # in the target space of the index_map
|
mask: torch.Tensor # in the target space of the index_map
|
||||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[])
|
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||||
|
|
||||||
def __int__(self,
|
def __int__(self,
|
||||||
cac_types_to_do: [CrossAttentionType],
|
cac_types_to_do: [CrossAttentionType],
|
||||||
@ -629,9 +633,6 @@ class SwapCrossAttnProcessor(CrossAttnProcessor):
|
|||||||
|
|
||||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||||
|
|
||||||
def __init__(self, slice_size = 1e6):
|
|
||||||
self.slice_count = slice_size
|
|
||||||
|
|
||||||
# TODO: dynamically pick slice size based on memory conditions
|
# TODO: dynamically pick slice size based on memory conditions
|
||||||
|
|
||||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||||
@ -660,7 +661,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||||
modified_text_key = attn.head_to_batch_dim(original_text_key)
|
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||||
|
|
||||||
# for the "value" just use the modified text embeddings.
|
# for the "value" just use the modified text embeddings.
|
||||||
value = attn.to_v(modified_text_embeddings)
|
value = attn.to_v(modified_text_embeddings)
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import math
|
import math
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Callable, Optional, Union, Any
|
from typing import Callable, Optional, Union, Any, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
||||||
CrossAttentionType, SwapCrossAttnContext
|
CrossAttentionType, SwapCrossAttnContext
|
||||||
@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
|
|
||||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
@contextmanager
|
||||||
|
def custom_attention_context(self,
|
||||||
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
|
step_count: int,
|
||||||
|
do_attention_map_saving: bool):
|
||||||
|
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
|
old_attn_processor = None
|
||||||
|
if do_swap:
|
||||||
|
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
|
||||||
|
step_count=step_count)
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
self.remove_cross_attention_control(old_attn_processor)
|
||||||
|
# TODO resuscitate attention map saving
|
||||||
|
#self.remove_attention_map_saving()
|
||||||
|
|
||||||
|
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||||
|
"""
|
||||||
|
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
|
the previous attention processor is returned so that the caller can restore it later.
|
||||||
|
"""
|
||||||
self.conditioning = conditioning
|
self.conditioning = conditioning
|
||||||
self.cross_attention_control_context = Context(
|
self.cross_attention_control_context = Context(
|
||||||
arguments=self.conditioning.cross_attention_control_args,
|
arguments=self.conditioning.cross_attention_control_args,
|
||||||
step_count=step_count
|
step_count=step_count
|
||||||
)
|
)
|
||||||
setup_cross_attention_control(self.model,
|
return setup_cross_attention_control(self.model,
|
||||||
self.cross_attention_control_context,
|
self.cross_attention_control_context,
|
||||||
is_running_diffusers=self.is_running_diffusers)
|
is_running_diffusers=self.is_running_diffusers)
|
||||||
|
|
||||||
def remove_cross_attention_control(self):
|
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
|
remove_cross_attention_control(self.model,
|
||||||
|
is_running_diffusers=self.is_running_diffusers,
|
||||||
|
restore_attention_processor=restore_attention_processor)
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user