diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 9e9cb2d1c7..d9125f0f37 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -775,10 +775,7 @@ class DenoiseLatentsInvocation(BaseInvocation): denoising_end=self.denoising_end, ) - ( - result_latents, - result_attention_map_saver, - ) = pipeline.latents_from_embeddings( + result_latents = pipeline.latents_from_embeddings( latents=latents, timesteps=timesteps, init_timestep=init_timestep, diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 8b3f701064..ed6782eefa 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -4,13 +4,11 @@ Initialization file for the invokeai.backend.stable_diffusion package from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401 -from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 from .seamless import set_seamless # noqa: F401 __all__ = [ "PipelineIntermediateState", "StableDiffusionGeneratorPipeline", "InvokeAIDiffuserComponent", - "AttentionMapSaver", "set_seamless", ] diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 538e0ea990..9a08787878 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -12,7 +12,6 @@ import torch import torchvision.transforms as T from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.controlnet import ControlNetModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers @@ -26,9 +25,9 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData +from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..util import auto_detect_slice_size, normalize_device -from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent @dataclass @@ -39,7 +38,6 @@ class PipelineIntermediateState: timestep: int latents: torch.Tensor predicted_original: Optional[torch.Tensor] = None - attention_map_saver: Optional[AttentionMapSaver] = None @dataclass @@ -190,19 +188,6 @@ class T2IAdapterData: end_step_percent: float = Field(default=1.0) -@dataclass -class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): - r""" - Output class for InvokeAI's Stable Diffusion pipeline. - - Args: - attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user - after generation completes. Optional. - """ - - attention_map_saver: Optional[AttentionMapSaver] - - class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -343,9 +328,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): masked_latents: Optional[torch.Tensor] = None, gradient_mask: Optional[bool] = False, seed: Optional[int] = None, - ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: + ) -> torch.Tensor: if init_timestep.shape[0] == 0: - return latents, None + return latents if additional_guidance is None: additional_guidance = [] @@ -385,7 +370,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)) try: - latents, attention_map_saver = self.generate_latents_from_embeddings( + latents = self.generate_latents_from_embeddings( latents, timesteps, conditioning_data, @@ -402,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if mask is not None and not gradient_mask: latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) - return latents, attention_map_saver + return latents def generate_latents_from_embeddings( self, @@ -415,16 +400,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, callback: Callable[[PipelineIntermediateState], None] = None, - ): + ) -> torch.Tensor: self._adjust_memory_efficient_attention(latents) if additional_guidance is None: additional_guidance = [] batch_size = latents.shape[0] - attention_map_saver: Optional[AttentionMapSaver] = None if timesteps.shape[0] == 0: - return latents, attention_map_saver + return latents ip_adapter_unet_patcher = None extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning @@ -432,7 +416,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), ) self.use_ip_adapter = False elif ip_adapter_data is not None: @@ -483,13 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): predicted_original = getattr(step_output, "pred_original_sample", None) - # TODO resuscitate attention map saving - # if i == len(timesteps)-1 and extra_conditioning_info is not None: - # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - # attention_map_token_ids = range(1, eos_token_index) - # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) - # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - if callback is not None: callback( PipelineIntermediateState( @@ -499,11 +475,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timestep=int(t), latents=latents, predicted_original=predicted_original, - attention_map_saver=attention_map_saver, ) ) - return latents, attention_map_saver + return latents @torch.inference_mode() def step( diff --git a/invokeai/backend/stable_diffusion/diffusion/__init__.py b/invokeai/backend/stable_diffusion/diffusion/__init__.py index e68340168a..854d127a36 100644 --- a/invokeai/backend/stable_diffusion/diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/diffusion/__init__.py @@ -2,6 +2,4 @@ Initialization file for invokeai.models.diffusion """ -from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401 -from .cross_attention_map_saving import AttentionMapSaver # noqa: F401 from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401 diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index 2bbee87f09..4278f08bff 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -3,19 +3,13 @@ import enum -import math from dataclasses import dataclass, field -from typing import Callable, Optional +from typing import Optional -import diffusers -import psutil import torch from compel.cross_attention_control import Arguments -from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor +from diffusers.models.attention_processor import Attention, SlicedAttnProcessor from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from torch import nn - -import invokeai.backend.util.logging as logger from ...util import torch_dtype @@ -25,72 +19,14 @@ class CrossAttentionType(enum.Enum): TOKENS = 2 -class Context: - cross_attention_mask: Optional[torch.Tensor] - cross_attention_index_map: Optional[torch.Tensor] - - class Action(enum.Enum): - NONE = 0 - SAVE = (1,) - APPLY = 2 - - def __init__(self, arguments: Arguments, step_count: int): +class CrossAttnControlContext: + def __init__(self, arguments: Arguments): """ :param arguments: Arguments for the cross-attention control process - :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) """ - self.cross_attention_mask = None - self.cross_attention_index_map = None - self.self_cross_attention_action = Context.Action.NONE - self.tokens_cross_attention_action = Context.Action.NONE + self.cross_attention_mask: Optional[torch.Tensor] = None + self.cross_attention_index_map: Optional[torch.Tensor] = None self.arguments = arguments - self.step_count = step_count - - self.self_cross_attention_module_identifiers = [] - self.tokens_cross_attention_module_identifiers = [] - - self.saved_cross_attention_maps = {} - - self.clear_requests(cleanup=True) - - def register_cross_attention_modules(self, model): - for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF): - if name in self.self_cross_attention_module_identifiers: - raise AssertionError(f"name {name} cannot appear more than once") - self.self_cross_attention_module_identifiers.append(name) - for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): - if name in self.tokens_cross_attention_module_identifiers: - raise AssertionError(f"name {name} cannot appear more than once") - self.tokens_cross_attention_module_identifiers.append(name) - - def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): - if cross_attention_type == CrossAttentionType.SELF: - self.self_cross_attention_action = Context.Action.SAVE - else: - self.tokens_cross_attention_action = Context.Action.SAVE - - def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType): - if cross_attention_type == CrossAttentionType.SELF: - self.self_cross_attention_action = Context.Action.APPLY - else: - self.tokens_cross_attention_action = Context.Action.APPLY - - def is_tokens_cross_attention(self, module_identifier) -> bool: - return module_identifier in self.tokens_cross_attention_module_identifiers - - def get_should_save_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == Context.Action.SAVE - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == Context.Action.SAVE - return False - - def get_should_apply_saved_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == Context.Action.APPLY - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == Context.Action.APPLY - return False def get_active_cross_attention_control_types_for_step( self, percent_through: float = None @@ -111,219 +47,8 @@ class Context: to_control.append(CrossAttentionType.TOKENS) return to_control - def save_slice( - self, - identifier: str, - slice: torch.Tensor, - dim: Optional[int], - offset: int, - slice_size: Optional[int], - ): - if identifier not in self.saved_cross_attention_maps: - self.saved_cross_attention_maps[identifier] = { - "dim": dim, - "slice_size": slice_size, - "slices": {offset or 0: slice}, - } - else: - self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice - def get_slice( - self, - identifier: str, - requested_dim: Optional[int], - requested_offset: int, - slice_size: int, - ): - saved_attention_dict = self.saved_cross_attention_maps[identifier] - if requested_dim is None: - if saved_attention_dict["dim"] is not None: - raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") - return saved_attention_dict["slices"][0] - - if saved_attention_dict["dim"] == requested_dim: - if slice_size != saved_attention_dict["slice_size"]: - raise RuntimeError( - f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}" - ) - return saved_attention_dict["slices"][requested_offset] - - if saved_attention_dict["dim"] is None: - whole_saved_attention = saved_attention_dict["slices"][0] - if requested_dim == 0: - return whole_saved_attention[requested_offset : requested_offset + slice_size] - elif requested_dim == 1: - return whole_saved_attention[:, requested_offset : requested_offset + slice_size] - - raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") - - def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: - saved_attention = self.saved_cross_attention_maps.get(identifier, None) - if saved_attention is None: - return None, None - return saved_attention["dim"], saved_attention["slice_size"] - - def clear_requests(self, cleanup=True): - self.tokens_cross_attention_action = Context.Action.NONE - self.self_cross_attention_action = Context.Action.NONE - if cleanup: - self.saved_cross_attention_maps = {} - - def offload_saved_attention_slices_to_cpu(self): - for _key, map_dict in self.saved_cross_attention_maps.items(): - for offset, slice in map_dict["slices"].items(): - map_dict[offset] = slice.to("cpu") - - -class InvokeAICrossAttentionMixin: - """ - Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls - through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling - and dymamic slicing strategy selection. - """ - - def __init__(self): - self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - self.attention_slice_wrangler = None - self.slicing_strategy_getter = None - self.attention_slice_calculated_callback = None - - def set_attention_slice_wrangler( - self, - wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]], - ): - """ - Set custom attention calculator to be called when attention is calculated - :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), - which returns either the suggested_attention_slice or an adjusted equivalent. - `module` is the current Attention module for which the callback is being invoked. - `suggested_attention_slice` is the default-calculated attention slice - `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. - If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. - - Pass None to use the default attention calculation. - :return: - """ - self.attention_slice_wrangler = wrangler - - def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]): - self.slicing_strategy_getter = getter - - def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]): - self.attention_slice_calculated_callback = callback - - def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): - # calculate attention scores - # attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) - attention_scores = torch.baddbmm( - torch.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - # calculate attention slice by taking the best scores for each latent pixel - default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) - attention_slice_wrangler = self.attention_slice_wrangler - if attention_slice_wrangler is not None: - attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) - else: - attention_slice = default_attention_slice - - if self.attention_slice_calculated_callback is not None: - self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size) - - hidden_states = torch.bmm(attention_slice, value) - return hidden_states - - def einsum_op_slice_dim0(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[0], slice_size): - end = i + slice_size - r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) - return r - - def einsum_op_slice_dim1(self, q, k, v, slice_size): - r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) - return r - - def einsum_op_mps_v1(self, q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - def einsum_op_mps_v2(self, q, k, v): - if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_lowest_level(q, k, v, None, None, None) - else: - return self.einsum_op_slice_dim0(q, k, v, 1) - - def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): - size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) - if size_mb <= max_tensor_mb: - return self.einsum_lowest_level(q, k, v, None, None, None) - div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() - if div <= q.shape[0]: - return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) - - def einsum_op_cuda(self, q, k, v): - # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) - slicing_strategy_getter = self.slicing_strategy_getter - if slicing_strategy_getter is not None: - (dim, slice_size) = slicing_strategy_getter(self) - if dim is not None: - # print("using saved slicing strategy with dim", dim, "slice size", slice_size) - if dim == 0: - return self.einsum_op_slice_dim0(q, k, v, slice_size) - elif dim == 1: - return self.einsum_op_slice_dim1(q, k, v, slice_size) - - # fallback for when there is no saved strategy, or saved strategy does not slice - mem_free_total = get_mem_free_total(q.device) - # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - - def get_invokeai_attention_mem_efficient(self, q, k, v): - if q.device.type == "cuda": - # print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) - return self.einsum_op_cuda(q, k, v) - - if q.device.type == "mps" or q.device.type == "cpu": - if self.mem_total_gb >= 32: - return self.einsum_op_mps_v1(q, k, v) - return self.einsum_op_mps_v2(q, k, v) - - # Smaller slices are faster due to L2/L3/SLC caches. - # Tested on i7 with 8MB L3 cache. - return self.einsum_op_tensor_mem(q, k, v, 32) - - -def restore_default_cross_attention( - model, - is_running_diffusers: bool, - restore_attention_processor: Optional[AttentionProcessor] = None, -): - if is_running_diffusers: - unet = model - unet.set_attn_processor(restore_attention_processor or AttnProcessor()) - else: - remove_attention_function(model) - - -def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context): +def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -362,170 +87,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) -def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: - cross_attention_class: type = InvokeAIDiffusersCrossAttention - which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" - attention_module_tuples = [ - (name, module) - for name, module in model.named_modules() - if isinstance(module, cross_attention_class) and which_attn in name - ] - cross_attention_modules_in_model_count = len(attention_module_tuples) - expected_count = 16 - if cross_attention_modules_in_model_count != expected_count: - # non-fatal error but .swap() won't work. - logger.error( - f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " - f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching " - "failed or some assumption has changed about the structure of the model itself. Please fix the " - f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and " - "inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and " - "attention map display will not work properly until it is fixed." - ) - return attention_module_tuples - - -def inject_attention_function(unet, context: Context): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - - def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size): - # memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() - - attention_slice = suggested_attention_slice - - if context.get_should_save_maps(module.identifier): - # print(module.identifier, "saving suggested_attention_slice of shape", - # suggested_attention_slice.shape, "dim", dim, "offset", offset) - slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice - context.save_slice( - module.identifier, - slice_to_save, - dim=dim, - offset=offset, - slice_size=slice_size, - ) - elif context.get_should_apply_saved_maps(module.identifier): - # print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) - saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) - - # slice may have been offloaded to CPU - saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) - - if context.is_tokens_cross_attention(module.identifier): - index_map = context.cross_attention_index_map - remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) - this_attention_slice = suggested_attention_slice - - mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) - saved_mask = mask - this_mask = 1 - mask - attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask - else: - # just use everything - attention_slice = saved_attention_slice - - return attention_slice - - cross_attention_modules = get_cross_attention_modules( - unet, CrossAttentionType.TOKENS - ) + get_cross_attention_modules(unet, CrossAttentionType.SELF) - for identifier, module in cross_attention_modules: - module.identifier = identifier - try: - module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023 - except AttributeError as e: - if is_attribute_error_about(e, "set_attention_slice_wrangler"): - print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO - else: - raise - - -def remove_attention_function(unet): - cross_attention_modules = get_cross_attention_modules( - unet, CrossAttentionType.TOKENS - ) + get_cross_attention_modules(unet, CrossAttentionType.SELF) - for _identifier, module in cross_attention_modules: - try: - # clear wrangler callback - module.set_attention_slice_wrangler(None) - module.set_slicing_strategy_getter(None) - except AttributeError as e: - if is_attribute_error_about(e, "set_attention_slice_wrangler"): - print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") - else: - raise - - -def is_attribute_error_about(error: AttributeError, attribute: str): - if hasattr(error, "name"): # Python 3.10 - return error.name == attribute - else: # Python 3.9 - return attribute in str(error) - - -def get_mem_free_total(device): - # only on cuda - if not torch.cuda.is_available(): - return None - stats = torch.cuda.memory_stats(device) - mem_active = stats["active_bytes.all.current"] - mem_reserved = stats["reserved_bytes.all.current"] - mem_free_cuda, _ = torch.cuda.mem_get_info(device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - return mem_free_total - - -class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin): - def __init__(self, **kwargs): - super().__init__(**kwargs) - InvokeAICrossAttentionMixin.__init__(self) - - def _attention(self, query, key, value, attention_mask=None): - # default_result = super()._attention(query, key, value) - if attention_mask is not None: - print(f"{type(self).__name__} ignoring passed-in attention_mask") - attention_result = self.get_invokeai_attention_mem_efficient(query, key, value) - - hidden_states = self.reshape_batch_dim_to_heads(attention_result) - return hidden_states - - -## 🧨diffusers implementation follows - - -""" -# base implementation - -class AttnProcessor: - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - -""" - - @dataclass class SwapCrossAttnContext: modified_text_embeddings: torch.Tensor diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py deleted file mode 100644 index 82c9f1dcea..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py +++ /dev/null @@ -1,100 +0,0 @@ -import math -from typing import Optional - -import torch -from PIL import Image -from torchvision.transforms.functional import InterpolationMode -from torchvision.transforms.functional import resize as tv_resize - - -class AttentionMapSaver: - def __init__(self, token_ids: range, latents_shape: torch.Size): - self.token_ids = token_ids - self.latents_shape = latents_shape - # self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]]) - self.collated_maps: dict[str, torch.Tensor] = {} - - def clear_maps(self): - self.collated_maps = {} - - def add_attention_maps(self, maps: torch.Tensor, key: str): - """ - Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any). - :param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77). - :param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match. - :return: None - """ - key_and_size = f"{key}_{maps.shape[1]}" - - # extract desired tokens - maps = maps[:, :, self.token_ids] - - # merge attention heads to a single map per token - maps = torch.sum(maps, 0) - - # store - if key_and_size not in self.collated_maps: - self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu") - self.collated_maps[key_and_size] += maps.cpu() - - def write_maps_to_disk(self, path: str): - pil_image = self.get_stacked_maps_image() - if pil_image is not None: - pil_image.save(path, "PNG") - - def get_stacked_maps_image(self) -> Optional[Image.Image]: - """ - Scale all collected attention maps to the same size, blend them together and return as an image. - :return: An image containing a vertical stack of blended attention maps, one for each requested token. - """ - num_tokens = len(self.token_ids) - if num_tokens == 0: - return None - - latents_height = self.latents_shape[0] - latents_width = self.latents_shape[1] - - merged = None - - for _key, maps in self.collated_maps.items(): - # maps has shape [(H*W), N] for N tokens - # but we want [N, H, W] - this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height)) - this_maps_height = int(float(latents_height) * this_scale_factor) - this_maps_width = int(float(latents_width) * this_scale_factor) - # and we need to do some dimension juggling - maps = torch.reshape( - torch.swapdims(maps, 0, 1), - [num_tokens, this_maps_height, this_maps_width], - ) - - # scale to output size if necessary - if this_scale_factor != 1: - maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC) - - # normalize - maps_min = torch.min(maps) - maps_range = torch.max(maps) - maps_min - # print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}") - maps_normalized = (maps - maps_min) / maps_range - # expand to (-0.1, 1.1) and clamp - maps_normalized_expanded = maps_normalized * 1.1 - 0.05 - maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1) - - # merge together, producing a vertical stack - maps_stacked = torch.reshape( - maps_normalized_expanded_clamped, - [num_tokens * latents_height, latents_width], - ) - - if merged is None: - merged = maps_stacked - else: - # screen blend - merged = 1 - (1 - maps_stacked) * (1 - merged) - - if merged is None: - return None - - merged_bytes = merged.mul(0xFF).byte() - return Image.fromarray(merged_bytes.numpy(), mode="L") diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index c6b85d2bd6..58ab16bae8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -17,13 +17,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ) from .cross_attention_control import ( - Context, CrossAttentionType, + CrossAttnControlContext, SwapCrossAttnContext, - get_cross_attention_modules, setup_cross_attention_control_attention_processors, ) -from .cross_attention_map_saving import AttentionMapSaver ModelForwardCallback: TypeAlias = Union[ # x, t, conditioning, Optional[cross-attention kwargs] @@ -69,14 +67,12 @@ class InvokeAIDiffuserComponent: self, unet: UNet2DConditionModel, extra_conditioning_info: Optional[ExtraConditioningInfo], - step_count: int, ): old_attn_processors = unet.attn_processors try: - self.cross_attention_control_context = Context( + self.cross_attention_control_context = CrossAttnControlContext( arguments=extra_conditioning_info.cross_attention_control_args, - step_count=step_count, ) setup_cross_attention_control_attention_processors( unet, @@ -87,27 +83,6 @@ class InvokeAIDiffuserComponent: finally: self.cross_attention_control_context = None unet.set_attn_processor(old_attn_processors) - # TODO resuscitate attention map saving - # self.remove_attention_map_saving() - - def setup_attention_map_saving(self, saver: AttentionMapSaver): - def callback(slice, dim, offset, slice_size, key): - if dim is not None: - # sliced tokens attention map saving is not implemented - return - saver.add_attention_maps(slice, key) - - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) - for identifier, module in tokens_cross_attention_modules: - key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid" - module.set_attention_slice_calculated_callback( - lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key) - ) - - def remove_attention_map_saving(self): - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) - for _, module in tokens_cross_attention_modules: - module.set_attention_slice_calculated_callback(None) def do_controlnet_step( self,