Create a UNetAttentionPatcher for patching UNet models with CustomAttnProcessor2_0 modules.

This commit is contained in:
Ryan Dick 2024-03-08 14:15:16 -05:00 committed by Kent Keirsey
parent 31c456c1e6
commit 7ca677578e
4 changed files with 25 additions and 205 deletions

View File

@ -1,182 +0,0 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
assert len(weights) == len(scales)
self._weights = weights
self._scales = scales
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Apply IP-Adapter attention.
Args:
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@ -22,12 +22,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
TextConditioningData, TextConditioningData,
) )
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import normalize_device from invokeai.backend.util.devices import normalize_device
@ -412,7 +412,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None: elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped. # As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True self.use_ip_adapter = True
else: else:
@ -476,7 +476,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetPatcher] = None, ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]

View File

@ -1,52 +1,54 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
class UNetPatcher: class UNetAttentionPatcher:
"""A class that contains multiple IP-Adapters and can apply them to a UNet.""" """A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapters: list[IPAdapter]): def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters self._ip_adapters = ip_adapters
self._scales = [1.0] * len(self._ip_adapters) self._ip_adapter_scales = None
if self._ip_adapters is not None:
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
def set_scale(self, idx: int, value: float): def set_scale(self, idx: int, value: float):
self._scales[idx] = value self._ip_adapter_scales[idx] = value
def _prepare_attention_processors(self, unet: UNet2DConditionModel): def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them. weights into them (if IP-Adapters are being applied).
Note that the `unet` param is only used to determine attention block dimensions and naming. Note that the `unet` param is only used to determine attention block dimensions and naming.
""" """
# Construct a dict of attention processors based on the UNet's architecture. # Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {} attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()): for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"): if name.endswith("attn1.processor") or self._ip_adapters is None:
attn_procs[name] = AttnProcessor2_0() # "attn1" processors do not use IP-Adapters.
attn_procs[name] = CustomAttnProcessor2_0()
else: else:
# Collect the weights from each IP Adapter for the idx'th attention processor. # Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0( attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._scales, self._ip_adapter_scales,
) )
return attn_procs return attn_procs
@contextmanager @contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with IP-Adapter attention processors.""" """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
attn_procs = self._prepare_attention_processors(unet) attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors orig_attn_processors = unet.attn_processors
try: try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
yield None yield None
finally: finally:

View File

@ -1,8 +1,8 @@
import pytest import pytest
import torch import torch
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model from invokeai.backend.util.test_utils import install_and_load_model
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample