From 23eb80b40421b2bb8f4b6d3dd30490d11c447b36 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Dec 2022 21:04:55 +0100 Subject: [PATCH] attention maps callback stuff for diffusers --- ldm/invoke/generator/base.py | 1 - ldm/invoke/generator/diffusers_pipeline.py | 62 ++++++--- ldm/invoke/generator/img2img.py | 7 +- ldm/invoke/generator/inpaint.py | 2 + ldm/invoke/generator/txt2img.py | 8 +- ldm/invoke/generator/txt2img2img.py | 2 +- .../diffusion/cross_attention_control.py | 122 ------------------ ldm/models/diffusion/ksampler.py | 10 +- 8 files changed, 63 insertions(+), 151 deletions(-) diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index f207b3dc24..68c5ccdeff 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -63,7 +63,6 @@ class Generator: def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, - attention_maps_callback = None, **kwargs): scope = choose_autocast(self.precision) self.safety_checker = safety_checker diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 85689f0927..410761cd79 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -7,12 +7,14 @@ from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any, import PIL.Image import einops +import numpy as np import torch import torchvision.transforms as T from diffusers.models import attention from diffusers.utils.import_utils import is_xformers_available from ...models.diffusion import cross_attention_control +from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver # monkeypatch diffusers CrossAttention 🙈 # this is to make prompt2prompt and (future) attention maps work @@ -41,6 +43,7 @@ class PipelineIntermediateState: timestep: int latents: torch.Tensor predicted_original: Optional[torch.Tensor] = None + attention_map_saver: Optional[AttentionMapSaver] = None # copied from configs/stable-diffusion/v1-inference.yaml @@ -180,6 +183,17 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): raise AssertionError("why was that an empty generator?") return result +@dataclass +class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): + r""" + Output class for InvokeAI's Stable Diffusion pipeline. + + Args: + attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user + after generation completes. Optional. + """ + attention_map_saver: Optional[AttentionMapSaver] + class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): r""" @@ -255,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): *, callback: Callable[[PipelineIntermediateState], None]=None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, run_id=None, - **extra_step_kwargs) -> StableDiffusionPipelineOutput: + **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput: r""" Function invoked when calling the pipeline for generation. @@ -273,7 +287,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): :param run_id: :param extra_step_kwargs: """ - result_latents = self.latents_from_embeddings( + result_latents, result_attention_map_saver = self.latents_from_embeddings( latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, extra_conditioning_info=extra_conditioning_info, run_id=run_id, callback=callback, **extra_step_kwargs @@ -283,7 +297,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): with torch.inference_mode(): image = self.decode_latents(result_latents) - output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver) return self.check_for_safety(output, dtype=text_embeddings.dtype) def latents_from_embeddings( @@ -302,13 +316,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) timesteps = self.scheduler.timesteps infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) - return infer_latents_from_embeddings( + result: PipelineIntermediateState = infer_latents_from_embeddings( latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale, extra_conditioning_info=extra_conditioning_info, additional_guidance=additional_guidance, run_id=run_id, callback=callback, - **extra_step_kwargs).latents + **extra_step_kwargs) + return result.latents, result.attention_map_saver def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, guidance_scale: float, *, @@ -334,6 +349,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): batched_t = torch.full((batch_size,), timesteps[0], dtype=timesteps.dtype, device=self.unet.device) + attention_map_saver: AttentionMapSaver = None + self.invokeai_diffuser.remove_attention_map_saving() for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step(batched_t, latents, guidance_scale, @@ -342,9 +359,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): **extra_step_kwargs) latents = step_output.prev_sample predicted_original = getattr(step_output, 'pred_original_sample', None) + + if i == len(timesteps)-1 and extra_conditioning_info is not None: + eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 + attention_map_token_ids = range(1, eos_token_index) + attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) + self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) + yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, - predicted_original=predicted_original) - return latents + predicted_original=predicted_original, attention_map_saver=attention_map_saver) + + self.invokeai_diffuser.remove_attention_map_saving() + return latents, attention_map_saver @torch.inference_mode() def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, @@ -393,7 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, run_id=None, noise_func=None, - **extra_step_kwargs) -> StableDiffusionPipelineOutput: + **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput: if isinstance(init_image, PIL.Image.Image): init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) @@ -412,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info, - noise_func, run_id=None, callback=None, **extra_step_kwargs): + noise_func, run_id=None, callback=None, **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput: device = self.unet.device batch_size = initial_latents.size(0) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) @@ -423,7 +449,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep) latents = noised_latents - result_latents = self.latents_from_embeddings( + result_latents, result_attention_maps = self.latents_from_embeddings( latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, extra_conditioning_info=extra_conditioning_info, timesteps=timesteps, @@ -435,7 +461,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): with torch.inference_mode(): image = self.decode_latents(result_latents) - output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) return self.check_for_safety(output, dtype=text_embeddings.dtype) def inpaint_from_embeddings( @@ -450,7 +476,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, run_id=None, noise_func=None, - **extra_step_kwargs) -> StableDiffusionPipelineOutput: + **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput: device = self.unet.device latents_dtype = self.unet.dtype batch_size = 1 @@ -493,7 +519,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise)) try: - result_latents = self.latents_from_embeddings( + result_latents, result_attention_maps = self.latents_from_embeddings( latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, extra_conditioning_info=extra_conditioning_info, timesteps=timesteps, @@ -508,7 +534,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): with torch.inference_mode(): image = self.decode_latents(result_latents) - output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) return self.check_for_safety(output, dtype=text_embeddings.dtype) def non_noised_latents_from_image(self, init_image, *, device, dtype): @@ -523,7 +549,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): with torch.inference_mode(): screened_images, has_nsfw_concept = self.run_safety_checker( output.images, device=self._execution_device, dtype=dtype) - return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept) + screened_attention_map_saver = None + if has_nsfw_concept is None or not has_nsfw_concept: + screened_attention_map_saver = output.attention_map_saver + return InvokeAIStableDiffusionPipelineOutput(screened_images, + has_nsfw_concept, + # block the attention maps if NSFW content is detected + attention_map_saver=screened_attention_map_saver) @torch.inference_mode() def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 6ea41fda33..1a470d1ebf 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -14,7 +14,9 @@ class Img2Img(Generator): self.init_latent = None # by get_noise() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): + conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0, + attention_maps_callback=None, + **kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it. @@ -35,7 +37,8 @@ class Img2Img(Generator): noise_func=self.get_noise_like, callback=step_callback ) - + if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: + attention_maps_callback(pipeline_output.attention_map_saver) return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 976121d720..79fbd542c1 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -273,6 +273,8 @@ class Inpaint(Img2Img): callback=step_callback, ) + if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: + attention_maps_callback(pipeline_output.attention_map_saver) result = pipeline.numpy_to_pil(pipeline_output.images)[0] # Seam paint if this is our first pass (seam_size set to 0 during seam painting) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index ef3d35dbca..6c4c7a3f13 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -37,14 +37,12 @@ class Txt2Img(Generator): unconditioned_embeddings=uc, guidance_scale=cfg_scale, callback=step_callback, - extra_conditioning_info=extra_conditioning_info, + extra_conditioning_info=extra_conditioning_info # TODO: eta = ddim_eta, # TODO: threshold = threshold, - # FIXME: Attention Maps Callback merged from main, but not hooked up - # in diffusers branch yet. - keturn - # attention_maps_callback = attention_maps_callback, ) - + if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: + attention_maps_callback(pipeline_output.attention_map_saver) return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 29a7106246..56ebcc5bf4 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -36,7 +36,7 @@ class Txt2Img2Img(Generator): def make_image(x_T): - first_pass_latent_output = pipeline.latents_from_embeddings( + first_pass_latent_output, _ = pipeline.latents_from_embeddings( latents=x_T, num_inference_steps=steps, text_embeddings=c, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index b32ccecae7..66c5567ebd 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -442,128 +442,6 @@ def get_mem_free_total(device): return mem_free_total -class InvokeAICrossAttentionMixin: - """ - Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls - through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling - and dymamic slicing strategy selection. - """ - def __init__(self): - self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - self.attention_slice_wrangler = None - self.slicing_strategy_getter = None - - def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): - ''' - Set custom attention calculator to be called when attention is calculated - :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), - which returns either the suggested_attention_slice or an adjusted equivalent. - `module` is the current CrossAttention module for which the callback is being invoked. - `suggested_attention_slice` is the default-calculated attention slice - `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. - If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. - - Pass None to use the default attention calculation. - :return: - ''' - self.attention_slice_wrangler = wrangler - - def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): - self.slicing_strategy_getter = getter - - def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): - # calculate attention scores - #attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) - if dim is not None: - print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}") - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - # calculate attention slice by taking the best scores for each latent pixel - default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) - attention_slice_wrangler = self.attention_slice_wrangler - if attention_slice_wrangler is not None: - attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) - else: - attention_slice = default_attention_slice - - hidden_states = torch.bmm(attention_slice, value) - return hidden_states - - def einsum_op_slice_dim0(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], slice_size): - end = i + slice_size - r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) - return r - - def einsum_op_slice_dim1(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) - return r - - def einsum_op_mps_v1(self, q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - def einsum_op_mps_v2(self, q, k, v): - if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - return self.einsum_op_slice_dim0(q, k, v, 1) - - def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): - size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) - if size_mb <= max_tensor_mb: - return self.einsum_lowest_level(q, k, v, None, None, None) - div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() - if div <= q.shape[0]: - return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) - - def einsum_op_cuda(self, q, k, v): - # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) - slicing_strategy_getter = self.slicing_strategy_getter - if slicing_strategy_getter is not None: - (dim, slice_size) = slicing_strategy_getter(self) - if dim is not None: - # print("using saved slicing strategy with dim", dim, "slice size", slice_size) - if dim == 0: - return self.einsum_op_slice_dim0(q, k, v, slice_size) - elif dim == 1: - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - # fallback for when there is no saved strategy, or saved strategy does not slice - mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device) - # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - - - def get_invokeai_attention_mem_efficient(self, q, k, v): - if q.device.type == 'cuda': - #print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) - return self.einsum_op_cuda(q, k, v) - - if q.device.type == 'mps' or q.device.type == 'cpu': - if self.mem_total_gb >= 32: - return self.einsum_op_mps_v1(q, k, v) - return self.einsum_op_mps_v2(q, k, v) - - # Smaller slices are faster due to L2/L3/SLC caches. - # Tested on i7 with 8MB L3 cache. - return self.einsum_op_tensor_mem(q, k, v, 32) - - class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): def __init__(self, **kwargs): diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 336ce1d7a0..0038c481e8 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -209,12 +209,12 @@ class KSampler(Sampler): model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) # setup attention maps saving. checks for None are because there are multiple code paths to get here. - attention_maps_saver = None + attention_map_saver = None if attention_maps_callback is not None and extra_conditioning_info is not None: eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 attention_map_token_ids = range(1, eos_token_index) - attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:]) - model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver) + attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:]) + model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) extra_args = { 'cond': conditioning, @@ -229,8 +229,8 @@ class KSampler(Sampler): ), None, ) - if attention_maps_saver is not None: - attention_maps_callback(attention_maps_saver) + if attention_map_saver is not None: + attention_maps_callback(attention_map_saver) return sampling_result # this code will support inpainting if and when ksampler API modified or