mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/feat/simple-mm2-api
This commit is contained in:
@ -21,12 +21,9 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterData,
|
||||
TextConditioningData,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||
if use_ip_adapter
|
||||
else None
|
||||
)
|
||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
|
@ -53,6 +53,7 @@ class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
mask: torch.Tensor
|
||||
target_blocks: List[str]
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = 1.0
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterAttentionWeights:
|
||||
ip_adapter_weights: IPAttentionProcessorWeights
|
||||
skip: bool
|
||||
|
||||
|
||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||
This implementation is based on
|
||||
@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
for the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
self._ip_adapter_weights = ip_adapter_weights
|
||||
|
||||
def _is_ip_adapter_enabled(self) -> bool:
|
||||
return self._ip_adapter_weights is not None
|
||||
self._ip_adapter_attention_weights = ip_adapter_attention_weights
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
# For regional prompting:
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
# For Regional Prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.FloatTensor] = None,
|
||||
percent_through: Optional[torch.Tensor] = None,
|
||||
# For IP-Adapter:
|
||||
regional_ip_data: Optional[RegionalIPData] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
Args:
|
||||
@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention:
|
||||
if self._is_ip_adapter_enabled():
|
||||
if self._ip_adapter_attention_weights:
|
||||
assert regional_ip_data is not None
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
||||
assert (
|
||||
len(regional_ip_data.image_prompt_embeds)
|
||||
== len(self._ip_adapter_weights)
|
||||
== len(self._ip_adapter_attention_weights)
|
||||
== len(regional_ip_data.scales)
|
||||
== ip_masks.shape[1]
|
||||
)
|
||||
|
||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||
ipa_weights = self._ip_adapter_weights[ipa_index]
|
||||
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
|
||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||
ip_mask = ip_masks[0, ipa_index, ...]
|
||||
|
||||
@ -153,29 +161,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End of unmodified block from AttnProcessor2_0
|
||||
|
||||
return hidden_states
|
||||
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||
return cast(torch.FloatTensor, hidden_states)
|
||||
|
@ -1,17 +1,25 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import List, Optional, TypedDict
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
|
||||
CustomAttnProcessor2_0,
|
||||
IPAdapterAttentionWeights,
|
||||
)
|
||||
|
||||
|
||||
class UNetIPAdapterData(TypedDict):
|
||||
ip_adapter: IPAdapter
|
||||
target_blocks: List[str]
|
||||
|
||||
|
||||
class UNetAttentionPatcher:
|
||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||
|
||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||
self._ip_adapters = ip_adapters
|
||||
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
|
||||
self._ip_adapters = ip_adapter_data
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
@ -26,9 +34,22 @@ class UNetAttentionPatcher:
|
||||
attn_procs[name] = CustomAttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = CustomAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||
)
|
||||
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
||||
|
||||
for ip_adapter in self._ip_adapters:
|
||||
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||
skip = True
|
||||
for block in ip_adapter["target_blocks"]:
|
||||
if block in name:
|
||||
skip = False
|
||||
break
|
||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||
ip_adapter_weights=ip_adapter_weights, skip=skip
|
||||
)
|
||||
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
||||
|
||||
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
||||
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
|
Reference in New Issue
Block a user