mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
initial experiments
This commit is contained in:
parent
07a3df6001
commit
b0b1993918
@ -1,4 +1,6 @@
|
|||||||
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -25,6 +27,9 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
|||||||
minval = max(min(-1, scale*minval), -threshold)
|
minval = max(min(-1, scale*minval), -threshold)
|
||||||
return torch.clamp(result, min=minval, max=maxval)
|
return torch.clamp(result, min=minval, max=maxval)
|
||||||
|
|
||||||
|
class AttentionLayer(Enum):
|
||||||
|
SELF = 1
|
||||||
|
TOKENS = 2
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model, threshold = 0, warmup = 0):
|
def __init__(self, model, threshold = 0, warmup = 0):
|
||||||
@ -34,11 +39,22 @@ class CFGDenoiser(nn.Module):
|
|||||||
self.warmup_max = warmup
|
self.warmup_max = warmup
|
||||||
self.warmup = max(warmup / 10, 1)
|
self.warmup = max(warmup / 10, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_module(self, which: AttentionLayer):
|
||||||
|
which_attn = "attn1" if which is AttentionLayer.SELF else "attn2"
|
||||||
|
module = next(module for name,module in self.inner_model.named_modules() if
|
||||||
|
type(module).__name__ == "CrossAttention" and which_attn in name)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = torch.cat([uncond, cond])
|
cond_in = torch.cat([uncond, cond])
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
|
|
||||||
|
module = self.get_attention_module(AttentionLayer.TOKENS)
|
||||||
|
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
self.warmup += 1
|
self.warmup += 1
|
||||||
|
@ -4,6 +4,8 @@ ldm.models.diffusion.sampler
|
|||||||
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -411,3 +413,6 @@ class Sampler(object):
|
|||||||
return self.model.inner_model.q_sample(x0,ts)
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
'''
|
'''
|
||||||
return self.model.q_sample(x0,ts)
|
return self.model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user