Merge branch 'development' of github.com:invoke-ai/InvokeAI into development

This commit is contained in:
Lincoln Stein 2022-11-01 17:34:55 -04:00
commit 6215592b12
2 changed files with 24 additions and 14 deletions

View File

@ -110,12 +110,14 @@ class CrossAttentionControl:
type(module).__name__ == "CrossAttention" and which_attn in name] type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod @classmethod
def clear_requests(cls, model): def clear_requests(cls, model, clear_attn_slice=True):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules: for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False m.save_last_attn_slice = False
m.use_last_attn_slice = False m.use_last_attn_slice = False
if clear_attn_slice:
m.last_attn_slice = None
@classmethod @classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):

View File

@ -134,13 +134,15 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the # representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
for type in cross_attention_control_types_to_do: for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type) CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model) CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do: for type in cross_attention_control_types_to_do:
@ -152,6 +154,12 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
except RuntimeError:
# make sure we clean out the attention slices we're storing on the model
# TODO don't store things on the model
CrossAttentionControl.clear_requests(self.model)
raise
def estimate_percent_through(self, step_index, sigma): def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None: if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended) # percent_through will never reach 1.0 (but this is intended)