From b0b19939183c58e2a502aafad4bceddcc33575fe Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Wed, 12 Oct 2022 23:29:48 +0200 Subject: [PATCH] initial experiments --- ldm/models/diffusion/ksampler.py | 16 ++++++++++++++++ ldm/models/diffusion/sampler.py | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..55800d0a5c 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -1,4 +1,6 @@ """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" +from enum import Enum + import k_diffusion as K import torch import torch.nn as nn @@ -25,6 +27,9 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): minval = max(min(-1, scale*minval), -threshold) return torch.clamp(result, min=minval, max=maxval) +class AttentionLayer(Enum): + SELF = 1 + TOKENS = 2 class CFGDenoiser(nn.Module): def __init__(self, model, threshold = 0, warmup = 0): @@ -34,11 +39,22 @@ class CFGDenoiser(nn.Module): self.warmup_max = warmup self.warmup = max(warmup / 10, 1) + + def get_attention_module(self, which: AttentionLayer): + which_attn = "attn1" if which is AttentionLayer.SELF else "attn2" + module = next(module for name,module in self.inner_model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name) + return module + + def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + + module = self.get_attention_module(AttentionLayer.TOKENS) + if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..eb7caebba0 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -4,6 +4,8 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc ''' +from enum import Enum + import torch import numpy as np from tqdm import tqdm @@ -411,3 +413,6 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + +