chore: change IPAdapterAttentionWeights to a dataclass

This commit is contained in:
blessedcoolant 2024-04-15 23:38:55 +05:30
parent cd76a31a8f
commit 5f6c6abf9c
2 changed files with 14 additions and 12 deletions

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass
from itertools import cycle, islice from itertools import cycle, islice
from typing import List, Optional, TypedDict, cast from typing import List, Optional, cast
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -10,7 +11,8 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
class IPAdapterAttentionWeights(TypedDict): @dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: List[IPAttentionProcessorWeights] ip_adapter_weights: List[IPAttentionProcessorWeights]
skip: bool skip: bool
@ -63,7 +65,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
is_cross_attention = encoder_hidden_states is not None is_cross_attention = encoder_hidden_states is not None
# Start unmodified block from AttnProcessor2_0. # Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
@ -77,7 +78,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0. # End unmodified block from AttnProcessor2_0.
_, query_seq_len, _ = hidden_states.shape _, query_seq_len, _ = hidden_states.shape
@ -140,22 +140,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
# Pad weight tensor list to match size of regional embeds # Pad weight tensor list to match size of regional embeds
self._ip_adapter_attention_weights["ip_adapter_weights"] = list( self._ip_adapter_attention_weights.ip_adapter_weights = list(
islice( islice(
cycle(self._ip_adapter_attention_weights["ip_adapter_weights"]), cycle(self._ip_adapter_attention_weights.ip_adapter_weights),
len(regional_ip_data.image_prompt_embeds), len(regional_ip_data.image_prompt_embeds),
) )
) )
assert ( assert (
len(regional_ip_data.image_prompt_embeds) len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_attention_weights["ip_adapter_weights"]) == len(self._ip_adapter_attention_weights.ip_adapter_weights)
== len(regional_ip_data.scales) == len(regional_ip_data.scales)
== ip_masks.shape[1] == ip_masks.shape[1]
) )
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index] ipa_weights = self._ip_adapter_attention_weights.ip_adapter_weights[ipa_index]
ipa_scale = regional_ip_data.scales[ipa_index] ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...] ip_mask = ip_masks[0, ipa_index, ...]
@ -168,7 +168,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if not self._ip_adapter_attention_weights["skip"]: if not self._ip_adapter_attention_weights.skip:
ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states)
@ -215,5 +215,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor hidden_states = hidden_states / attn.rescale_output_factor
# End of unmodified block from AttnProcessor2_0
# casting torch.Tensor to torch.FloatTensor to avoid type issues
return cast(torch.FloatTensor, hidden_states) return cast(torch.FloatTensor, hidden_states)

View File

@ -33,7 +33,7 @@ class UNetAttentionPatcher:
# "attn1" processors do not use IP-Adapters. # "attn1" processors do not use IP-Adapters.
attn_procs[name] = CustomAttnProcessor2_0() attn_procs[name] = CustomAttnProcessor2_0()
else: else:
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False} ip_adapter_attention_weights = IPAdapterAttentionWeights(ip_adapter_weights=[], skip=False)
for ip_adapter in self._ip_adapters: for ip_adapter in self._ip_adapters:
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx) ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True skip = True
@ -41,8 +41,8 @@ class UNetAttentionPatcher:
if block in name: if block in name:
skip = False skip = False
break break
ip_adapter_attention_weights.ip_adapter_weights = [ip_adapter_weight]
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip}) ip_adapter_attention_weights.skip = skip
# Collect the weights from each IP Adapter for the idx'th attention processor. # Collect the weights from each IP Adapter for the idx'th attention processor.