diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 8fde088b36..a6f46a3834 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -34,8 +34,8 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.seamless import set_seamless from ...backend.model_management.models import BaseModelType +from ...backend.model_management.seamless import set_seamless from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, @@ -43,7 +43,9 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor, ) -from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings +from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import ( + PostprocessingSettings, +) from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import choose_precision, choose_torch_device from ..models.image import ImageCategory, ResourceOrigin @@ -485,9 +487,12 @@ class DenoiseLatentsInvocation(BaseInvocation): **self.unet.unet.dict(), context=context, ) - with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( - unet_info.context.model, _lora_loader() - ), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet: + with ( + ExitStack() as exit_stack, + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), + set_seamless(unet_info.context.model, self.unet.seamless_axes), + unet_info as unet, + ): latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) @@ -524,7 +529,7 @@ class DenoiseLatentsInvocation(BaseInvocation): denoising_end=self.denoising_end, ) - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + result_latents = pipeline.latents_from_embeddings( latents=latents, timesteps=timesteps, init_timestep=init_timestep, diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index a958750802..a3d5c72a2d 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -7,9 +7,8 @@ from .diffusers_pipeline import ( # noqa: F401 StableDiffusionGeneratorPipeline, ) from .diffusion import InvokeAIDiffuserComponent # noqa: F401 -from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 from .diffusion.shared_invokeai_diffusion import ( # noqa: F401 - PostprocessingSettings, BasicConditioningInfo, + PostprocessingSettings, SDXLConditioningInfo, ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index d88313f455..766fbe668f 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -5,14 +5,13 @@ import inspect from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Union -import PIL.Image import einops +import PIL.Image import psutil import torch import torchvision.transforms as T from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.controlnet import ControlNetModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) @@ -27,13 +26,13 @@ from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config import InvokeAIAppConfig + +from ..util import auto_detect_slice_size, normalize_device from .diffusion import ( - AttentionMapSaver, + BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings, - BasicConditioningInfo, ) -from ..util import normalize_device, auto_detect_slice_size @dataclass @@ -44,7 +43,6 @@ class PipelineIntermediateState: timestep: int latents: torch.Tensor predicted_original: Optional[torch.Tensor] = None - attention_map_saver: Optional[AttentionMapSaver] = None @dataclass @@ -103,7 +101,7 @@ class AddsMaskGuidance: # Mask anything that has the same shape as prev_sample, return others as-is. return output_class( { - k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v) + k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v for k, v in step_output.items() } ) @@ -205,18 +203,6 @@ class ConditioningData: return dataclasses.replace(self, scheduler_args=scheduler_args) -@dataclass -class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): - r""" - Output class for InvokeAI's Stable Diffusion pipeline. - - Args: - attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user - after generation completes. Optional. - """ - attention_map_saver: Optional[AttentionMapSaver] - - class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -360,7 +346,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): mask: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None, seed: Optional[int] = None, - ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: + ) -> torch.Tensor: if init_timestep.shape[0] == 0: return latents, None @@ -402,7 +388,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise)) try: - latents, attention_map_saver = self.generate_latents_from_embeddings( + latents = self.generate_latents_from_embeddings( latents, timesteps, conditioning_data, @@ -417,7 +403,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if mask is not None: latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) - return latents, attention_map_saver + return latents def generate_latents_from_embeddings( self, @@ -434,16 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance = [] batch_size = latents.shape[0] - attention_map_saver: Optional[AttentionMapSaver] = None if timesteps.shape[0] == 0: - return latents, attention_map_saver + return latents extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), ): if callback is not None: callback( @@ -480,13 +464,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): predicted_original = getattr(step_output, "pred_original_sample", None) - # TODO resuscitate attention map saving - # if i == len(timesteps)-1 and extra_conditioning_info is not None: - # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - # attention_map_token_ids = range(1, eos_token_index) - # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) - # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - if callback is not None: callback( PipelineIntermediateState( @@ -496,11 +473,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timestep=int(t), latents=latents, predicted_original=predicted_original, - attention_map_saver=attention_map_saver, ) ) - return latents, attention_map_saver + return latents @torch.inference_mode() def step( diff --git a/invokeai/backend/stable_diffusion/diffusion/__init__.py b/invokeai/backend/stable_diffusion/diffusion/__init__.py index 2bcc595889..1d789a2706 100644 --- a/invokeai/backend/stable_diffusion/diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/diffusion/__init__.py @@ -1,11 +1,9 @@ """ Initialization file for invokeai.models.diffusion """ -from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401 -from .cross_attention_map_saving import AttentionMapSaver # noqa: F401 from .shared_invokeai_diffusion import ( # noqa: F401 + BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings, - BasicConditioningInfo, SDXLConditioningInfo, ) diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index 35d4800859..aabc92a85b 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -5,22 +5,14 @@ import enum import math from dataclasses import dataclass, field -from typing import Callable, Optional +from typing import Optional -import diffusers -import psutil import torch from compel.cross_attention_control import Arguments +from diffusers.models.attention_processor import Attention, SlicedAttnProcessor from diffusers.models.unet_2d_condition import UNet2DConditionModel -from diffusers.models.attention_processor import AttentionProcessor -from diffusers.models.attention_processor import ( - Attention, - AttnProcessor, - SlicedAttnProcessor, -) from torch import nn -import invokeai.backend.util.logging as logger from ...util import torch_dtype @@ -33,68 +25,14 @@ class Context: cross_attention_mask: Optional[torch.Tensor] cross_attention_index_map: Optional[torch.Tensor] - class Action(enum.Enum): - NONE = 0 - SAVE = (1,) - APPLY = 2 - - def __init__(self, arguments: Arguments, step_count: int): + def __init__(self, arguments: Arguments): """ :param arguments: Arguments for the cross-attention control process - :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) """ self.cross_attention_mask = None self.cross_attention_index_map = None - self.self_cross_attention_action = Context.Action.NONE - self.tokens_cross_attention_action = Context.Action.NONE + self.arguments = arguments - self.step_count = step_count - - self.self_cross_attention_module_identifiers = [] - self.tokens_cross_attention_module_identifiers = [] - - self.saved_cross_attention_maps = {} - - self.clear_requests(cleanup=True) - - def register_cross_attention_modules(self, model): - for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF): - if name in self.self_cross_attention_module_identifiers: - assert False, f"name {name} cannot appear more than once" - self.self_cross_attention_module_identifiers.append(name) - for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): - if name in self.tokens_cross_attention_module_identifiers: - assert False, f"name {name} cannot appear more than once" - self.tokens_cross_attention_module_identifiers.append(name) - - def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): - if cross_attention_type == CrossAttentionType.SELF: - self.self_cross_attention_action = Context.Action.SAVE - else: - self.tokens_cross_attention_action = Context.Action.SAVE - - def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType): - if cross_attention_type == CrossAttentionType.SELF: - self.self_cross_attention_action = Context.Action.APPLY - else: - self.tokens_cross_attention_action = Context.Action.APPLY - - def is_tokens_cross_attention(self, module_identifier) -> bool: - return module_identifier in self.tokens_cross_attention_module_identifiers - - def get_should_save_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == Context.Action.SAVE - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == Context.Action.SAVE - return False - - def get_should_apply_saved_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == Context.Action.APPLY - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == Context.Action.APPLY - return False def get_active_cross_attention_control_types_for_step( self, percent_through: float = None @@ -115,217 +53,6 @@ class Context: to_control.append(CrossAttentionType.TOKENS) return to_control - def save_slice( - self, - identifier: str, - slice: torch.Tensor, - dim: Optional[int], - offset: int, - slice_size: Optional[int], - ): - if identifier not in self.saved_cross_attention_maps: - self.saved_cross_attention_maps[identifier] = { - "dim": dim, - "slice_size": slice_size, - "slices": {offset or 0: slice}, - } - else: - self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice - - def get_slice( - self, - identifier: str, - requested_dim: Optional[int], - requested_offset: int, - slice_size: int, - ): - saved_attention_dict = self.saved_cross_attention_maps[identifier] - if requested_dim is None: - if saved_attention_dict["dim"] is not None: - raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") - return saved_attention_dict["slices"][0] - - if saved_attention_dict["dim"] == requested_dim: - if slice_size != saved_attention_dict["slice_size"]: - raise RuntimeError( - f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}" - ) - return saved_attention_dict["slices"][requested_offset] - - if saved_attention_dict["dim"] is None: - whole_saved_attention = saved_attention_dict["slices"][0] - if requested_dim == 0: - return whole_saved_attention[requested_offset : requested_offset + slice_size] - elif requested_dim == 1: - return whole_saved_attention[:, requested_offset : requested_offset + slice_size] - - raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") - - def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: - saved_attention = self.saved_cross_attention_maps.get(identifier, None) - if saved_attention is None: - return None, None - return saved_attention["dim"], saved_attention["slice_size"] - - def clear_requests(self, cleanup=True): - self.tokens_cross_attention_action = Context.Action.NONE - self.self_cross_attention_action = Context.Action.NONE - if cleanup: - self.saved_cross_attention_maps = {} - - def offload_saved_attention_slices_to_cpu(self): - for key, map_dict in self.saved_cross_attention_maps.items(): - for offset, slice in map_dict["slices"].items(): - map_dict[offset] = slice.to("cpu") - - -class InvokeAICrossAttentionMixin: - """ - Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls - through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling - and dymamic slicing strategy selection. - """ - - def __init__(self): - self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - self.attention_slice_wrangler = None - self.slicing_strategy_getter = None - self.attention_slice_calculated_callback = None - - def set_attention_slice_wrangler( - self, - wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]], - ): - """ - Set custom attention calculator to be called when attention is calculated - :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), - which returns either the suggested_attention_slice or an adjusted equivalent. - `module` is the current Attention module for which the callback is being invoked. - `suggested_attention_slice` is the default-calculated attention slice - `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. - If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. - - Pass None to use the default attention calculation. - :return: - """ - self.attention_slice_wrangler = wrangler - - def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]): - self.slicing_strategy_getter = getter - - def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]): - self.attention_slice_calculated_callback = callback - - def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): - # calculate attention scores - # attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) - attention_scores = torch.baddbmm( - torch.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - # calculate attention slice by taking the best scores for each latent pixel - default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) - attention_slice_wrangler = self.attention_slice_wrangler - if attention_slice_wrangler is not None: - attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) - else: - attention_slice = default_attention_slice - - if self.attention_slice_calculated_callback is not None: - self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size) - - hidden_states = torch.bmm(attention_slice, value) - return hidden_states - - def einsum_op_slice_dim0(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], slice_size): - end = i + slice_size - r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) - return r - - def einsum_op_slice_dim1(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) - return r - - def einsum_op_mps_v1(self, q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - def einsum_op_mps_v2(self, q, k, v): - if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - return self.einsum_op_slice_dim0(q, k, v, 1) - - def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): - size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) - if size_mb <= max_tensor_mb: - return self.einsum_lowest_level(q, k, v, None, None, None) - div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() - if div <= q.shape[0]: - return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) - - def einsum_op_cuda(self, q, k, v): - # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) - slicing_strategy_getter = self.slicing_strategy_getter - if slicing_strategy_getter is not None: - (dim, slice_size) = slicing_strategy_getter(self) - if dim is not None: - # print("using saved slicing strategy with dim", dim, "slice size", slice_size) - if dim == 0: - return self.einsum_op_slice_dim0(q, k, v, slice_size) - elif dim == 1: - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - # fallback for when there is no saved strategy, or saved strategy does not slice - mem_free_total = get_mem_free_total(q.device) - # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - - def get_invokeai_attention_mem_efficient(self, q, k, v): - if q.device.type == "cuda": - # print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) - return self.einsum_op_cuda(q, k, v) - - if q.device.type == "mps" or q.device.type == "cpu": - if self.mem_total_gb >= 32: - return self.einsum_op_mps_v1(q, k, v) - return self.einsum_op_mps_v2(q, k, v) - - # Smaller slices are faster due to L2/L3/SLC caches. - # Tested on i7 with 8MB L3 cache. - return self.einsum_op_tensor_mem(q, k, v, 32) - - -def restore_default_cross_attention( - model, - is_running_diffusers: bool, - restore_attention_processor: Optional[AttentionProcessor] = None, -): - if is_running_diffusers: - unet = model - unet.set_attn_processor(restore_attention_processor or AttnProcessor()) - else: - remove_attention_function(model) - def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context): """ @@ -366,136 +93,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) -def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: - cross_attention_class: type = InvokeAIDiffusersCrossAttention - which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" - attention_module_tuples = [ - (name, module) - for name, module in model.named_modules() - if isinstance(module, cross_attention_class) and which_attn in name - ] - cross_attention_modules_in_model_count = len(attention_module_tuples) - expected_count = 16 - if cross_attention_modules_in_model_count != expected_count: - # non-fatal error but .swap() won't work. - logger.error( - f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " - + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " - + "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " - + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " - + "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " - + "work properly until it is fixed." - ) - return attention_module_tuples - - -def inject_attention_function(unet, context: Context): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - - def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size): - # memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() - - attention_slice = suggested_attention_slice - - if context.get_should_save_maps(module.identifier): - # print(module.identifier, "saving suggested_attention_slice of shape", - # suggested_attention_slice.shape, "dim", dim, "offset", offset) - slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice - context.save_slice( - module.identifier, - slice_to_save, - dim=dim, - offset=offset, - slice_size=slice_size, - ) - elif context.get_should_apply_saved_maps(module.identifier): - # print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) - saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) - - # slice may have been offloaded to CPU - saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) - - if context.is_tokens_cross_attention(module.identifier): - index_map = context.cross_attention_index_map - remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) - this_attention_slice = suggested_attention_slice - - mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) - saved_mask = mask - this_mask = 1 - mask - attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask - else: - # just use everything - attention_slice = saved_attention_slice - - return attention_slice - - cross_attention_modules = get_cross_attention_modules( - unet, CrossAttentionType.TOKENS - ) + get_cross_attention_modules(unet, CrossAttentionType.SELF) - for identifier, module in cross_attention_modules: - module.identifier = identifier - try: - module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) - except AttributeError as e: - if is_attribute_error_about(e, "set_attention_slice_wrangler"): - print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO - else: - raise - - -def remove_attention_function(unet): - cross_attention_modules = get_cross_attention_modules( - unet, CrossAttentionType.TOKENS - ) + get_cross_attention_modules(unet, CrossAttentionType.SELF) - for identifier, module in cross_attention_modules: - try: - # clear wrangler callback - module.set_attention_slice_wrangler(None) - module.set_slicing_strategy_getter(None) - except AttributeError as e: - if is_attribute_error_about(e, "set_attention_slice_wrangler"): - print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") - else: - raise - - -def is_attribute_error_about(error: AttributeError, attribute: str): - if hasattr(error, "name"): # Python 3.10 - return error.name == attribute - else: # Python 3.9 - return attribute in str(error) - - -def get_mem_free_total(device): - # only on cuda - if not torch.cuda.is_available(): - return None - stats = torch.cuda.memory_stats(device) - mem_active = stats["active_bytes.all.current"] - mem_reserved = stats["reserved_bytes.all.current"] - mem_free_cuda, _ = torch.cuda.mem_get_info(device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - return mem_free_total - - -class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin): - def __init__(self, **kwargs): - super().__init__(**kwargs) - InvokeAICrossAttentionMixin.__init__(self) - - def _attention(self, query, key, value, attention_mask=None): - # default_result = super()._attention(query, key, value) - if attention_mask is not None: - print(f"{type(self).__name__} ignoring passed-in attention_mask") - attention_result = self.get_invokeai_attention_mem_efficient(query, key, value) - - hidden_states = self.reshape_batch_dim_to_heads(attention_result) - return hidden_states - - ## 🧨diffusers implementation follows diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py deleted file mode 100644 index abef979b1c..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py +++ /dev/null @@ -1,98 +0,0 @@ -import math - -import PIL -import torch -from torchvision.transforms.functional import InterpolationMode -from torchvision.transforms.functional import resize as tv_resize - - -class AttentionMapSaver: - def __init__(self, token_ids: range, latents_shape: torch.Size): - self.token_ids = token_ids - self.latents_shape = latents_shape - # self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]]) - self.collated_maps = {} - - def clear_maps(self): - self.collated_maps = {} - - def add_attention_maps(self, maps: torch.Tensor, key: str): - """ - Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any). - :param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77). - :param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match. - :return: None - """ - key_and_size = f"{key}_{maps.shape[1]}" - - # extract desired tokens - maps = maps[:, :, self.token_ids] - - # merge attention heads to a single map per token - maps = torch.sum(maps, 0) - - # store - if key_and_size not in self.collated_maps: - self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu") - self.collated_maps[key_and_size] += maps.cpu() - - def write_maps_to_disk(self, path: str): - pil_image = self.get_stacked_maps_image() - pil_image.save(path, "PNG") - - def get_stacked_maps_image(self) -> PIL.Image: - """ - Scale all collected attention maps to the same size, blend them together and return as an image. - :return: An image containing a vertical stack of blended attention maps, one for each requested token. - """ - num_tokens = len(self.token_ids) - if num_tokens == 0: - return None - - latents_height = self.latents_shape[0] - latents_width = self.latents_shape[1] - - merged = None - - for key, maps in self.collated_maps.items(): - # maps has shape [(H*W), N] for N tokens - # but we want [N, H, W] - this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height)) - this_maps_height = int(float(latents_height) * this_scale_factor) - this_maps_width = int(float(latents_width) * this_scale_factor) - # and we need to do some dimension juggling - maps = torch.reshape( - torch.swapdims(maps, 0, 1), - [num_tokens, this_maps_height, this_maps_width], - ) - - # scale to output size if necessary - if this_scale_factor != 1: - maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC) - - # normalize - maps_min = torch.min(maps) - maps_range = torch.max(maps) - maps_min - # print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}") - maps_normalized = (maps - maps_min) / maps_range - # expand to (-0.1, 1.1) and clamp - maps_normalized_expanded = maps_normalized * 1.1 - 0.05 - maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1) - - # merge together, producing a vertical stack - maps_stacked = torch.reshape( - maps_normalized_expanded_clamped, - [num_tokens * latents_height, latents_width], - ) - - if merged is None: - merged = maps_stacked - else: - # screen blend - merged = 1 - (1 - maps_stacked) * (1 - merged) - - if merged is None: - return None - - merged_bytes = merged.mul(0xFF).byte() - return PIL.Image.fromarray(merged_bytes.numpy(), mode="L") diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f05adafca2..64c29573ec 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,8 +1,8 @@ from __future__ import annotations +import math from contextlib import contextmanager from dataclasses import dataclass -import math from typing import Any, Callable, Optional, Union import torch @@ -14,12 +14,9 @@ from invokeai.app.services.config import InvokeAIAppConfig from .cross_attention_control import ( Arguments, Context, - CrossAttentionType, SwapCrossAttnContext, - get_cross_attention_modules, setup_cross_attention_control_attention_processors, ) -from .cross_attention_map_saving import AttentionMapSaver ModelForwardCallback: TypeAlias = Union[ # x, t, conditioning, Optional[cross-attention kwargs] @@ -105,7 +102,6 @@ class InvokeAIDiffuserComponent: self, unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs extra_conditioning_info: Optional[ExtraConditioningInfo], - step_count: int, ): old_attn_processors = None if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control): @@ -114,7 +110,6 @@ class InvokeAIDiffuserComponent: if extra_conditioning_info.wants_cross_attention_control: self.cross_attention_control_context = Context( arguments=extra_conditioning_info.cross_attention_control_args, - step_count=step_count, ) setup_cross_attention_control_attention_processors( unet, @@ -127,27 +122,6 @@ class InvokeAIDiffuserComponent: self.cross_attention_control_context = None if old_attn_processors is not None: unet.set_attn_processor(old_attn_processors) - # TODO resuscitate attention map saving - # self.remove_attention_map_saving() - - def setup_attention_map_saving(self, saver: AttentionMapSaver): - def callback(slice, dim, offset, slice_size, key): - if dim is not None: - # sliced tokens attention map saving is not implemented - return - saver.add_attention_maps(slice, key) - - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) - for identifier, module in tokens_cross_attention_modules: - key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid" - module.set_attention_slice_calculated_callback( - lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key) - ) - - def remove_attention_map_saving(self): - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) - for _, module in tokens_cross_attention_modules: - module.set_attention_slice_calculated_callback(None) def do_controlnet_step( self,