Pull the upstream changes from diffusers' AttnProcessor2_0 into CustomAttnProcessor2_0. This fixes a bug in CustomAttnProcessor2_0 that was being triggered when peft was not installed. The bug was present in a block of code that was previously copied from diffusers. The bug seems to have been introduced during diffusers' migration to PEFT for their LoRA handling. The upstream bug was fixed in 531e719163.

This commit is contained in:
Ryan Dick 2024-04-08 10:55:54 -04:00 committed by Kent Keirsey
parent 926b8d0efe
commit 75ef473748

View File

@ -3,7 +3,6 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import USE_PEFT_BACKEND
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@ -51,7 +50,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
# For regional prompting: # For regional prompting:
regional_prompt_data: Optional[RegionalPromptData] = None, regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None, percent_through: Optional[torch.FloatTensor] = None,
@ -111,16 +109,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
@ -187,7 +184,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Start unmodified block from AttnProcessor2_0. # Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)