From 8cd81e52be8d08f34cbb9e3951e615d90314625f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 21 Apr 2024 01:55:50 -0400 Subject: [PATCH] wip --- .../diffusion/custom_atttention.py | 19 +++++++++++++++++++ .../diffusion/regional_prompt_data.py | 8 ++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 5ff7b9a2ca..9857bab0b9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -69,6 +69,11 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len) encoder_hidden_states = regional_prompt_data.text_embeds + hidden_states_stack = [] + for batch_idx, prompt_count in enumerate(regional_prompt_data.prompt_count_by_batch_element): + hidden_states_stack.append(hidden_states[batch_idx : batch_idx + 1].repeat((prompt_count, 1, 1))) + hidden_states = torch.cat(hidden_states_stack, dim=0) + # Start unmodified block from AttnProcessor2_0. # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv residual = hidden_states @@ -148,8 +153,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # End unmodified block from AttnProcessor2_0. + print("todo") + # Apply IP-Adapter conditioning. if is_cross_attention: + if regional_prompt_data is not None: + outputs = [] + cur_idx = 0 + for prompt_count in regional_prompt_data.prompt_count_by_batch_element: + cur_prompt_masks = prompt_masks[cur_idx : cur_idx + prompt_count] + cur_prompt_masks = cur_prompt_masks.view(-1, hidden_states.shape[1], 1) + masked_output = hidden_states[cur_idx : cur_idx + prompt_count] * cur_prompt_masks + masked_output = masked_output.sum(dim=0, keepdim=True) + outputs.append(masked_output) + cur_idx += prompt_count + hidden_states = torch.cat(outputs, dim=0) + if self._ip_adapter_attention_weights: assert regional_ip_data is not None ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index af396e7c2f..3dd75b82a5 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -13,7 +13,6 @@ class RegionalPromptData: masks: list[list[torch.Tensor]], device: torch.device, dtype: torch.dtype, - max_downscale_factor: int = 8, ): """Initialize a `RegionalPromptData` object. Args: @@ -49,15 +48,16 @@ class RegionalPromptData: for mask_batch in masks: masks_flat_list.extend(mask_batch) self._masks = torch.cat(masks_flat_list, dim=0) + # TODO(ryand): Is this necessary? Do we need to do the same for text_embeds? + self._masks = self._masks.to(dtype=dtype, device=device) self._device = device self._dtype = dtype - def get_masks(self, query_seq_len: int): - _, h, w = self._masks.shape + def get_masks(self, query_seq_len: int, max_downscale_factor: int = 8) -> torch.Tensor: + _, _, h, w = self._masks.shape # Determine the downscaling factor for the given query sequence length. - max_downscale_factor = 8 downscale_factor = 1 while downscale_factor <= max_downscale_factor: if query_seq_len == (h // downscale_factor) * (w // downscale_factor):