From 33d6603fef4839f7627eb817181c1cb59bb3b838 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 16 Oct 2022 16:57:48 +0200 Subject: [PATCH] cleanup initial experiments --- ldm/models/diffusion/cross_attention.py | 19 +++++++++++++++++++ ldm/models/diffusion/ksampler.py | 10 ---------- 2 files changed, 19 insertions(+), 10 deletions(-) create mode 100644 ldm/models/diffusion/cross_attention.py diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py new file mode 100644 index 0000000000..c39d8d5959 --- /dev/null +++ b/ldm/models/diffusion/cross_attention.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class CrossAttention: + + class AttentionType(Enum): + SELF = 1 + TOKENS = 2 + + @classmethod + def get_attention_module(cls, model, which: AttentionType): + which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" + module = next(module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name) + return module + + @classmethod + def inject_attention_mask_capture(cls, model, callback): + pass \ No newline at end of file diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 55800d0a5c..8010b44d1d 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -27,9 +27,6 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): minval = max(min(-1, scale*minval), -threshold) return torch.clamp(result, min=minval, max=maxval) -class AttentionLayer(Enum): - SELF = 1 - TOKENS = 2 class CFGDenoiser(nn.Module): def __init__(self, model, threshold = 0, warmup = 0): @@ -40,13 +37,6 @@ class CFGDenoiser(nn.Module): self.warmup = max(warmup / 10, 1) - def get_attention_module(self, which: AttentionLayer): - which_attn = "attn1" if which is AttentionLayer.SELF else "attn2" - module = next(module for name,module in self.inner_model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name) - return module - - def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2)