mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup and clarify comments
This commit is contained in:
parent
711ffd238f
commit
09f62032ec
@ -74,6 +74,8 @@ class CrossAttentionControl:
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
# clear out the saved slice in case the outermost dim changes
|
||||
m.last_attn_slice = None
|
||||
m.save_last_attn_slice = True
|
||||
|
||||
@classmethod
|
||||
@ -91,6 +93,8 @@ class CrossAttentionControl:
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
|
||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
|
@ -50,23 +50,32 @@ class CFGDenoiser(nn.Module):
|
||||
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
|
||||
#rint('generating unconditioned latents')
|
||||
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
|
||||
if self.edited_conditioning is None:
|
||||
# faster batch path
|
||||
x_twice = torch.cat([x]*2)
|
||||
sigma_twice = torch.cat([sigma]*2)
|
||||
both_conditionings = torch.cat([uncond, cond])
|
||||
unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2)
|
||||
else:
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
unconditioned_next_x = self.inner_model(x, sigma, cond=uncond)
|
||||
|
||||
# process x using the original prompt, saving the attention maps if required
|
||||
if self.edited_conditioning is not None:
|
||||
# this is automatically toggled off after the model forward()
|
||||
# process x using the original prompt, saving the attention maps
|
||||
CrossAttentionControl.request_save_attention_maps(self.inner_model)
|
||||
#print('generating conditioned latents')
|
||||
conditioned_latents = self.inner_model(x, sigma, cond=cond)
|
||||
|
||||
if self.edited_conditioning is not None:
|
||||
# process x again, using the saved attention maps but the new conditioning
|
||||
# this is automatically toggled off after the model forward()
|
||||
_ = self.inner_model(x, sigma, cond=cond)
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
|
||||
#print('generating edited conditioned latents')
|
||||
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
||||
conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
@ -75,8 +84,9 @@ class CFGDenoiser(nn.Module):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
delta = (conditioned_latents - unconditioned_latents)
|
||||
return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh)
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale
|
||||
return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh)
|
||||
|
||||
|
||||
class KSampler(Sampler):
|
||||
|
Loading…
Reference in New Issue
Block a user