Save and display per-token attention maps (#1866)

* attention maps saving to /tmp

* tidy up diffusers branch backporting of cross attention refactoring

* base64-encoding the attention maps image for generationResult

* cleanup/refactor conditioning.py

* attention maps and tokens being sent to web UI

* attention maps: restrict count to actual token count and improve robustness

* add argument type hint to image_to_dataURL function

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

Co-authored-by: damian <git@damianstewart.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Damian Stewart
2022-12-10 15:57:41 +01:00
committed by GitHub
parent 55132f6463
commit 786b8878d6
13 changed files with 636 additions and 346 deletions

View File

@ -4,6 +4,7 @@ import k_diffusion as K
import torch
from torch import nn
from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
@ -106,12 +108,12 @@ class KSampler(Sampler):
else:
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch.
@torch.no_grad()
def decode(
self,
@ -145,7 +147,7 @@ class KSampler(Sampler):
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
@ -158,6 +160,7 @@ class KSampler(Sampler):
callback=None,
normals_sequence=None,
img_callback=None,
attention_maps_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
@ -171,7 +174,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -204,6 +207,12 @@ class KSampler(Sampler):
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
if attention_maps_callback is not None:
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
@ -217,6 +226,8 @@ class KSampler(Sampler):
),
None,
)
if attention_maps_callback is not None:
attention_maps_callback(attention_maps_saver)
return sampling_result
# this code will support inpainting if and when ksampler API modified or
@ -248,7 +259,7 @@ class KSampler(Sampler):
# terrible, confusing names here
steps = self.ddim_num_steps
t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps.
@ -280,7 +291,7 @@ class KSampler(Sampler):
return x_T + x
else:
return x
def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc
self.model_wrap = None