Update IP-Adapter model to enable running multiple IP-Adapters at once. (Not tested yet.)

This commit is contained in:
Ryan Dick 2023-10-06 11:05:25 -04:00 committed by Kent Keirsey
parent 78828b6b9c
commit 7ca456d674
6 changed files with 165 additions and 153 deletions

View File

@ -8,6 +8,8 @@ import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
@ -45,18 +47,13 @@ class IPAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
def __init__(self, weights: list[IPAttentionProcessorWeights]):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.weights = weights
def __call__(
self,
@ -67,16 +64,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@ -128,23 +115,36 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self.weights)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
for ipa_embed, ipa_weights in zip(ip_adapter_image_prompt_embeds, self.weights):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ipa_embed
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + ipa_weights.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)

View File

@ -1,17 +1,15 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from contextlib import contextmanager
from typing import Optional, Union
import torch
from diffusers.models import UNet2DConditionModel
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler
@ -61,7 +59,7 @@ class IPAdapter:
def __init__(
self,
state_dict: dict[torch.Tensor],
state_dict: dict[str, torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
@ -73,12 +71,11 @@ class IPAdapter:
self._clip_image_processor = CLIPImageProcessor()
self._state_dict = state_dict
self._image_proj_model = self._init_image_proj_model(state_dict["image_proj"])
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
# The _attn_processors will be initialized later when we have access to the UNet.
self._attn_processors = None
self.attn_weights = IPAttentionWeights.from_state_dict(state_dict["ip_adapter"]).to(
self.device, dtype=self.dtype
)
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
@ -86,99 +83,14 @@ class IPAdapter:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)
def calc_size(self):
if self._state_dict is not None:
image_proj_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
)
ip_adapter_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
)
return image_proj_size + ip_adapter_size
else:
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
torch.nn.ModuleList(self._attn_processors.values())
)
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self._attn_weights)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
# which makes implementing begin_step_percent and end_step_percent easier
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
# which should make it easier to implement multiple IPAdapters
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)
# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):

View File

@ -0,0 +1,46 @@
import torch
class IPAttentionProcessorWeights(torch.nn.Module):
"""The IP-Adapter weights for a single attention processor.
This class is a torch.nn.Module sub-class to facilitate loading from a state_dict. It does not have a forward(...)
method.
"""
def __init__(self, in_dim: int, out_dim: int, scale: float = 1.0):
super().__init__()
self.scale = scale
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
class IPAttentionWeights(torch.nn.Module):
"""A collection of all the `IPAttentionProcessorWeights` objects for an IP-Adapter model.
This class is a torch.nn.Module sub-class so that it inherits the `.to(...)` functionality. It does not have a
forward(...) method.
"""
def __init__(self, weights: dict[int, IPAttentionProcessorWeights]):
super().__init__()
self.weights = weights
def set_scale(self, scale: float):
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
for w in self.weights.values():
w.scale = scale
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
attn_proc_weights: dict[int, IPAttentionProcessorWeights] = {}
for tensor_name, tensor in state_dict.items():
if "to_k_ip.weight" in tensor_name:
index = int(tensor_name.split(".")[0])
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
attn_proc_weights_module_dict = torch.nn.ModuleDict(attn_proc_weights)
attn_proc_weights_module_dict.load_state_dict(state_dict)
return cls(attn_proc_weights)

View File

@ -0,0 +1,51 @@
from contextlib import contextmanager
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# TODO(ryand): This logic can be simplified.
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0([ip_adapter.attn_weights.weights[idx] for ip_adapter in ip_adapters])
@contextmanager
def apply_ip_adapter_attention(cls, unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
"""A context manager that patches `unet` with IP-Adapter attention processors."""
attn_procs = _prepare_attention_processors(unet, ip_adapters)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@ -24,6 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from ..util import auto_detect_slice_size, normalize_device
@ -434,10 +435,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
weight = ip_adapter_data.weight[0] if isinstance(ip_adapter_data.weight, List) else ip_adapter_data.weight
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model,
scale=weight,
attn_ctx = apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data]
)
self.use_ip_adapter = True
else:
@ -513,7 +512,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
@ -527,20 +526,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
first_adapter_step = math.floor(ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(ip_adapter_data.end_step_percent * total_step_count)
weight = (
ip_adapter_data.weight[step_index]
if isinstance(ip_adapter_data.weight, List)
else ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# only apply IP-Adapter if current step is within the IP-Adapter's begin/end step range
# ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
ip_adapter_data.ip_adapter_model.set_scale(weight)
else:
# otherwise, set IP-Adapter scale to 0, so it has no effect
ip_adapter_data.ip_adapter_model.set_scale(0.0)
for single_ip_adapter_data in ip_adapter_data:
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
weight = (
single_ip_adapter_data.weight[step_index]
if isinstance(single_ip_adapter_data.weight, List)
else single_ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(weight)
else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(0.0)
# Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None

View File

@ -346,12 +346,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": torch.cat(
[
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
]
)
"ip_adapter_image_prompt_embeds": [
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
added_cond_kwargs = None
@ -418,7 +416,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.uncond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
added_cond_kwargs = None
@ -444,7 +445,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.cond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
added_cond_kwargs = None