From c2d43f007b00a45893937463eb94184cc507e0f9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 7 Sep 2023 18:20:21 -0400 Subject: [PATCH] Specify the image_embedding_len in the IPAttnProcessor rather than the text embedding length. This enables the IPAttnProcessor to handle text embeddings of varying lengths. --- .../backend/ip_adapter/attention_processor.py | 28 ++++++++++--------- invokeai/backend/ip_adapter/ip_adapter.py | 5 +++- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index 60aaeec7b0..99d9edc5dd 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -32,20 +32,21 @@ class IPAttnProcessor(nn.Module): Args: hidden_size (`int`): The hidden size of the attention layer. + image_embedding_len (`int`): + The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of + the `encoder_hidden_states` are the IP-Adapter image embeddings. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. - text_context_len (`int`, defaults to 77): - The context length of the text features. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ - def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0): + def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim - self.text_context_len = text_context_len + self.image_embedding_len = image_embedding_len self.scale = scale self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) @@ -85,10 +86,10 @@ class IPAttnProcessor(nn.Module): elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # split hidden states + # Split text encoder hidden states and image encoder hidden state. encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, : self.text_context_len, :], - encoder_hidden_states[:, self.text_context_len :, :], + encoder_hidden_states[:, : -self.image_embedding_len, :], + encoder_hidden_states[:, -self.image_embedding_len :, :], ) key = attn.to_k(encoder_hidden_states) @@ -137,15 +138,16 @@ class IPAttnProcessor2_0(torch.nn.Module): Args: hidden_size (`int`): The hidden size of the attention layer. + image_embedding_len (`int`): + The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of + the `encoder_hidden_states` are the IP-Adapter image embeddings. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. - text_context_len (`int`, defaults to 77): - The context length of the text features. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ - def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0): + def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -198,10 +200,10 @@ class IPAttnProcessor2_0(torch.nn.Module): elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # split hidden states + # Split text encoder hidden states and image encoder hidden state. encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, : self.text_context_len, :], - encoder_hidden_states[:, self.text_context_len :, :], + encoder_hidden_states[:, : -self.image_embedding_len, :], + encoder_hidden_states[:, -self.image_embedding_len :, :], ) key = attn.to_k(encoder_hidden_states) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 8261338058..a9fcc25539 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -91,7 +91,10 @@ class IPAdapter: else: print("swapping in IPAttnProcessor for", name) attn_procs[name] = IPAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + hidden_size=hidden_size, + image_embedding_len=self.num_tokens, + cross_attention_dim=cross_attention_dim, + scale=1.0, ).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) print("Modified UNet Attn Processors count:", len(unet.attn_processors))