mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip
This commit is contained in:
parent
d183aa823c
commit
8cd81e52be
@ -69,6 +69,11 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
|
prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
|
||||||
encoder_hidden_states = regional_prompt_data.text_embeds
|
encoder_hidden_states = regional_prompt_data.text_embeds
|
||||||
|
|
||||||
|
hidden_states_stack = []
|
||||||
|
for batch_idx, prompt_count in enumerate(regional_prompt_data.prompt_count_by_batch_element):
|
||||||
|
hidden_states_stack.append(hidden_states[batch_idx : batch_idx + 1].repeat((prompt_count, 1, 1)))
|
||||||
|
hidden_states = torch.cat(hidden_states_stack, dim=0)
|
||||||
|
|
||||||
# Start unmodified block from AttnProcessor2_0.
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -148,8 +153,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
print("todo")
|
||||||
|
|
||||||
# Apply IP-Adapter conditioning.
|
# Apply IP-Adapter conditioning.
|
||||||
if is_cross_attention:
|
if is_cross_attention:
|
||||||
|
if regional_prompt_data is not None:
|
||||||
|
outputs = []
|
||||||
|
cur_idx = 0
|
||||||
|
for prompt_count in regional_prompt_data.prompt_count_by_batch_element:
|
||||||
|
cur_prompt_masks = prompt_masks[cur_idx : cur_idx + prompt_count]
|
||||||
|
cur_prompt_masks = cur_prompt_masks.view(-1, hidden_states.shape[1], 1)
|
||||||
|
masked_output = hidden_states[cur_idx : cur_idx + prompt_count] * cur_prompt_masks
|
||||||
|
masked_output = masked_output.sum(dim=0, keepdim=True)
|
||||||
|
outputs.append(masked_output)
|
||||||
|
cur_idx += prompt_count
|
||||||
|
hidden_states = torch.cat(outputs, dim=0)
|
||||||
|
|
||||||
if self._ip_adapter_attention_weights:
|
if self._ip_adapter_attention_weights:
|
||||||
assert regional_ip_data is not None
|
assert regional_ip_data is not None
|
||||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
@ -13,7 +13,6 @@ class RegionalPromptData:
|
|||||||
masks: list[list[torch.Tensor]],
|
masks: list[list[torch.Tensor]],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
max_downscale_factor: int = 8,
|
|
||||||
):
|
):
|
||||||
"""Initialize a `RegionalPromptData` object.
|
"""Initialize a `RegionalPromptData` object.
|
||||||
Args:
|
Args:
|
||||||
@ -49,15 +48,16 @@ class RegionalPromptData:
|
|||||||
for mask_batch in masks:
|
for mask_batch in masks:
|
||||||
masks_flat_list.extend(mask_batch)
|
masks_flat_list.extend(mask_batch)
|
||||||
self._masks = torch.cat(masks_flat_list, dim=0)
|
self._masks = torch.cat(masks_flat_list, dim=0)
|
||||||
|
# TODO(ryand): Is this necessary? Do we need to do the same for text_embeds?
|
||||||
|
self._masks = self._masks.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
self._device = device
|
self._device = device
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
|
||||||
def get_masks(self, query_seq_len: int):
|
def get_masks(self, query_seq_len: int, max_downscale_factor: int = 8) -> torch.Tensor:
|
||||||
_, h, w = self._masks.shape
|
_, _, h, w = self._masks.shape
|
||||||
|
|
||||||
# Determine the downscaling factor for the given query sequence length.
|
# Determine the downscaling factor for the given query sequence length.
|
||||||
max_downscale_factor = 8
|
|
||||||
downscale_factor = 1
|
downscale_factor = 1
|
||||||
while downscale_factor <= max_downscale_factor:
|
while downscale_factor <= max_downscale_factor:
|
||||||
if query_seq_len == (h // downscale_factor) * (w // downscale_factor):
|
if query_seq_len == (h // downscale_factor) * (w // downscale_factor):
|
||||||
|
Loading…
Reference in New Issue
Block a user