mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
more wip sliced attention (.swap doesn't work yet)
This commit is contained in:
parent
63c6019f92
commit
a4aea1540b
@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if is_xformers_available() and not Globals.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
slice_size = 2
|
||||
self.enable_attention_slicing(slice_size=slice_size)
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
@ -370,43 +373,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps))
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
do_attention_map_saving=False):
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
# TODO resuscitate attention map saving
|
||||
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
return latents, attention_map_saver
|
||||
return latents, attention_map_saver
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
import diffusers
|
||||
from torch import nn
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
|
||||
|
||||
@ -305,11 +306,10 @@ class InvokeAICrossAttentionMixin:
|
||||
|
||||
|
||||
|
||||
def remove_cross_attention_control(model, is_running_diffusers: bool):
|
||||
def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
print("** need to know what cross attn processor to use by default, None in the following line is wrong")
|
||||
unet.set_attn_processor(CrossAttnProcessor())
|
||||
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
@ -343,10 +343,16 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
old_attn_processors = unet.attn_processors
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -509,13 +515,11 @@ class CrossAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
import enum
|
||||
from dataclasses import field, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
|
||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -523,7 +527,7 @@ class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[])
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||
|
||||
def __int__(self,
|
||||
cac_types_to_do: [CrossAttentionType],
|
||||
@ -629,9 +633,6 @@ class SwapCrossAttnProcessor(CrossAttnProcessor):
|
||||
|
||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
def __init__(self, slice_size = 1e6):
|
||||
self.slice_count = slice_size
|
||||
|
||||
# TODO: dynamically pick slice size based on memory conditions
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||
@ -660,7 +661,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||
modified_text_key = attn.head_to_batch_dim(original_text_key)
|
||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||
|
||||
# for the "value" just use the modified text embeddings.
|
||||
value = attn.to_v(modified_text_embeddings)
|
||||
|
@ -1,11 +1,13 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union, Any
|
||||
from typing import Callable, Optional, Union, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent:
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
do_attention_map_saving: bool):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
self.remove_cross_attention_control(old_attn_processor)
|
||||
# TODO resuscitate attention map saving
|
||||
#self.remove_attention_map_saving()
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
setup_cross_attention_control(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
return setup_cross_attention_control(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
|
||||
def remove_cross_attention_control(self):
|
||||
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
|
||||
remove_cross_attention_control(self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
|
Loading…
Reference in New Issue
Block a user