mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Delete unused code for attention map saving.
This commit is contained in:
parent
a98c37b7a3
commit
a22331fbe9
@ -34,8 +34,8 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.seamless import set_seamless
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
|
from ...backend.model_management.seamless import set_seamless
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
@ -43,7 +43,9 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||||
|
PostprocessingSettings,
|
||||||
|
)
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
@ -485,9 +487,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
**self.unet.unet.dict(),
|
**self.unet.unet.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
with (
|
||||||
unet_info.context.model, _lora_loader()
|
ExitStack() as exit_stack,
|
||||||
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
||||||
|
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||||
|
unet_info as unet,
|
||||||
|
):
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
@ -524,7 +529,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
init_timestep=init_timestep,
|
init_timestep=init_timestep,
|
||||||
|
@ -7,9 +7,8 @@ from .diffusers_pipeline import ( # noqa: F401
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
|
||||||
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
||||||
PostprocessingSettings,
|
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
|
PostprocessingSettings,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
@ -5,14 +5,13 @@ import inspect
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import PIL.Image
|
|
||||||
import einops
|
import einops
|
||||||
|
import PIL.Image
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.controlnet import ControlNetModel
|
from diffusers.models.controlnet import ControlNetModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
@ -27,13 +26,13 @@ from pydantic import Field
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
from .diffusion import (
|
from .diffusion import (
|
||||||
AttentionMapSaver,
|
BasicConditioningInfo,
|
||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
BasicConditioningInfo,
|
|
||||||
)
|
)
|
||||||
from ..util import normalize_device, auto_detect_slice_size
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -44,7 +43,6 @@ class PipelineIntermediateState:
|
|||||||
timestep: int
|
timestep: int
|
||||||
latents: torch.Tensor
|
latents: torch.Tensor
|
||||||
predicted_original: Optional[torch.Tensor] = None
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -103,7 +101,7 @@ class AddsMaskGuidance:
|
|||||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||||
return output_class(
|
return output_class(
|
||||||
{
|
{
|
||||||
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
|
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
|
||||||
for k, v in step_output.items()
|
for k, v in step_output.items()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -205,18 +203,6 @@ class ConditioningData:
|
|||||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
|
||||||
r"""
|
|
||||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
|
||||||
after generation completes. Optional.
|
|
||||||
"""
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver]
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-image generation using Stable Diffusion.
|
Pipeline for text-to-image generation using Stable Diffusion.
|
||||||
@ -360,7 +346,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> torch.Tensor:
|
||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
return latents, None
|
return latents, None
|
||||||
|
|
||||||
@ -402,7 +388,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
latents = self.generate_latents_from_embeddings(
|
||||||
latents,
|
latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
@ -417,7 +403,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||||
|
|
||||||
return latents, attention_map_saver
|
return latents
|
||||||
|
|
||||||
def generate_latents_from_embeddings(
|
def generate_latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -434,16 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
|
||||||
|
|
||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents, attention_map_saver
|
return latents
|
||||||
|
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
with self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
|
||||||
):
|
):
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
@ -480,13 +464,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
# TODO resuscitate attention map saving
|
|
||||||
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
|
||||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
|
||||||
# attention_map_token_ids = range(1, eos_token_index)
|
|
||||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
|
||||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
@ -496,11 +473,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timestep=int(t),
|
timestep=int(t),
|
||||||
latents=latents,
|
latents=latents,
|
||||||
predicted_original=predicted_original,
|
predicted_original=predicted_original,
|
||||||
attention_map_saver=attention_map_saver,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return latents, attention_map_saver
|
return latents
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(
|
def step(
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.models.diffusion
|
Initialization file for invokeai.models.diffusion
|
||||||
"""
|
"""
|
||||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
|
||||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
|
||||||
from .shared_invokeai_diffusion import ( # noqa: F401
|
from .shared_invokeai_diffusion import ( # noqa: F401
|
||||||
|
BasicConditioningInfo,
|
||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
BasicConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
@ -5,22 +5,14 @@
|
|||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import psutil
|
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
|
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
|
||||||
from diffusers.models.attention_processor import (
|
|
||||||
Attention,
|
|
||||||
AttnProcessor,
|
|
||||||
SlicedAttnProcessor,
|
|
||||||
)
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -33,68 +25,14 @@ class Context:
|
|||||||
cross_attention_mask: Optional[torch.Tensor]
|
cross_attention_mask: Optional[torch.Tensor]
|
||||||
cross_attention_index_map: Optional[torch.Tensor]
|
cross_attention_index_map: Optional[torch.Tensor]
|
||||||
|
|
||||||
class Action(enum.Enum):
|
def __init__(self, arguments: Arguments):
|
||||||
NONE = 0
|
|
||||||
SAVE = (1,)
|
|
||||||
APPLY = 2
|
|
||||||
|
|
||||||
def __init__(self, arguments: Arguments, step_count: int):
|
|
||||||
"""
|
"""
|
||||||
:param arguments: Arguments for the cross-attention control process
|
:param arguments: Arguments for the cross-attention control process
|
||||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
|
||||||
"""
|
"""
|
||||||
self.cross_attention_mask = None
|
self.cross_attention_mask = None
|
||||||
self.cross_attention_index_map = None
|
self.cross_attention_index_map = None
|
||||||
self.self_cross_attention_action = Context.Action.NONE
|
|
||||||
self.tokens_cross_attention_action = Context.Action.NONE
|
|
||||||
self.arguments = arguments
|
self.arguments = arguments
|
||||||
self.step_count = step_count
|
|
||||||
|
|
||||||
self.self_cross_attention_module_identifiers = []
|
|
||||||
self.tokens_cross_attention_module_identifiers = []
|
|
||||||
|
|
||||||
self.saved_cross_attention_maps = {}
|
|
||||||
|
|
||||||
self.clear_requests(cleanup=True)
|
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
|
||||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
|
||||||
if name in self.self_cross_attention_module_identifiers:
|
|
||||||
assert False, f"name {name} cannot appear more than once"
|
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
|
||||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
|
||||||
if name in self.tokens_cross_attention_module_identifiers:
|
|
||||||
assert False, f"name {name} cannot appear more than once"
|
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
|
||||||
|
|
||||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
|
||||||
if cross_attention_type == CrossAttentionType.SELF:
|
|
||||||
self.self_cross_attention_action = Context.Action.SAVE
|
|
||||||
else:
|
|
||||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
|
||||||
|
|
||||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
|
||||||
if cross_attention_type == CrossAttentionType.SELF:
|
|
||||||
self.self_cross_attention_action = Context.Action.APPLY
|
|
||||||
else:
|
|
||||||
self.tokens_cross_attention_action = Context.Action.APPLY
|
|
||||||
|
|
||||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
|
||||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
|
||||||
|
|
||||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
|
||||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
|
||||||
return self.self_cross_attention_action == Context.Action.SAVE
|
|
||||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
|
||||||
return self.tokens_cross_attention_action == Context.Action.SAVE
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
|
||||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
|
||||||
return self.self_cross_attention_action == Context.Action.APPLY
|
|
||||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
|
||||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_active_cross_attention_control_types_for_step(
|
def get_active_cross_attention_control_types_for_step(
|
||||||
self, percent_through: float = None
|
self, percent_through: float = None
|
||||||
@ -115,217 +53,6 @@ class Context:
|
|||||||
to_control.append(CrossAttentionType.TOKENS)
|
to_control.append(CrossAttentionType.TOKENS)
|
||||||
return to_control
|
return to_control
|
||||||
|
|
||||||
def save_slice(
|
|
||||||
self,
|
|
||||||
identifier: str,
|
|
||||||
slice: torch.Tensor,
|
|
||||||
dim: Optional[int],
|
|
||||||
offset: int,
|
|
||||||
slice_size: Optional[int],
|
|
||||||
):
|
|
||||||
if identifier not in self.saved_cross_attention_maps:
|
|
||||||
self.saved_cross_attention_maps[identifier] = {
|
|
||||||
"dim": dim,
|
|
||||||
"slice_size": slice_size,
|
|
||||||
"slices": {offset or 0: slice},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
|
|
||||||
|
|
||||||
def get_slice(
|
|
||||||
self,
|
|
||||||
identifier: str,
|
|
||||||
requested_dim: Optional[int],
|
|
||||||
requested_offset: int,
|
|
||||||
slice_size: int,
|
|
||||||
):
|
|
||||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
|
||||||
if requested_dim is None:
|
|
||||||
if saved_attention_dict["dim"] is not None:
|
|
||||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
|
||||||
return saved_attention_dict["slices"][0]
|
|
||||||
|
|
||||||
if saved_attention_dict["dim"] == requested_dim:
|
|
||||||
if slice_size != saved_attention_dict["slice_size"]:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
|
|
||||||
)
|
|
||||||
return saved_attention_dict["slices"][requested_offset]
|
|
||||||
|
|
||||||
if saved_attention_dict["dim"] is None:
|
|
||||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
|
||||||
if requested_dim == 0:
|
|
||||||
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
|
||||||
elif requested_dim == 1:
|
|
||||||
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
|
||||||
|
|
||||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
|
||||||
|
|
||||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
|
||||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
|
||||||
if saved_attention is None:
|
|
||||||
return None, None
|
|
||||||
return saved_attention["dim"], saved_attention["slice_size"]
|
|
||||||
|
|
||||||
def clear_requests(self, cleanup=True):
|
|
||||||
self.tokens_cross_attention_action = Context.Action.NONE
|
|
||||||
self.self_cross_attention_action = Context.Action.NONE
|
|
||||||
if cleanup:
|
|
||||||
self.saved_cross_attention_maps = {}
|
|
||||||
|
|
||||||
def offload_saved_attention_slices_to_cpu(self):
|
|
||||||
for key, map_dict in self.saved_cross_attention_maps.items():
|
|
||||||
for offset, slice in map_dict["slices"].items():
|
|
||||||
map_dict[offset] = slice.to("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAICrossAttentionMixin:
|
|
||||||
"""
|
|
||||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
|
||||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
|
||||||
and dymamic slicing strategy selection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
||||||
self.attention_slice_wrangler = None
|
|
||||||
self.slicing_strategy_getter = None
|
|
||||||
self.attention_slice_calculated_callback = None
|
|
||||||
|
|
||||||
def set_attention_slice_wrangler(
|
|
||||||
self,
|
|
||||||
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Set custom attention calculator to be called when attention is calculated
|
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
|
||||||
`module` is the current Attention module for which the callback is being invoked.
|
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
|
||||||
|
|
||||||
Pass None to use the default attention calculation.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
self.attention_slice_wrangler = wrangler
|
|
||||||
|
|
||||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
|
||||||
self.slicing_strategy_getter = getter
|
|
||||||
|
|
||||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
|
||||||
self.attention_slice_calculated_callback = callback
|
|
||||||
|
|
||||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
|
||||||
# calculate attention scores
|
|
||||||
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
|
||||||
attention_scores = torch.baddbmm(
|
|
||||||
torch.empty(
|
|
||||||
query.shape[0],
|
|
||||||
query.shape[1],
|
|
||||||
key.shape[1],
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device,
|
|
||||||
),
|
|
||||||
query,
|
|
||||||
key.transpose(-1, -2),
|
|
||||||
beta=0,
|
|
||||||
alpha=self.scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate attention slice by taking the best scores for each latent pixel
|
|
||||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
|
||||||
attention_slice_wrangler = self.attention_slice_wrangler
|
|
||||||
if attention_slice_wrangler is not None:
|
|
||||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
|
||||||
else:
|
|
||||||
attention_slice = default_attention_slice
|
|
||||||
|
|
||||||
if self.attention_slice_calculated_callback is not None:
|
|
||||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
|
||||||
|
|
||||||
hidden_states = torch.bmm(attention_slice, value)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[0], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_mps_v1(self, q, k, v):
|
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v):
|
|
||||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
|
||||||
|
|
||||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
|
||||||
if size_mb <= max_tensor_mb:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
|
||||||
if div <= q.shape[0]:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
|
||||||
|
|
||||||
def einsum_op_cuda(self, q, k, v):
|
|
||||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
|
||||||
slicing_strategy_getter = self.slicing_strategy_getter
|
|
||||||
if slicing_strategy_getter is not None:
|
|
||||||
(dim, slice_size) = slicing_strategy_getter(self)
|
|
||||||
if dim is not None:
|
|
||||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
|
||||||
if dim == 0:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
|
||||||
elif dim == 1:
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
|
||||||
mem_free_total = get_mem_free_total(q.device)
|
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
||||||
|
|
||||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
|
||||||
if q.device.type == "cuda":
|
|
||||||
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
|
||||||
return self.einsum_op_cuda(q, k, v)
|
|
||||||
|
|
||||||
if q.device.type == "mps" or q.device.type == "cpu":
|
|
||||||
if self.mem_total_gb >= 32:
|
|
||||||
return self.einsum_op_mps_v1(q, k, v)
|
|
||||||
return self.einsum_op_mps_v2(q, k, v)
|
|
||||||
|
|
||||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
|
||||||
# Tested on i7 with 8MB L3 cache.
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
|
||||||
|
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
|
||||||
model,
|
|
||||||
is_running_diffusers: bool,
|
|
||||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
|
||||||
):
|
|
||||||
if is_running_diffusers:
|
|
||||||
unet = model
|
|
||||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
|
||||||
else:
|
|
||||||
remove_attention_function(model)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||||
"""
|
"""
|
||||||
@ -366,136 +93,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
|||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
|
|
||||||
|
|
||||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
|
||||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
|
||||||
attention_module_tuples = [
|
|
||||||
(name, module)
|
|
||||||
for name, module in model.named_modules()
|
|
||||||
if isinstance(module, cross_attention_class) and which_attn in name
|
|
||||||
]
|
|
||||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
|
||||||
expected_count = 16
|
|
||||||
if cross_attention_modules_in_model_count != expected_count:
|
|
||||||
# non-fatal error but .swap() won't work.
|
|
||||||
logger.error(
|
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
|
||||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
|
||||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
|
||||||
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
|
||||||
+ "work properly until it is fixed."
|
|
||||||
)
|
|
||||||
return attention_module_tuples
|
|
||||||
|
|
||||||
|
|
||||||
def inject_attention_function(unet, context: Context):
|
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
|
||||||
|
|
||||||
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
|
||||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
|
||||||
|
|
||||||
attention_slice = suggested_attention_slice
|
|
||||||
|
|
||||||
if context.get_should_save_maps(module.identifier):
|
|
||||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
|
||||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
|
||||||
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
|
||||||
context.save_slice(
|
|
||||||
module.identifier,
|
|
||||||
slice_to_save,
|
|
||||||
dim=dim,
|
|
||||||
offset=offset,
|
|
||||||
slice_size=slice_size,
|
|
||||||
)
|
|
||||||
elif context.get_should_apply_saved_maps(module.identifier):
|
|
||||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
|
||||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
|
||||||
|
|
||||||
# slice may have been offloaded to CPU
|
|
||||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
|
||||||
|
|
||||||
if context.is_tokens_cross_attention(module.identifier):
|
|
||||||
index_map = context.cross_attention_index_map
|
|
||||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
|
||||||
this_attention_slice = suggested_attention_slice
|
|
||||||
|
|
||||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
|
||||||
saved_mask = mask
|
|
||||||
this_mask = 1 - mask
|
|
||||||
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
|
||||||
else:
|
|
||||||
# just use everything
|
|
||||||
attention_slice = saved_attention_slice
|
|
||||||
|
|
||||||
return attention_slice
|
|
||||||
|
|
||||||
cross_attention_modules = get_cross_attention_modules(
|
|
||||||
unet, CrossAttentionType.TOKENS
|
|
||||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
|
||||||
for identifier, module in cross_attention_modules:
|
|
||||||
module.identifier = identifier
|
|
||||||
try:
|
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
|
||||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
|
|
||||||
except AttributeError as e:
|
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
|
||||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def remove_attention_function(unet):
|
|
||||||
cross_attention_modules = get_cross_attention_modules(
|
|
||||||
unet, CrossAttentionType.TOKENS
|
|
||||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
|
||||||
for identifier, module in cross_attention_modules:
|
|
||||||
try:
|
|
||||||
# clear wrangler callback
|
|
||||||
module.set_attention_slice_wrangler(None)
|
|
||||||
module.set_slicing_strategy_getter(None)
|
|
||||||
except AttributeError as e:
|
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
|
||||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
|
||||||
if hasattr(error, "name"): # Python 3.10
|
|
||||||
return error.name == attribute
|
|
||||||
else: # Python 3.9
|
|
||||||
return attribute in str(error)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mem_free_total(device):
|
|
||||||
# only on cuda
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
return None
|
|
||||||
stats = torch.cuda.memory_stats(device)
|
|
||||||
mem_active = stats["active_bytes.all.current"]
|
|
||||||
mem_reserved = stats["reserved_bytes.all.current"]
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
return mem_free_total
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
InvokeAICrossAttentionMixin.__init__(self)
|
|
||||||
|
|
||||||
def _attention(self, query, key, value, attention_mask=None):
|
|
||||||
# default_result = super()._attention(query, key, value)
|
|
||||||
if attention_mask is not None:
|
|
||||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
|
||||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
|
||||||
|
|
||||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
## 🧨diffusers implementation follows
|
## 🧨diffusers implementation follows
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,98 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import PIL
|
|
||||||
import torch
|
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionMapSaver:
|
|
||||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
|
||||||
self.token_ids = token_ids
|
|
||||||
self.latents_shape = latents_shape
|
|
||||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
|
||||||
self.collated_maps = {}
|
|
||||||
|
|
||||||
def clear_maps(self):
|
|
||||||
self.collated_maps = {}
|
|
||||||
|
|
||||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
|
||||||
"""
|
|
||||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
|
||||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
|
||||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
key_and_size = f"{key}_{maps.shape[1]}"
|
|
||||||
|
|
||||||
# extract desired tokens
|
|
||||||
maps = maps[:, :, self.token_ids]
|
|
||||||
|
|
||||||
# merge attention heads to a single map per token
|
|
||||||
maps = torch.sum(maps, 0)
|
|
||||||
|
|
||||||
# store
|
|
||||||
if key_and_size not in self.collated_maps:
|
|
||||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
|
|
||||||
self.collated_maps[key_and_size] += maps.cpu()
|
|
||||||
|
|
||||||
def write_maps_to_disk(self, path: str):
|
|
||||||
pil_image = self.get_stacked_maps_image()
|
|
||||||
pil_image.save(path, "PNG")
|
|
||||||
|
|
||||||
def get_stacked_maps_image(self) -> PIL.Image:
|
|
||||||
"""
|
|
||||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
|
||||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
|
||||||
"""
|
|
||||||
num_tokens = len(self.token_ids)
|
|
||||||
if num_tokens == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
latents_height = self.latents_shape[0]
|
|
||||||
latents_width = self.latents_shape[1]
|
|
||||||
|
|
||||||
merged = None
|
|
||||||
|
|
||||||
for key, maps in self.collated_maps.items():
|
|
||||||
# maps has shape [(H*W), N] for N tokens
|
|
||||||
# but we want [N, H, W]
|
|
||||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
|
||||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
|
||||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
|
||||||
# and we need to do some dimension juggling
|
|
||||||
maps = torch.reshape(
|
|
||||||
torch.swapdims(maps, 0, 1),
|
|
||||||
[num_tokens, this_maps_height, this_maps_width],
|
|
||||||
)
|
|
||||||
|
|
||||||
# scale to output size if necessary
|
|
||||||
if this_scale_factor != 1:
|
|
||||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
|
||||||
|
|
||||||
# normalize
|
|
||||||
maps_min = torch.min(maps)
|
|
||||||
maps_range = torch.max(maps) - maps_min
|
|
||||||
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
|
||||||
maps_normalized = (maps - maps_min) / maps_range
|
|
||||||
# expand to (-0.1, 1.1) and clamp
|
|
||||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
|
||||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
|
||||||
|
|
||||||
# merge together, producing a vertical stack
|
|
||||||
maps_stacked = torch.reshape(
|
|
||||||
maps_normalized_expanded_clamped,
|
|
||||||
[num_tokens * latents_height, latents_width],
|
|
||||||
)
|
|
||||||
|
|
||||||
if merged is None:
|
|
||||||
merged = maps_stacked
|
|
||||||
else:
|
|
||||||
# screen blend
|
|
||||||
merged = 1 - (1 - maps_stacked) * (1 - merged)
|
|
||||||
|
|
||||||
if merged is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
merged_bytes = merged.mul(0xFF).byte()
|
|
||||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode="L")
|
|
@ -1,8 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import math
|
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -14,12 +14,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
Arguments,
|
Arguments,
|
||||||
Context,
|
Context,
|
||||||
CrossAttentionType,
|
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
get_cross_attention_modules,
|
|
||||||
setup_cross_attention_control_attention_processors,
|
setup_cross_attention_control_attention_processors,
|
||||||
)
|
)
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
|
||||||
|
|
||||||
ModelForwardCallback: TypeAlias = Union[
|
ModelForwardCallback: TypeAlias = Union[
|
||||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||||
@ -105,7 +102,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
step_count: int,
|
|
||||||
):
|
):
|
||||||
old_attn_processors = None
|
old_attn_processors = None
|
||||||
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
||||||
@ -114,7 +110,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
if extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.cross_attention_control_context = Context(
|
self.cross_attention_control_context = Context(
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
step_count=step_count,
|
|
||||||
)
|
)
|
||||||
setup_cross_attention_control_attention_processors(
|
setup_cross_attention_control_attention_processors(
|
||||||
unet,
|
unet,
|
||||||
@ -127,27 +122,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
if old_attn_processors is not None:
|
if old_attn_processors is not None:
|
||||||
unet.set_attn_processor(old_attn_processors)
|
unet.set_attn_processor(old_attn_processors)
|
||||||
# TODO resuscitate attention map saving
|
|
||||||
# self.remove_attention_map_saving()
|
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
|
||||||
if dim is not None:
|
|
||||||
# sliced tokens attention map saving is not implemented
|
|
||||||
return
|
|
||||||
saver.add_attention_maps(slice, key)
|
|
||||||
|
|
||||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
|
||||||
for identifier, module in tokens_cross_attention_modules:
|
|
||||||
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
|
||||||
module.set_attention_slice_calculated_callback(
|
|
||||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_attention_map_saving(self):
|
|
||||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
|
||||||
for _, module in tokens_cross_attention_modules:
|
|
||||||
module.set_attention_slice_calculated_callback(None)
|
|
||||||
|
|
||||||
def do_controlnet_step(
|
def do_controlnet_step(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user