Update IP-Adapter model to enable running multiple IP-Adapters at once. (Not tested yet.)

This commit is contained in:
Ryan Dick 2023-10-06 11:05:25 -04:00 committed by Kent Keirsey
parent 78828b6b9c
commit 7ca456d674
6 changed files with 165 additions and 153 deletions

View File

@ -8,6 +8,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict # Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading. # loading.
@ -45,18 +47,13 @@ class IPAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt. the weight scale of image prompt.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0): def __init__(self, weights: list[IPAttentionProcessorWeights]):
super().__init__() super().__init__()
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size self.weights = weights
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__( def __call__(
self, self,
@ -67,16 +64,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
temb=None, temb=None,
ip_adapter_image_prompt_embeds=None, ip_adapter_image_prompt_embeds=None,
): ):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
@ -128,23 +115,36 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
if ip_hidden_states is not None: if encoder_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states) # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
ip_value = self.to_v_ip(ip_hidden_states) # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self.weights)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) for ipa_embed, ipa_weights in zip(ip_adapter_image_prompt_embeds, self.weights):
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
# the output of sdp = (batch, num_heads, seq_len, head_dim) ip_hidden_states = ipa_embed
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_hidden_states = ip_hidden_states.to(query.dtype) ip_value = ipa_weights.to_v_ip(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + ipa_weights.scale * ip_hidden_states
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)

View File

@ -1,17 +1,15 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed # and modified as needed
from contextlib import contextmanager
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from diffusers.models import UNet2DConditionModel
from PIL import Image from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler from .resampler import Resampler
@ -61,7 +59,7 @@ class IPAdapter:
def __init__( def __init__(
self, self,
state_dict: dict[torch.Tensor], state_dict: dict[str, torch.Tensor],
device: torch.device, device: torch.device,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
num_tokens: int = 4, num_tokens: int = 4,
@ -73,12 +71,11 @@ class IPAdapter:
self._clip_image_processor = CLIPImageProcessor() self._clip_image_processor = CLIPImageProcessor()
self._state_dict = state_dict self._image_proj_model = self._init_image_proj_model(state_dict["image_proj"])
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"]) self.attn_weights = IPAttentionWeights.from_state_dict(state_dict["ip_adapter"]).to(
self.device, dtype=self.dtype
# The _attn_processors will be initialized later when we have access to the UNet. )
self._attn_processors = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device self.device = device
@ -86,99 +83,14 @@ class IPAdapter:
self.dtype = dtype self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype) self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None: self.attn_weights.to(device=self.device, dtype=self.dtype)
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
def calc_size(self): def calc_size(self):
if self._state_dict is not None: return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self._attn_weights)
image_proj_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
)
ip_adapter_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
)
return image_proj_size + ip_adapter_size
else:
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
torch.nn.ModuleList(self._attn_processors.values())
)
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype) return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
# which makes implementing begin_step_percent and end_step_percent easier
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
# which should make it easier to implement multiple IPAdapters
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)
# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):

View File

@ -0,0 +1,46 @@
import torch
class IPAttentionProcessorWeights(torch.nn.Module):
"""The IP-Adapter weights for a single attention processor.
This class is a torch.nn.Module sub-class to facilitate loading from a state_dict. It does not have a forward(...)
method.
"""
def __init__(self, in_dim: int, out_dim: int, scale: float = 1.0):
super().__init__()
self.scale = scale
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
class IPAttentionWeights(torch.nn.Module):
"""A collection of all the `IPAttentionProcessorWeights` objects for an IP-Adapter model.
This class is a torch.nn.Module sub-class so that it inherits the `.to(...)` functionality. It does not have a
forward(...) method.
"""
def __init__(self, weights: dict[int, IPAttentionProcessorWeights]):
super().__init__()
self.weights = weights
def set_scale(self, scale: float):
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
for w in self.weights.values():
w.scale = scale
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
attn_proc_weights: dict[int, IPAttentionProcessorWeights] = {}
for tensor_name, tensor in state_dict.items():
if "to_k_ip.weight" in tensor_name:
index = int(tensor_name.split(".")[0])
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
attn_proc_weights_module_dict = torch.nn.ModuleDict(attn_proc_weights)
attn_proc_weights_module_dict.load_state_dict(state_dict)
return cls(attn_proc_weights)

View File

@ -0,0 +1,51 @@
from contextlib import contextmanager
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# TODO(ryand): This logic can be simplified.
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0([ip_adapter.attn_weights.weights[idx] for ip_adapter in ip_adapters])
@contextmanager
def apply_ip_adapter_attention(cls, unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
"""A context manager that patches `unet` with IP-Adapter attention processors."""
attn_procs = _prepare_attention_processors(unet, ip_adapters)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@ -24,6 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
@ -434,10 +435,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None: elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped. # As it is now, the IP-Adapter will silently be skipped.
weight = ip_adapter_data.weight[0] if isinstance(ip_adapter_data.weight, List) else ip_adapter_data.weight attn_ctx = apply_ip_adapter_attention(
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention( unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data]
unet=self.invokeai_diffuser.model,
scale=weight,
) )
self.use_ip_adapter = True self.use_ip_adapter = True
else: else:
@ -513,7 +512,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
@ -527,20 +526,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# handle IP-Adapter # handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
first_adapter_step = math.floor(ip_adapter_data.begin_step_percent * total_step_count) for single_ip_adapter_data in ip_adapter_data:
last_adapter_step = math.ceil(ip_adapter_data.end_step_percent * total_step_count) first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
weight = ( last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
ip_adapter_data.weight[step_index] weight = (
if isinstance(ip_adapter_data.weight, List) single_ip_adapter_data.weight[step_index]
else ip_adapter_data.weight if isinstance(single_ip_adapter_data.weight, List)
) else single_ip_adapter_data.weight
if step_index >= first_adapter_step and step_index <= last_adapter_step: )
# only apply IP-Adapter if current step is within the IP-Adapter's begin/end step range if step_index >= first_adapter_step and step_index <= last_adapter_step:
# ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight) # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
ip_adapter_data.ip_adapter_model.set_scale(weight) single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(weight)
else: else:
# otherwise, set IP-Adapter scale to 0, so it has no effect # Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_data.ip_adapter_model.set_scale(0.0) single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(0.0)
# Handle ControlNet(s) and T2I-Adapter(s) # Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None down_block_additional_residuals = None

View File

@ -346,12 +346,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = { cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": torch.cat( "ip_adapter_image_prompt_embeds": [
[ torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds, for ipa_conditioning in conditioning_data.ip_adapter_conditioning
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds, ]
]
)
} }
added_cond_kwargs = None added_cond_kwargs = None
@ -418,7 +416,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = { cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds "ip_adapter_image_prompt_embeds": [
ipa_conditioning.uncond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
} }
added_cond_kwargs = None added_cond_kwargs = None
@ -444,7 +445,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = { cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds "ip_adapter_image_prompt_embeds": [
ipa_conditioning.cond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
} }
added_cond_kwargs = None added_cond_kwargs = None