mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
be more aggressive at clearing out saved_attn_slice
This commit is contained in:
parent
c7de2b2801
commit
214d276379
@ -116,6 +116,7 @@ class CrossAttentionControl:
|
|||||||
for m in self_attention_modules+tokens_attention_modules:
|
for m in self_attention_modules+tokens_attention_modules:
|
||||||
m.save_last_attn_slice = False
|
m.save_last_attn_slice = False
|
||||||
m.use_last_attn_slice = False
|
m.use_last_attn_slice = False
|
||||||
|
m.last_attn_slice = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||||
|
@ -134,23 +134,31 @@ class InvokeAIDiffuserComponent:
|
|||||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
try:
|
||||||
for type in cross_attention_control_types_to_do:
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||||
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
# process x using the original prompt, saving the attention maps
|
||||||
for type in cross_attention_control_types_to_do:
|
for type in cross_attention_control_types_to_do:
|
||||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
||||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
CrossAttentionControl.clear_requests(self.model)
|
||||||
|
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
|
for type in cross_attention_control_types_to_do:
|
||||||
|
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
||||||
|
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
|
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
CrossAttentionControl.clear_requests(self.model)
|
||||||
|
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
# make sure we clean out the attention slices we're storing on the model
|
||||||
|
# TODO don't store things on the model
|
||||||
|
CrossAttentionControl.clear_requests(self.model)
|
||||||
|
raise
|
||||||
|
|
||||||
def estimate_percent_through(self, step_index, sigma):
|
def estimate_percent_through(self, step_index, sigma):
|
||||||
if step_index is not None and self.cross_attention_control_context is not None:
|
if step_index is not None and self.cross_attention_control_context is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user