mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: make ip-adapter weights not be optional
This commit is contained in:
parent
d27907cc6d
commit
f46bbaf8c4
@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import Reg
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterAttentionWeights:
|
class IPAdapterAttentionWeights:
|
||||||
ip_adapter_weights: Optional[IPAttentionProcessorWeights]
|
ip_adapter_weights: IPAttentionProcessorWeights
|
||||||
skip: bool
|
skip: bool
|
||||||
|
|
||||||
|
|
||||||
@ -64,6 +64,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
is_cross_attention = encoder_hidden_states is not None
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
# Start unmodified block from AttnProcessor2_0.
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
@ -77,6 +78,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
batch_size, sequence_length, _ = (
|
batch_size, sequence_length, _ = (
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
)
|
)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
_, query_seq_len, _ = hidden_states.shape
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
@ -160,7 +162,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||||
if ipa_weights:
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
@ -206,6 +207,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
# End of unmodified block from AttnProcessor2_0
|
# End of unmodified block from AttnProcessor2_0
|
||||||
|
|
||||||
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||||
|
@ -37,17 +37,15 @@ class UNetAttentionPatcher:
|
|||||||
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
||||||
|
|
||||||
for ip_adapter in self._ip_adapters:
|
for ip_adapter in self._ip_adapters:
|
||||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
|
||||||
ip_adapter_weights=None, skip=False
|
|
||||||
)
|
|
||||||
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||||
skip = True
|
skip = True
|
||||||
for block in ip_adapter["target_blocks"]:
|
for block in ip_adapter["target_blocks"]:
|
||||||
if block in name:
|
if block in name:
|
||||||
skip = False
|
skip = False
|
||||||
break
|
break
|
||||||
ip_adapter_attention_weights.ip_adapter_weights = ip_adapter_weights
|
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||||
ip_adapter_attention_weights.skip = skip
|
ip_adapter_weights=ip_adapter_weights, skip=skip
|
||||||
|
)
|
||||||
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
||||||
|
|
||||||
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user