mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup initial experiments
This commit is contained in:
parent
b0b1993918
commit
33d6603fef
19
ldm/models/diffusion/cross_attention.py
Normal file
19
ldm/models/diffusion/cross_attention.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention:
|
||||||
|
|
||||||
|
class AttentionType(Enum):
|
||||||
|
SELF = 1
|
||||||
|
TOKENS = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_attention_module(cls, model, which: AttentionType):
|
||||||
|
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
|
||||||
|
module = next(module for name, module in model.named_modules() if
|
||||||
|
type(module).__name__ == "CrossAttention" and which_attn in name)
|
||||||
|
return module
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def inject_attention_mask_capture(cls, model, callback):
|
||||||
|
pass
|
@ -27,9 +27,6 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
|||||||
minval = max(min(-1, scale*minval), -threshold)
|
minval = max(min(-1, scale*minval), -threshold)
|
||||||
return torch.clamp(result, min=minval, max=maxval)
|
return torch.clamp(result, min=minval, max=maxval)
|
||||||
|
|
||||||
class AttentionLayer(Enum):
|
|
||||||
SELF = 1
|
|
||||||
TOKENS = 2
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model, threshold = 0, warmup = 0):
|
def __init__(self, model, threshold = 0, warmup = 0):
|
||||||
@ -40,13 +37,6 @@ class CFGDenoiser(nn.Module):
|
|||||||
self.warmup = max(warmup / 10, 1)
|
self.warmup = max(warmup / 10, 1)
|
||||||
|
|
||||||
|
|
||||||
def get_attention_module(self, which: AttentionLayer):
|
|
||||||
which_attn = "attn1" if which is AttentionLayer.SELF else "attn2"
|
|
||||||
module = next(module for name,module in self.inner_model.named_modules() if
|
|
||||||
type(module).__name__ == "CrossAttention" and which_attn in name)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user