mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
786b8878d6
* attention maps saving to /tmp * tidy up diffusers branch backporting of cross attention refactoring * base64-encoding the attention maps image for generationResult * cleanup/refactor conditioning.py * attention maps and tokens being sent to web UI * attention maps: restrict count to actual token count and improve robustness * add argument type hint to image_to_dataURL function Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: damian <git@damianstewart.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
import math
|
|
|
|
import PIL
|
|
import torch
|
|
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
|
|
|
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
|
|
|
|
|
class AttentionMapSaver():
|
|
|
|
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
|
self.token_ids = token_ids
|
|
self.latents_shape = latents_shape
|
|
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
|
self.collated_maps = {}
|
|
|
|
def clear_maps(self):
|
|
self.collated_maps = {}
|
|
|
|
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
|
"""
|
|
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
|
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
|
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
|
:return: None
|
|
"""
|
|
key_and_size = f'{key}_{maps.shape[1]}'
|
|
|
|
# extract desired tokens
|
|
maps = maps[:, :, self.token_ids]
|
|
|
|
# merge attention heads to a single map per token
|
|
maps = torch.sum(maps, 0)
|
|
|
|
# store
|
|
if key_and_size not in self.collated_maps:
|
|
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
|
self.collated_maps[key_and_size] += maps.cpu()
|
|
|
|
def write_maps_to_disk(self, path: str):
|
|
pil_image = self.get_stacked_maps_image()
|
|
pil_image.save(path, 'PNG')
|
|
|
|
def get_stacked_maps_image(self) -> PIL.Image:
|
|
"""
|
|
Scale all collected attention maps to the same size, blend them together and return as an image.
|
|
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
|
"""
|
|
num_tokens = len(self.token_ids)
|
|
if num_tokens == 0:
|
|
return None
|
|
|
|
latents_height = self.latents_shape[0]
|
|
latents_width = self.latents_shape[1]
|
|
|
|
merged = None
|
|
|
|
for key, maps in self.collated_maps.items():
|
|
|
|
# maps has shape [(H*W), N] for N tokens
|
|
# but we want [N, H, W]
|
|
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
|
this_maps_height = int(float(latents_height) * this_scale_factor)
|
|
this_maps_width = int(float(latents_width) * this_scale_factor)
|
|
# and we need to do some dimension juggling
|
|
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
|
|
|
# scale to output size if necessary
|
|
if this_scale_factor != 1:
|
|
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
|
|
|
# normalize
|
|
maps_min = torch.min(maps)
|
|
maps_range = torch.max(maps) - maps_min
|
|
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
|
maps_normalized = (maps - maps_min) / maps_range
|
|
# expand to (-0.1, 1.1) and clamp
|
|
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
|
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
|
|
|
# merge together, producing a vertical stack
|
|
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
|
|
|
if merged is None:
|
|
merged = maps_stacked
|
|
else:
|
|
# screen blend
|
|
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
|
|
|
if merged is None:
|
|
return None
|
|
|
|
merged_bytes = merged.mul(0xff).byte()
|
|
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|