mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: change IPAdapterAttentionWeights to a dataclass
This commit is contained in:
parent
cd76a31a8f
commit
5f6c6abf9c
@ -1,5 +1,6 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from itertools import cycle, islice
|
from itertools import cycle, islice
|
||||||
from typing import List, Optional, TypedDict, cast
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -10,7 +11,8 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterAttentionWeights(TypedDict):
|
@dataclass
|
||||||
|
class IPAdapterAttentionWeights:
|
||||||
ip_adapter_weights: List[IPAttentionProcessorWeights]
|
ip_adapter_weights: List[IPAttentionProcessorWeights]
|
||||||
skip: bool
|
skip: bool
|
||||||
|
|
||||||
@ -63,7 +65,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
is_cross_attention = encoder_hidden_states is not None
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
# Start unmodified block from AttnProcessor2_0.
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
@ -77,7 +78,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
batch_size, sequence_length, _ = (
|
batch_size, sequence_length, _ = (
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
)
|
)
|
||||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
_, query_seq_len, _ = hidden_states.shape
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
@ -140,22 +140,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
|
||||||
# Pad weight tensor list to match size of regional embeds
|
# Pad weight tensor list to match size of regional embeds
|
||||||
self._ip_adapter_attention_weights["ip_adapter_weights"] = list(
|
self._ip_adapter_attention_weights.ip_adapter_weights = list(
|
||||||
islice(
|
islice(
|
||||||
cycle(self._ip_adapter_attention_weights["ip_adapter_weights"]),
|
cycle(self._ip_adapter_attention_weights.ip_adapter_weights),
|
||||||
len(regional_ip_data.image_prompt_embeds),
|
len(regional_ip_data.image_prompt_embeds),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(regional_ip_data.image_prompt_embeds)
|
len(regional_ip_data.image_prompt_embeds)
|
||||||
== len(self._ip_adapter_attention_weights["ip_adapter_weights"])
|
== len(self._ip_adapter_attention_weights.ip_adapter_weights)
|
||||||
== len(regional_ip_data.scales)
|
== len(regional_ip_data.scales)
|
||||||
== ip_masks.shape[1]
|
== ip_masks.shape[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||||
ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index]
|
ipa_weights = self._ip_adapter_attention_weights.ip_adapter_weights[ipa_index]
|
||||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||||
ip_mask = ip_masks[0, ipa_index, ...]
|
ip_mask = ip_masks[0, ipa_index, ...]
|
||||||
|
|
||||||
@ -168,7 +168,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
if not self._ip_adapter_attention_weights["skip"]:
|
if not self._ip_adapter_attention_weights.skip:
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
@ -215,5 +215,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
# End of unmodified block from AttnProcessor2_0
|
||||||
|
|
||||||
|
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||||
return cast(torch.FloatTensor, hidden_states)
|
return cast(torch.FloatTensor, hidden_states)
|
||||||
|
@ -33,7 +33,7 @@ class UNetAttentionPatcher:
|
|||||||
# "attn1" processors do not use IP-Adapters.
|
# "attn1" processors do not use IP-Adapters.
|
||||||
attn_procs[name] = CustomAttnProcessor2_0()
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False}
|
ip_adapter_attention_weights = IPAdapterAttentionWeights(ip_adapter_weights=[], skip=False)
|
||||||
for ip_adapter in self._ip_adapters:
|
for ip_adapter in self._ip_adapters:
|
||||||
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||||
skip = True
|
skip = True
|
||||||
@ -41,8 +41,8 @@ class UNetAttentionPatcher:
|
|||||||
if block in name:
|
if block in name:
|
||||||
skip = False
|
skip = False
|
||||||
break
|
break
|
||||||
|
ip_adapter_attention_weights.ip_adapter_weights = [ip_adapter_weight]
|
||||||
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})
|
ip_adapter_attention_weights.skip = skip
|
||||||
|
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user