cleanup and clarify comments

This commit is contained in:
Damian at mba 2022-10-18 19:49:25 +02:00
parent 711ffd238f
commit 09f62032ec
2 changed files with 29 additions and 15 deletions

View File

@ -74,6 +74,8 @@ class CrossAttentionControl:
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
@ -91,6 +93,8 @@ class CrossAttentionControl:
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
attn_slice = suggested_attention_slice
if dim is not None:
start = offset

View File

@ -50,23 +50,32 @@ class CFGDenoiser(nn.Module):
CrossAttentionControl.clear_requests(self.inner_model)
#rint('generating unconditioned latents')
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
if self.edited_conditioning is None:
# faster batch path
x_twice = torch.cat([x]*2)
sigma_twice = torch.cat([sigma]*2)
both_conditionings = torch.cat([uncond, cond])
unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2)
else:
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
unconditioned_next_x = self.inner_model(x, sigma, cond=uncond)
# process x using the original prompt, saving the attention maps if required
if self.edited_conditioning is not None:
# this is automatically toggled off after the model forward()
# process x using the original prompt, saving the attention maps
CrossAttentionControl.request_save_attention_maps(self.inner_model)
#print('generating conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=cond)
if self.edited_conditioning is not None:
# process x again, using the saved attention maps but the new conditioning
# this is automatically toggled off after the model forward()
_ = self.inner_model(x, sigma, cond=cond)
CrossAttentionControl.clear_requests(self.inner_model)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
#print('generating edited conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning)
CrossAttentionControl.clear_requests(self.inner_model)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
@ -75,8 +84,9 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
delta = (conditioned_latents - unconditioned_latents)
return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh)
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale
return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh)
class KSampler(Sampler):