diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 2f7523dd46..667fcd9a64 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -3,7 +3,6 @@ from typing import Optional import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention, AttnProcessor2_0 -from diffusers.utils import USE_PEFT_BACKEND from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData @@ -51,7 +50,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, # For regional prompting: regional_prompt_data: Optional[RegionalPromptData] = None, percent_through: Optional[torch.FloatTensor] = None, @@ -111,16 +109,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -187,7 +184,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Start unmodified block from AttnProcessor2_0. # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states)