diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 3386c72556..ac53bf911d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from itertools import cycle, islice -from typing import List, Optional, TypedDict, cast +from typing import List, Optional, cast import torch import torch.nn.functional as F @@ -10,7 +11,8 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData -class IPAdapterAttentionWeights(TypedDict): +@dataclass +class IPAdapterAttentionWeights: ip_adapter_weights: List[IPAttentionProcessorWeights] skip: bool @@ -63,7 +65,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): is_cross_attention = encoder_hidden_states is not None # Start unmodified block from AttnProcessor2_0. - # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -77,7 +78,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # End unmodified block from AttnProcessor2_0. _, query_seq_len, _ = hidden_states.shape @@ -140,22 +140,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) # Pad weight tensor list to match size of regional embeds - self._ip_adapter_attention_weights["ip_adapter_weights"] = list( + self._ip_adapter_attention_weights.ip_adapter_weights = list( islice( - cycle(self._ip_adapter_attention_weights["ip_adapter_weights"]), + cycle(self._ip_adapter_attention_weights.ip_adapter_weights), len(regional_ip_data.image_prompt_embeds), ) ) assert ( len(regional_ip_data.image_prompt_embeds) - == len(self._ip_adapter_attention_weights["ip_adapter_weights"]) + == len(self._ip_adapter_attention_weights.ip_adapter_weights) == len(regional_ip_data.scales) == ip_masks.shape[1] ) for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): - ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index] + ipa_weights = self._ip_adapter_attention_weights.ip_adapter_weights[ipa_index] ipa_scale = regional_ip_data.scales[ipa_index] ip_mask = ip_masks[0, ipa_index, ...] @@ -168,7 +168,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - if not self._ip_adapter_attention_weights["skip"]: + if not self._ip_adapter_attention_weights.skip: ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states) @@ -215,5 +215,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor + # End of unmodified block from AttnProcessor2_0 + # casting torch.Tensor to torch.FloatTensor to avoid type issues return cast(torch.FloatTensor, hidden_states) diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index 52cfc2c573..df9b8d6386 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -33,7 +33,7 @@ class UNetAttentionPatcher: # "attn1" processors do not use IP-Adapters. attn_procs[name] = CustomAttnProcessor2_0() else: - ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False} + ip_adapter_attention_weights = IPAdapterAttentionWeights(ip_adapter_weights=[], skip=False) for ip_adapter in self._ip_adapters: ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx) skip = True @@ -41,8 +41,8 @@ class UNetAttentionPatcher: if block in name: skip = False break - - ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip}) + ip_adapter_attention_weights.ip_adapter_weights = [ip_adapter_weight] + ip_adapter_attention_weights.skip = skip # Collect the weights from each IP Adapter for the idx'th attention processor.