initial experiments

This commit is contained in:
Damian at mba 2022-10-12 23:29:48 +02:00
parent 07a3df6001
commit b0b1993918
2 changed files with 21 additions and 0 deletions

View File

@ -1,4 +1,6 @@
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
from enum import Enum
import k_diffusion as K import k_diffusion as K
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,6 +27,9 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
minval = max(min(-1, scale*minval), -threshold) minval = max(min(-1, scale*minval), -threshold)
return torch.clamp(result, min=minval, max=maxval) return torch.clamp(result, min=minval, max=maxval)
class AttentionLayer(Enum):
SELF = 1
TOKENS = 2
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0): def __init__(self, model, threshold = 0, warmup = 0):
@ -34,11 +39,22 @@ class CFGDenoiser(nn.Module):
self.warmup_max = warmup self.warmup_max = warmup
self.warmup = max(warmup / 10, 1) self.warmup = max(warmup / 10, 1)
def get_attention_module(self, which: AttentionLayer):
which_attn = "attn1" if which is AttentionLayer.SELF else "attn2"
module = next(module for name,module in self.inner_model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name)
return module
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
module = self.get_attention_module(AttentionLayer.TOKENS)
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1 self.warmup += 1

View File

@ -4,6 +4,8 @@ ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
''' '''
from enum import Enum
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
@ -411,3 +413,6 @@ class Sampler(object):
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
''' '''
return self.model.q_sample(x0,ts) return self.model.q_sample(x0,ts)