InvokeAI/ldm/models/diffusion/cross_attention_map_saving.py
Damian Stewart 786b8878d6
Save and display per-token attention maps (#1866)
* attention maps saving to /tmp

* tidy up diffusers branch backporting of cross attention refactoring

* base64-encoding the attention maps image for generationResult

* cleanup/refactor conditioning.py

* attention maps and tokens being sent to web UI

* attention maps: restrict count to actual token count and improve robustness

* add argument type hint to image_to_dataURL function

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

Co-authored-by: damian <git@damianstewart.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2022-12-10 15:57:41 +01:00

96 lines
3.9 KiB
Python

import math
import PIL
import torch
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
class AttentionMapSaver():
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f'{key}_{maps.shape[1]}'
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
pil_image.save(path, 'PNG')
def get_stacked_maps_image(self) -> PIL.Image:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked)*(1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xff).byte()
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')