mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip
This commit is contained in:
parent
d183aa823c
commit
8cd81e52be
@ -69,6 +69,11 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
|
||||
encoder_hidden_states = regional_prompt_data.text_embeds
|
||||
|
||||
hidden_states_stack = []
|
||||
for batch_idx, prompt_count in enumerate(regional_prompt_data.prompt_count_by_batch_element):
|
||||
hidden_states_stack.append(hidden_states[batch_idx : batch_idx + 1].repeat((prompt_count, 1, 1)))
|
||||
hidden_states = torch.cat(hidden_states_stack, dim=0)
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
residual = hidden_states
|
||||
@ -148,8 +153,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
print("todo")
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention:
|
||||
if regional_prompt_data is not None:
|
||||
outputs = []
|
||||
cur_idx = 0
|
||||
for prompt_count in regional_prompt_data.prompt_count_by_batch_element:
|
||||
cur_prompt_masks = prompt_masks[cur_idx : cur_idx + prompt_count]
|
||||
cur_prompt_masks = cur_prompt_masks.view(-1, hidden_states.shape[1], 1)
|
||||
masked_output = hidden_states[cur_idx : cur_idx + prompt_count] * cur_prompt_masks
|
||||
masked_output = masked_output.sum(dim=0, keepdim=True)
|
||||
outputs.append(masked_output)
|
||||
cur_idx += prompt_count
|
||||
hidden_states = torch.cat(outputs, dim=0)
|
||||
|
||||
if self._ip_adapter_attention_weights:
|
||||
assert regional_ip_data is not None
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
@ -13,7 +13,6 @@ class RegionalPromptData:
|
||||
masks: list[list[torch.Tensor]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `RegionalPromptData` object.
|
||||
Args:
|
||||
@ -49,15 +48,16 @@ class RegionalPromptData:
|
||||
for mask_batch in masks:
|
||||
masks_flat_list.extend(mask_batch)
|
||||
self._masks = torch.cat(masks_flat_list, dim=0)
|
||||
# TODO(ryand): Is this necessary? Do we need to do the same for text_embeds?
|
||||
self._masks = self._masks.to(dtype=dtype, device=device)
|
||||
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def get_masks(self, query_seq_len: int):
|
||||
_, h, w = self._masks.shape
|
||||
def get_masks(self, query_seq_len: int, max_downscale_factor: int = 8) -> torch.Tensor:
|
||||
_, _, h, w = self._masks.shape
|
||||
|
||||
# Determine the downscaling factor for the given query sequence length.
|
||||
max_downscale_factor = 8
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
if query_seq_len == (h // downscale_factor) * (w // downscale_factor):
|
||||
|
Loading…
Reference in New Issue
Block a user