mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch to using torch 2.0 attention for IP-Adapter (more memory-efficient).
This commit is contained in:
parent
382e2139bd
commit
b05b8ef677
@ -10,28 +10,8 @@ from diffusers.models.attention_processor import AttnProcessor as DiffusersAttnP
|
|||||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
||||||
|
|
||||||
|
|
||||||
# Create versions of AttnProcessor and AttnProcessor2_0 that are sub-classes of nn.Module. This is required for
|
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
||||||
# IP-Adapter state_dict loading.
|
# loading.
|
||||||
class AttnProcessor(DiffusersAttnProcessor, nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
DiffusersAttnProcessor.__init__(self)
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
|
|
||||||
ip_adapter_image_prompt_embeds parameter.
|
|
||||||
"""
|
|
||||||
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DiffusersAttnProcessor2_0.__init__(self)
|
DiffusersAttnProcessor2_0.__init__(self)
|
||||||
@ -54,113 +34,6 @@ class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor(nn.Module):
|
|
||||||
r"""
|
|
||||||
Attention processor for IP-Adapater.
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`):
|
|
||||||
The hidden size of the attention layer.
|
|
||||||
cross_attention_dim (`int`):
|
|
||||||
The number of channels in the `encoder_hidden_states`.
|
|
||||||
scale (`float`, defaults to 1.0):
|
|
||||||
the weight scale of image prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
||||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
if encoder_hidden_states is not None:
|
|
||||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
|
||||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
|
||||||
assert ip_adapter_image_prompt_embeds is not None
|
|
||||||
# The batch dimensions should match.
|
|
||||||
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
|
||||||
# The channel dimensions should match.
|
|
||||||
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
|
||||||
ip_hidden_states = ip_adapter_image_prompt_embeds
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
key = attn.head_to_batch_dim(key)
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
if ip_hidden_states is not None:
|
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
|
||||||
|
|
||||||
ip_key = attn.head_to_batch_dim(ip_key)
|
|
||||||
ip_value = attn.head_to_batch_dim(ip_value)
|
|
||||||
|
|
||||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
|
||||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
|
||||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor2_0(torch.nn.Module):
|
class IPAttnProcessor2_0(torch.nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||||
@ -256,7 +129,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
if ip_hidden_states:
|
if ip_hidden_states is not None:
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
@ -6,18 +6,10 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
|
|
||||||
# so for now falling back to the default versions
|
|
||||||
# from .utils import is_torch2_available
|
|
||||||
# if is_torch2_available:
|
|
||||||
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
|
||||||
# else:
|
|
||||||
# from .attention_processor import IPAttnProcessor, AttnProcessor
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from .attention_processor import AttnProcessor, IPAttnProcessor
|
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
@ -118,9 +110,9 @@ class IPAdapter:
|
|||||||
block_id = int(name[len("down_blocks.")])
|
block_id = int(name[len("down_blocks.")])
|
||||||
hidden_size = unet.config.block_out_channels[block_id]
|
hidden_size = unet.config.block_out_channels[block_id]
|
||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
attn_procs[name] = AttnProcessor()
|
attn_procs[name] = AttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
attn_procs[name] = IPAttnProcessor(
|
attn_procs[name] = IPAttnProcessor2_0(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
@ -138,7 +130,7 @@ class IPAdapter:
|
|||||||
def set_scale(self, scale):
|
def set_scale(self, scale):
|
||||||
if self._attn_processors is not None:
|
if self._attn_processors is not None:
|
||||||
for attn_processor in self._attn_processors.values():
|
for attn_processor in self._attn_processors.values():
|
||||||
if isinstance(attn_processor, IPAttnProcessor):
|
if isinstance(attn_processor, IPAttnProcessor2_0):
|
||||||
attn_processor.scale = scale
|
attn_processor.scale = scale
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -156,7 +148,7 @@ class IPAdapter:
|
|||||||
# Set scale
|
# Set scale
|
||||||
self.set_scale(scale)
|
self.set_scale(scale)
|
||||||
# for attn_processor in self._attn_processors.values():
|
# for attn_processor in self._attn_processors.values():
|
||||||
# if isinstance(attn_processor, IPAttnProcessor):
|
# if isinstance(attn_processor, IPAttnProcessor2_0):
|
||||||
# attn_processor.scale = scale
|
# attn_processor.scale = scale
|
||||||
|
|
||||||
orig_attn_processors = unet.attn_processors
|
orig_attn_processors = unet.attn_processors
|
||||||
|
Loading…
x
Reference in New Issue
Block a user