This commit is contained in:
Ryan Dick 2024-04-21 01:55:50 -04:00
parent d183aa823c
commit 8cd81e52be
2 changed files with 23 additions and 4 deletions

View File

@ -69,6 +69,11 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len) prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
encoder_hidden_states = regional_prompt_data.text_embeds encoder_hidden_states = regional_prompt_data.text_embeds
hidden_states_stack = []
for batch_idx, prompt_count in enumerate(regional_prompt_data.prompt_count_by_batch_element):
hidden_states_stack.append(hidden_states[batch_idx : batch_idx + 1].repeat((prompt_count, 1, 1)))
hidden_states = torch.cat(hidden_states_stack, dim=0)
# Start unmodified block from AttnProcessor2_0. # Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
residual = hidden_states residual = hidden_states
@ -148,8 +153,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0. # End unmodified block from AttnProcessor2_0.
print("todo")
# Apply IP-Adapter conditioning. # Apply IP-Adapter conditioning.
if is_cross_attention: if is_cross_attention:
if regional_prompt_data is not None:
outputs = []
cur_idx = 0
for prompt_count in regional_prompt_data.prompt_count_by_batch_element:
cur_prompt_masks = prompt_masks[cur_idx : cur_idx + prompt_count]
cur_prompt_masks = cur_prompt_masks.view(-1, hidden_states.shape[1], 1)
masked_output = hidden_states[cur_idx : cur_idx + prompt_count] * cur_prompt_masks
masked_output = masked_output.sum(dim=0, keepdim=True)
outputs.append(masked_output)
cur_idx += prompt_count
hidden_states = torch.cat(outputs, dim=0)
if self._ip_adapter_attention_weights: if self._ip_adapter_attention_weights:
assert regional_ip_data is not None assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)

View File

@ -13,7 +13,6 @@ class RegionalPromptData:
masks: list[list[torch.Tensor]], masks: list[list[torch.Tensor]],
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
max_downscale_factor: int = 8,
): ):
"""Initialize a `RegionalPromptData` object. """Initialize a `RegionalPromptData` object.
Args: Args:
@ -49,15 +48,16 @@ class RegionalPromptData:
for mask_batch in masks: for mask_batch in masks:
masks_flat_list.extend(mask_batch) masks_flat_list.extend(mask_batch)
self._masks = torch.cat(masks_flat_list, dim=0) self._masks = torch.cat(masks_flat_list, dim=0)
# TODO(ryand): Is this necessary? Do we need to do the same for text_embeds?
self._masks = self._masks.to(dtype=dtype, device=device)
self._device = device self._device = device
self._dtype = dtype self._dtype = dtype
def get_masks(self, query_seq_len: int): def get_masks(self, query_seq_len: int, max_downscale_factor: int = 8) -> torch.Tensor:
_, h, w = self._masks.shape _, _, h, w = self._masks.shape
# Determine the downscaling factor for the given query sequence length. # Determine the downscaling factor for the given query sequence length.
max_downscale_factor = 8
downscale_factor = 1 downscale_factor = 1
while downscale_factor <= max_downscale_factor: while downscale_factor <= max_downscale_factor:
if query_seq_len == (h // downscale_factor) * (w // downscale_factor): if query_seq_len == (h // downscale_factor) * (w // downscale_factor):