cleanup initial experiments

This commit is contained in:
Damian at mba 2022-10-16 16:57:48 +02:00
parent b0b1993918
commit 33d6603fef
2 changed files with 19 additions and 10 deletions

View File

@ -0,0 +1,19 @@
from enum import Enum
class CrossAttention:
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_attention_module(cls, model, which: AttentionType):
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
module = next(module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name)
return module
@classmethod
def inject_attention_mask_capture(cls, model, callback):
pass

View File

@ -27,9 +27,6 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
minval = max(min(-1, scale*minval), -threshold)
return torch.clamp(result, min=minval, max=maxval)
class AttentionLayer(Enum):
SELF = 1
TOKENS = 2
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0):
@ -40,13 +37,6 @@ class CFGDenoiser(nn.Module):
self.warmup = max(warmup / 10, 1)
def get_attention_module(self, which: AttentionLayer):
which_attn = "attn1" if which is AttentionLayer.SELF else "attn2"
module = next(module for name,module in self.inner_model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name)
return module
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)