Merge branch 'main' into lstein/feat/simple-mm2-api

This commit is contained in:
Lincoln Stein
2024-04-17 17:18:13 -04:00
committed by GitHub
28 changed files with 396 additions and 79 deletions

View File

@ -21,12 +21,9 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterData,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)

View File

@ -53,6 +53,7 @@ class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0

View File

@ -1,4 +1,5 @@
from typing import Optional
from dataclasses import dataclass
from typing import List, Optional, cast
import torch
import torch.nn.functional as F
@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
class CustomAttnProcessor2_0(AttnProcessor2_0):
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
This implementation is based on
@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __init__(
self,
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
):
"""Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
for the i'th IP-Adapter.
"""
super().__init__()
self._ip_adapter_weights = ip_adapter_weights
def _is_ip_adapter_enabled(self) -> bool:
return self._ip_adapter_weights is not None
self._ip_adapter_attention_weights = ip_adapter_attention_weights
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
# For regional prompting:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
# For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
"""Apply attention.
Args:
@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Apply IP-Adapter conditioning.
if is_cross_attention:
if self._is_ip_adapter_enabled():
if self._ip_adapter_attention_weights:
assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
assert (
len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_weights)
== len(self._ip_adapter_attention_weights)
== len(regional_ip_data.scales)
== ip_masks.shape[1]
)
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_weights[ipa_index]
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...]
@ -153,29 +161,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
if not self._ip_adapter_attention_weights[ipa_index].skip:
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None
@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End of unmodified block from AttnProcessor2_0
return hidden_states
# casting torch.Tensor to torch.FloatTensor to avoid type issues
return cast(torch.FloatTensor, hidden_states)

View File

@ -1,17 +1,25 @@
from contextlib import contextmanager
from typing import Optional
from typing import List, Optional, TypedDict
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
CustomAttnProcessor2_0,
IPAdapterAttentionWeights,
)
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
self._ip_adapters = ip_adapter_data
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
@ -26,9 +34,22 @@ class UNetAttentionPatcher:
attn_procs[name] = CustomAttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
)
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
for ip_adapter in self._ip_adapters:
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
return attn_procs
@contextmanager