attention maps callback stuff for diffusers

This commit is contained in:
Damian Stewart 2022-12-14 21:04:55 +01:00
parent 2c6db2e77c
commit 23eb80b404
8 changed files with 63 additions and 151 deletions

View File

@ -63,7 +63,6 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
attention_maps_callback = None,
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker

View File

@ -7,12 +7,14 @@ from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any,
import PIL.Image
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers.models import attention
from diffusers.utils.import_utils import is_xformers_available
from ...models.diffusion import cross_attention_control
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
# monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work
@ -41,6 +43,7 @@ class PipelineIntermediateState:
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
attention_map_saver: Optional[AttentionMapSaver] = None
# copied from configs/stable-diffusion/v1-inference.yaml
@ -180,6 +183,17 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
@ -255,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*, callback: Callable[[PipelineIntermediateState], None]=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
run_id=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
@ -273,7 +287,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
:param run_id:
:param extra_step_kwargs:
"""
result_latents = self.latents_from_embeddings(
result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
run_id=run_id, callback=callback, **extra_step_kwargs
@ -283,7 +297,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
return self.check_for_safety(output, dtype=text_embeddings.dtype)
def latents_from_embeddings(
@ -302,13 +316,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
return infer_latents_from_embeddings(
result: PipelineIntermediateState = infer_latents_from_embeddings(
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
additional_guidance=additional_guidance,
run_id=run_id,
callback=callback,
**extra_step_kwargs).latents
**extra_step_kwargs)
return result.latents, result.attention_map_saver
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
@ -334,6 +349,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
attention_map_saver: AttentionMapSaver = None
self.invokeai_diffuser.remove_attention_map_saving()
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, guidance_scale,
@ -342,9 +359,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**extra_step_kwargs)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
if i == len(timesteps)-1 and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original)
return latents
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
self.invokeai_diffuser.remove_attention_map_saving()
return latents, attention_map_saver
@torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
@ -393,7 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
run_id=None,
noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
@ -412,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
noise_func, run_id=None, callback=None, **extra_step_kwargs):
noise_func, run_id=None, callback=None, **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
batch_size = initial_latents.size(0)
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
@ -423,7 +449,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep)
latents = noised_latents
result_latents = self.latents_from_embeddings(
result_latents, result_attention_maps = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
@ -435,7 +461,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
return self.check_for_safety(output, dtype=text_embeddings.dtype)
def inpaint_from_embeddings(
@ -450,7 +476,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
run_id=None,
noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = self.unet.dtype
batch_size = 1
@ -493,7 +519,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
try:
result_latents = self.latents_from_embeddings(
result_latents, result_attention_maps = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
@ -508,7 +534,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
return self.check_for_safety(output, dtype=text_embeddings.dtype)
def non_noised_latents_from_image(self, init_image, *, device, dtype):
@ -523,7 +549,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, device=self._execution_device, dtype=dtype)
return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
return InvokeAIStableDiffusionPipelineOutput(screened_images,
has_nsfw_concept,
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver)
@torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):

View File

@ -14,7 +14,9 @@ class Img2Img(Generator):
self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
@ -35,7 +37,8 @@ class Img2Img(Generator):
noise_func=self.get_noise_like,
callback=step_callback
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@ -273,6 +273,8 @@ class Inpaint(Img2Img):
callback=step_callback,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
result = pipeline.numpy_to_pil(pipeline_output.images)[0]
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)

View File

@ -37,14 +37,12 @@ class Txt2Img(Generator):
unconditioned_embeddings=uc,
guidance_scale=cfg_scale,
callback=step_callback,
extra_conditioning_info=extra_conditioning_info,
extra_conditioning_info=extra_conditioning_info
# TODO: eta = ddim_eta,
# TODO: threshold = threshold,
# FIXME: Attention Maps Callback merged from main, but not hooked up
# in diffusers branch yet. - keturn
# attention_maps_callback = attention_maps_callback,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
def make_image(x_T):
first_pass_latent_output = pipeline.latents_from_embeddings(
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=x_T,
num_inference_steps=steps,
text_embeddings=c,

View File

@ -442,128 +442,6 @@ def get_mem_free_total(device):
return mem_free_total
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
if dim is not None:
print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}")
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps' or q.device.type == 'cpu':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs):

View File

@ -209,12 +209,12 @@ class KSampler(Sampler):
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
attention_maps_saver = None
attention_map_saver = None
if attention_maps_callback is not None and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
extra_args = {
'cond': conditioning,
@ -229,8 +229,8 @@ class KSampler(Sampler):
),
None,
)
if attention_maps_saver is not None:
attention_maps_callback(attention_maps_saver)
if attention_map_saver is not None:
attention_maps_callback(attention_map_saver)
return sampling_result
# this code will support inpainting if and when ksampler API modified or