mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup
This commit is contained in:
parent
1f5ad1b05e
commit
34a3f4a820
@ -307,9 +307,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if is_xformers_available() and not Globals.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
slice_size = 2
|
||||
slice_size = 4 # or 2, or 8. i chose this arbitrarily.
|
||||
self.enable_attention_slicing(slice_size=slice_size)
|
||||
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
|
@ -560,77 +560,6 @@ class SwapCrossAttnContext:
|
||||
return mask, indices
|
||||
|
||||
|
||||
class SwapCrossAttnProcessor(CrossAttnProcessor):
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
|
||||
# if cross-attention control is not in play, just call through to the base implementation.
|
||||
if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type):
|
||||
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
#else:
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
# helper function
|
||||
def get_attention_probs(embeddings):
|
||||
this_key = attn.to_k(embeddings)
|
||||
this_key = attn.head_to_batch_dim(this_key)
|
||||
return attn.get_attention_scores(query, this_key, attention_mask)
|
||||
|
||||
if attention_type == CrossAttentionType.SELF:
|
||||
# self attention has no remapping, it just bluntly copies the whole tensor
|
||||
attention_probs = get_attention_probs(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
else:
|
||||
# tokens (cross) attention
|
||||
# first, find attention probabilities for the "original" prompt
|
||||
original_text_embeddings = encoder_hidden_states
|
||||
original_attention_probs = get_attention_probs(original_text_embeddings)
|
||||
|
||||
# then, find attention probabilities for the "modified" prompt
|
||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||
modified_attention_probs = get_attention_probs(modified_text_embeddings)
|
||||
|
||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||
# the original attention probabilities must be remapped to account for token index changes in the
|
||||
# modified prompt
|
||||
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
|
||||
swap_cross_attn_context.index_map)
|
||||
|
||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||
mask = swap_cross_attn_context.mask
|
||||
inverse_mask = 1 - mask
|
||||
attention_probs = \
|
||||
remapped_original_attention_probs * mask + \
|
||||
modified_attention_probs * inverse_mask
|
||||
|
||||
# for the "value" just use the modified text embeddings.
|
||||
value = attn.to_v(modified_text_embeddings)
|
||||
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
# TODO: dynamically pick slice size based on memory conditions
|
||||
@ -714,3 +643,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||
|
||||
def __init__(self):
|
||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=1e6) # big number so we never slice
|
||||
|
Loading…
Reference in New Issue
Block a user