mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add missing files
This commit is contained in:
parent
8c9764476c
commit
d334f7f1f6
9
.gitignore
vendored
9
.gitignore
vendored
@ -214,9 +214,9 @@ gfpgan/
|
|||||||
configs/models.yaml
|
configs/models.yaml
|
||||||
|
|
||||||
# weights (will be created by installer)
|
# weights (will be created by installer)
|
||||||
models/ldm/stable-diffusion-v1/*.ckpt
|
# models/ldm/stable-diffusion-v1/*.ckpt
|
||||||
models/clipseg
|
# models/clipseg
|
||||||
models/gfpgan
|
# models/gfpgan
|
||||||
|
|
||||||
# ignore initfile
|
# ignore initfile
|
||||||
.invokeai
|
.invokeai
|
||||||
@ -232,6 +232,3 @@ installer/install.bat
|
|||||||
installer/install.sh
|
installer/install.sh
|
||||||
installer/update.bat
|
installer/update.bat
|
||||||
installer/update.sh
|
installer/update.sh
|
||||||
|
|
||||||
# no longer stored in source directory
|
|
||||||
models
|
|
||||||
|
10
invokeai/models/__init__.py
Normal file
10
invokeai/models/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
'''
|
||||||
|
Initialization file for the invokeai.models package
|
||||||
|
'''
|
||||||
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
|
from .diffusion import InvokeAIDiffuserComponent
|
||||||
|
from .diffusion.ddim import DDIMSampler
|
||||||
|
from .diffusion.ksampler import KSampler
|
||||||
|
from .diffusion.plms import PLMSSampler
|
||||||
|
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
10
invokeai/models/__init__.py~
Normal file
10
invokeai/models/__init__.py~
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
'''
|
||||||
|
Initialization file for the invokeai.models package
|
||||||
|
'''
|
||||||
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
|
from .diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from .diffusion.ddim import DDIMSampler
|
||||||
|
from .diffusion.ksampler import KSampler
|
||||||
|
from .diffusion.plms import PLMSSampler
|
||||||
|
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
BIN
invokeai/models/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
invokeai/models/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/__pycache__/autoencoder.cpython-310.pyc
Normal file
BIN
invokeai/models/__pycache__/autoencoder.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/__pycache__/model_manager.cpython-310.pyc
Normal file
BIN
invokeai/models/__pycache__/model_manager.cpython-310.pyc
Normal file
Binary file not shown.
596
invokeai/models/autoencoder.py
Normal file
596
invokeai/models/autoencoder.py
Normal file
@ -0,0 +1,596 @@
|
|||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
|
from ldm.modules.distributions.distributions import (
|
||||||
|
DiagonalGaussianDistribution,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
class VQModel(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
n_embed,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key='image',
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
batch_resize_range=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
lr_g_factor=1.0,
|
||||||
|
remap=None,
|
||||||
|
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||||
|
use_ema=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.n_embed = n_embed
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
self.quantize = VectorQuantizer(
|
||||||
|
n_embed,
|
||||||
|
embed_dim,
|
||||||
|
beta=0.25,
|
||||||
|
remap=remap,
|
||||||
|
sane_index_shape=sane_index_shape,
|
||||||
|
)
|
||||||
|
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(
|
||||||
|
embed_dim, ddconfig['z_channels'], 1
|
||||||
|
)
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels) == int
|
||||||
|
self.register_buffer(
|
||||||
|
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||||
|
)
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
self.batch_resize_range = batch_resize_range
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
print(
|
||||||
|
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema = LitEma(self)
|
||||||
|
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ema_scope(self, context=None):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.store(self.parameters())
|
||||||
|
self.model_ema.copy_to(self)
|
||||||
|
if context is not None:
|
||||||
|
print(f'{context}: Switched to EMA weights')
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.restore(self.parameters())
|
||||||
|
if context is not None:
|
||||||
|
print(f'{context}: Restored training weights')
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print('Deleting key {} from state_dict.'.format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||||
|
print(
|
||||||
|
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||||
|
)
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f'Missing Keys: {missing}')
|
||||||
|
print(f'Unexpected Keys: {unexpected}')
|
||||||
|
|
||||||
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema(self)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
return quant, emb_loss, info
|
||||||
|
|
||||||
|
def encode_to_prequant(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, quant):
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def decode_code(self, code_b):
|
||||||
|
quant_b = self.quantize.embed_code(code_b)
|
||||||
|
dec = self.decode(quant_b)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, return_pred_indices=False):
|
||||||
|
quant, diff, (_, _, ind) = self.encode(input)
|
||||||
|
dec = self.decode(quant)
|
||||||
|
if return_pred_indices:
|
||||||
|
return dec, diff, ind
|
||||||
|
return dec, diff
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = (
|
||||||
|
x.permute(0, 3, 1, 2)
|
||||||
|
.to(memory_format=torch.contiguous_format)
|
||||||
|
.float()
|
||||||
|
)
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
lower_size = self.batch_resize_range[0]
|
||||||
|
upper_size = self.batch_resize_range[1]
|
||||||
|
if self.global_step <= 4:
|
||||||
|
# do the first few batches with max size to avoid later oom
|
||||||
|
new_resize = upper_size
|
||||||
|
else:
|
||||||
|
new_resize = np.random.choice(
|
||||||
|
np.arange(lower_size, upper_size + 16, 16)
|
||||||
|
)
|
||||||
|
if new_resize != x.shape[2]:
|
||||||
|
x = F.interpolate(x, size=new_resize, mode='bicubic')
|
||||||
|
x = x.detach()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
# https://github.com/pytorch/pytorch/issues/37142
|
||||||
|
# try not to fool the heuristics
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# autoencode
|
||||||
|
aeloss, log_dict_ae = self.loss(
|
||||||
|
qloss,
|
||||||
|
x,
|
||||||
|
xrec,
|
||||||
|
optimizer_idx,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='train',
|
||||||
|
predicted_indices=ind,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_dict(
|
||||||
|
log_dict_ae,
|
||||||
|
prog_bar=False,
|
||||||
|
logger=True,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(
|
||||||
|
qloss,
|
||||||
|
x,
|
||||||
|
xrec,
|
||||||
|
optimizer_idx,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='train',
|
||||||
|
)
|
||||||
|
self.log_dict(
|
||||||
|
log_dict_disc,
|
||||||
|
prog_bar=False,
|
||||||
|
logger=True,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
|
with self.ema_scope():
|
||||||
|
log_dict_ema = self._validation_step(
|
||||||
|
batch, batch_idx, suffix='_ema'
|
||||||
|
)
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
def _validation_step(self, batch, batch_idx, suffix=''):
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
aeloss, log_dict_ae = self.loss(
|
||||||
|
qloss,
|
||||||
|
x,
|
||||||
|
xrec,
|
||||||
|
0,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='val' + suffix,
|
||||||
|
predicted_indices=ind,
|
||||||
|
)
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(
|
||||||
|
qloss,
|
||||||
|
x,
|
||||||
|
xrec,
|
||||||
|
1,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='val' + suffix,
|
||||||
|
predicted_indices=ind,
|
||||||
|
)
|
||||||
|
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
|
||||||
|
self.log(
|
||||||
|
f'val{suffix}/rec_loss',
|
||||||
|
rec_loss,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True,
|
||||||
|
on_step=False,
|
||||||
|
on_epoch=True,
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
f'val{suffix}/aeloss',
|
||||||
|
aeloss,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True,
|
||||||
|
on_step=False,
|
||||||
|
on_epoch=True,
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||||
|
del log_dict_ae[f'val{suffix}/rec_loss']
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr_d = self.learning_rate
|
||||||
|
lr_g = self.lr_g_factor * self.learning_rate
|
||||||
|
print('lr_d', lr_d)
|
||||||
|
print('lr_g', lr_g)
|
||||||
|
opt_ae = torch.optim.Adam(
|
||||||
|
list(self.encoder.parameters())
|
||||||
|
+ list(self.decoder.parameters())
|
||||||
|
+ list(self.quantize.parameters())
|
||||||
|
+ list(self.quant_conv.parameters())
|
||||||
|
+ list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr_g,
|
||||||
|
betas=(0.5, 0.9),
|
||||||
|
)
|
||||||
|
opt_disc = torch.optim.Adam(
|
||||||
|
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.scheduler_config is not None:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print('Setting up LambdaLR scheduler...')
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(
|
||||||
|
opt_ae, lr_lambda=scheduler.schedule
|
||||||
|
),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(
|
||||||
|
opt_disc, lr_lambda=scheduler.schedule
|
||||||
|
),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return [opt_ae, opt_disc], scheduler
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if only_inputs:
|
||||||
|
log['inputs'] = x
|
||||||
|
return log
|
||||||
|
xrec, _ = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log['inputs'] = x
|
||||||
|
log['reconstructions'] = xrec
|
||||||
|
if plot_ema:
|
||||||
|
with self.ema_scope():
|
||||||
|
xrec_ema, _ = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
xrec_ema = self.to_rgb(xrec_ema)
|
||||||
|
log['reconstructions_ema'] = xrec_ema
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == 'segmentation'
|
||||||
|
if not hasattr(self, 'colorize'):
|
||||||
|
self.register_buffer(
|
||||||
|
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||||
|
)
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VQModelInterface(VQModel):
|
||||||
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
|
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, h, force_not_quantize=False):
|
||||||
|
# also go through quantization layer
|
||||||
|
if not force_not_quantize:
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
else:
|
||||||
|
quant = h
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key='image',
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
assert ddconfig['double_z']
|
||||||
|
self.quant_conv = torch.nn.Conv2d(
|
||||||
|
2 * ddconfig['z_channels'], 2 * embed_dim, 1
|
||||||
|
)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(
|
||||||
|
embed_dim, ddconfig['z_channels'], 1
|
||||||
|
)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels) == int
|
||||||
|
self.register_buffer(
|
||||||
|
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||||
|
)
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print('Deleting key {} from state_dict.'.format(k))
|
||||||
|
del sd[k]
|
||||||
|
self.load_state_dict(sd, strict=False)
|
||||||
|
print(f'Restored from {path}')
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(z)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, sample_posterior=True):
|
||||||
|
posterior = self.encode(input)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample()
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z)
|
||||||
|
return dec, posterior
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = (
|
||||||
|
x.permute(0, 3, 1, 2)
|
||||||
|
.to(memory_format=torch.contiguous_format)
|
||||||
|
.float()
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# train encoder+decoder+logvar
|
||||||
|
aeloss, log_dict_ae = self.loss(
|
||||||
|
inputs,
|
||||||
|
reconstructions,
|
||||||
|
posterior,
|
||||||
|
optimizer_idx,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='train',
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
'aeloss',
|
||||||
|
aeloss,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
)
|
||||||
|
self.log_dict(
|
||||||
|
log_dict_ae,
|
||||||
|
prog_bar=False,
|
||||||
|
logger=True,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=False,
|
||||||
|
)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# train the discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(
|
||||||
|
inputs,
|
||||||
|
reconstructions,
|
||||||
|
posterior,
|
||||||
|
optimizer_idx,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='train',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log(
|
||||||
|
'discloss',
|
||||||
|
discloss,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
)
|
||||||
|
self.log_dict(
|
||||||
|
log_dict_disc,
|
||||||
|
prog_bar=False,
|
||||||
|
logger=True,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=False,
|
||||||
|
)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
aeloss, log_dict_ae = self.loss(
|
||||||
|
inputs,
|
||||||
|
reconstructions,
|
||||||
|
posterior,
|
||||||
|
0,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='val',
|
||||||
|
)
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(
|
||||||
|
inputs,
|
||||||
|
reconstructions,
|
||||||
|
posterior,
|
||||||
|
1,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split='val',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr = self.learning_rate
|
||||||
|
opt_ae = torch.optim.Adam(
|
||||||
|
list(self.encoder.parameters())
|
||||||
|
+ list(self.decoder.parameters())
|
||||||
|
+ list(self.quant_conv.parameters())
|
||||||
|
+ list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr,
|
||||||
|
betas=(0.5, 0.9),
|
||||||
|
)
|
||||||
|
opt_disc = torch.optim.Adam(
|
||||||
|
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
||||||
|
)
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if not only_inputs:
|
||||||
|
xrec, posterior = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||||
|
log['reconstructions'] = xrec
|
||||||
|
log['inputs'] = x
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == 'segmentation'
|
||||||
|
if not hasattr(self, 'colorize'):
|
||||||
|
self.register_buffer(
|
||||||
|
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||||
|
)
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityFirstStage(torch.nn.Module):
|
||||||
|
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||||
|
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def quantize(self, x, *args, **kwargs):
|
||||||
|
if self.vq_interface:
|
||||||
|
return x, None, [None, None, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x
|
4
invokeai/models/diffusion/__init__.py
Normal file
4
invokeai/models/diffusion/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
'''
|
||||||
|
Initialization file for invokeai.models.diffusion
|
||||||
|
'''
|
||||||
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
4
invokeai/models/diffusion/__init__.py~
Normal file
4
invokeai/models/diffusion/__init__.py~
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
'''
|
||||||
|
Initialization file for invokeai.models.diffusion
|
||||||
|
'''
|
||||||
|
from shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
BIN
invokeai/models/diffusion/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/ddim.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/ddim.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/ddpm.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/ddpm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/ksampler.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/ksampler.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/plms.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/plms.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/sampler.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/sampler.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
355
invokeai/models/diffusion/classifier.py
Normal file
355
invokeai/models/diffusion/classifier.py
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
from copy import deepcopy
|
||||||
|
from einops import rearrange
|
||||||
|
from glob import glob
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.openaimodel import (
|
||||||
|
EncoderUNetModel,
|
||||||
|
UNetModel,
|
||||||
|
)
|
||||||
|
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||||
|
|
||||||
|
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
|
||||||
|
|
||||||
|
|
||||||
|
def disabled_train(self, mode=True):
|
||||||
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
does not change anymore."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
diffusion_path,
|
||||||
|
num_classes,
|
||||||
|
ckpt_path=None,
|
||||||
|
pool='attention',
|
||||||
|
label_key=None,
|
||||||
|
diffusion_ckpt_path=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
weight_decay=1.0e-2,
|
||||||
|
log_steps=10,
|
||||||
|
monitor='val/loss',
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
# get latest config of diffusion model
|
||||||
|
diffusion_config = natsorted(
|
||||||
|
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
|
||||||
|
)[-1]
|
||||||
|
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||||
|
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||||
|
self.load_diffusion()
|
||||||
|
|
||||||
|
self.monitor = monitor
|
||||||
|
self.numd = (
|
||||||
|
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||||
|
)
|
||||||
|
self.log_time_interval = (
|
||||||
|
self.diffusion_model.num_timesteps // log_steps
|
||||||
|
)
|
||||||
|
self.log_steps = log_steps
|
||||||
|
|
||||||
|
self.label_key = (
|
||||||
|
label_key
|
||||||
|
if not hasattr(self.diffusion_model, 'cond_stage_key')
|
||||||
|
else self.diffusion_model.cond_stage_key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.label_key is not None
|
||||||
|
), 'label_key neither in diffusion model nor in model.params'
|
||||||
|
|
||||||
|
if self.label_key not in __models__:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.load_classifier(ckpt_path, pool)
|
||||||
|
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.use_scheduler = self.scheduler_config is not None
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||||
|
sd = torch.load(path, map_location='cpu')
|
||||||
|
if 'state_dict' in list(sd.keys()):
|
||||||
|
sd = sd['state_dict']
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print('Deleting key {} from state_dict.'.format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = (
|
||||||
|
self.load_state_dict(sd, strict=False)
|
||||||
|
if not only_model
|
||||||
|
else self.model.load_state_dict(sd, strict=False)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||||
|
)
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f'Missing Keys: {missing}')
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
print(f'Unexpected Keys: {unexpected}')
|
||||||
|
|
||||||
|
def load_diffusion(self):
|
||||||
|
model = instantiate_from_config(self.diffusion_config)
|
||||||
|
self.diffusion_model = model.eval()
|
||||||
|
self.diffusion_model.train = disabled_train
|
||||||
|
for param in self.diffusion_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def load_classifier(self, ckpt_path, pool):
|
||||||
|
model_config = deepcopy(
|
||||||
|
self.diffusion_config.params.unet_config.params
|
||||||
|
)
|
||||||
|
model_config.in_channels = (
|
||||||
|
self.diffusion_config.params.unet_config.params.out_channels
|
||||||
|
)
|
||||||
|
model_config.out_channels = self.num_classes
|
||||||
|
if self.label_key == 'class_label':
|
||||||
|
model_config.pool = pool
|
||||||
|
|
||||||
|
self.model = __models__[self.label_key](**model_config)
|
||||||
|
if ckpt_path is not None:
|
||||||
|
print(
|
||||||
|
'#####################################################################'
|
||||||
|
)
|
||||||
|
print(f'load from ckpt "{ckpt_path}"')
|
||||||
|
print(
|
||||||
|
'#####################################################################'
|
||||||
|
)
|
||||||
|
self.init_from_ckpt(ckpt_path)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_x_noisy(self, x, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x))
|
||||||
|
continuous_sqrt_alpha_cumprod = None
|
||||||
|
if self.diffusion_model.use_continuous_noise:
|
||||||
|
continuous_sqrt_alpha_cumprod = (
|
||||||
|
self.diffusion_model.sample_continuous_noise_level(
|
||||||
|
x.shape[0], t + 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# todo: make sure t+1 is correct here
|
||||||
|
|
||||||
|
return self.diffusion_model.q_sample(
|
||||||
|
x_start=x,
|
||||||
|
t=t,
|
||||||
|
noise=noise,
|
||||||
|
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x_noisy, t, *args, **kwargs):
|
||||||
|
return self.model(x_noisy, t)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = rearrange(x, 'b h w c -> b c h w')
|
||||||
|
x = x.to(memory_format=torch.contiguous_format).float()
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_conditioning(self, batch, k=None):
|
||||||
|
if k is None:
|
||||||
|
k = self.label_key
|
||||||
|
assert k is not None, 'Needs to provide label key'
|
||||||
|
|
||||||
|
targets = batch[k].to(self.device)
|
||||||
|
|
||||||
|
if self.label_key == 'segmentation':
|
||||||
|
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||||
|
for down in range(self.numd):
|
||||||
|
h, w = targets.shape[-2:]
|
||||||
|
targets = F.interpolate(
|
||||||
|
targets, size=(h // 2, w // 2), mode='nearest'
|
||||||
|
)
|
||||||
|
|
||||||
|
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||||
|
|
||||||
|
return targets
|
||||||
|
|
||||||
|
def compute_top_k(self, logits, labels, k, reduction='mean'):
|
||||||
|
_, top_ks = torch.topk(logits, k, dim=1)
|
||||||
|
if reduction == 'mean':
|
||||||
|
return (
|
||||||
|
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||||
|
)
|
||||||
|
elif reduction == 'none':
|
||||||
|
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||||
|
|
||||||
|
def on_train_epoch_start(self):
|
||||||
|
# save some memory
|
||||||
|
self.diffusion_model.model.to('cpu')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def write_logs(self, loss, logits, targets):
|
||||||
|
log_prefix = 'train' if self.training else 'val'
|
||||||
|
log = {}
|
||||||
|
log[f'{log_prefix}/loss'] = loss.mean()
|
||||||
|
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
|
||||||
|
logits, targets, k=1, reduction='mean'
|
||||||
|
)
|
||||||
|
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
|
||||||
|
logits, targets, k=5, reduction='mean'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_dict(
|
||||||
|
log,
|
||||||
|
prog_bar=False,
|
||||||
|
logger=True,
|
||||||
|
on_step=self.training,
|
||||||
|
on_epoch=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
'global_step',
|
||||||
|
self.global_step,
|
||||||
|
logger=False,
|
||||||
|
on_epoch=False,
|
||||||
|
prog_bar=True,
|
||||||
|
)
|
||||||
|
lr = self.optimizers().param_groups[0]['lr']
|
||||||
|
self.log(
|
||||||
|
'lr_abs',
|
||||||
|
lr,
|
||||||
|
on_step=True,
|
||||||
|
logger=True,
|
||||||
|
on_epoch=False,
|
||||||
|
prog_bar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def shared_step(self, batch, t=None):
|
||||||
|
x, *_ = self.diffusion_model.get_input(
|
||||||
|
batch, k=self.diffusion_model.first_stage_key
|
||||||
|
)
|
||||||
|
targets = self.get_conditioning(batch)
|
||||||
|
if targets.dim() == 4:
|
||||||
|
targets = targets.argmax(dim=1)
|
||||||
|
if t is None:
|
||||||
|
t = torch.randint(
|
||||||
|
0,
|
||||||
|
self.diffusion_model.num_timesteps,
|
||||||
|
(x.shape[0],),
|
||||||
|
device=self.device,
|
||||||
|
).long()
|
||||||
|
else:
|
||||||
|
t = torch.full(
|
||||||
|
size=(x.shape[0],), fill_value=t, device=self.device
|
||||||
|
).long()
|
||||||
|
x_noisy = self.get_x_noisy(x, t)
|
||||||
|
logits = self(x_noisy, t)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||||
|
|
||||||
|
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||||
|
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss, logits, x_noisy, targets
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
loss, *_ = self.shared_step(batch)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def reset_noise_accs(self):
|
||||||
|
self.noisy_acc = {
|
||||||
|
t: {'acc@1': [], 'acc@5': []}
|
||||||
|
for t in range(
|
||||||
|
0,
|
||||||
|
self.diffusion_model.num_timesteps,
|
||||||
|
self.diffusion_model.log_every_t,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
def on_validation_start(self):
|
||||||
|
self.reset_noise_accs()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
loss, *_ = self.shared_step(batch)
|
||||||
|
|
||||||
|
for t in self.noisy_acc:
|
||||||
|
_, logits, _, targets = self.shared_step(batch, t)
|
||||||
|
self.noisy_acc[t]['acc@1'].append(
|
||||||
|
self.compute_top_k(logits, targets, k=1, reduction='mean')
|
||||||
|
)
|
||||||
|
self.noisy_acc[t]['acc@5'].append(
|
||||||
|
self.compute_top_k(logits, targets, k=5, reduction='mean')
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = AdamW(
|
||||||
|
self.model.parameters(),
|
||||||
|
lr=self.learning_rate,
|
||||||
|
weight_decay=self.weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_scheduler:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print('Setting up LambdaLR scheduler...')
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(
|
||||||
|
optimizer, lr_lambda=scheduler.schedule
|
||||||
|
),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return [optimizer], scheduler
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||||
|
log['inputs'] = x
|
||||||
|
|
||||||
|
y = self.get_conditioning(batch)
|
||||||
|
|
||||||
|
if self.label_key == 'class_label':
|
||||||
|
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
|
||||||
|
log['labels'] = y
|
||||||
|
|
||||||
|
if ismap(y):
|
||||||
|
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||||
|
|
||||||
|
for step in range(self.log_steps):
|
||||||
|
current_time = step * self.log_time_interval
|
||||||
|
|
||||||
|
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||||
|
|
||||||
|
log[f'inputs@t{current_time}'] = x_noisy
|
||||||
|
|
||||||
|
pred = F.one_hot(
|
||||||
|
logits.argmax(dim=1), num_classes=self.num_classes
|
||||||
|
)
|
||||||
|
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||||
|
|
||||||
|
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
|
||||||
|
pred
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in log:
|
||||||
|
log[key] = log[key][:N]
|
||||||
|
|
||||||
|
return log
|
642
invokeai/models/diffusion/cross_attention_control.py
Normal file
642
invokeai/models/diffusion/cross_attention_control.py
Normal file
@ -0,0 +1,642 @@
|
|||||||
|
|
||||||
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import math
|
||||||
|
from typing import Optional, Callable
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
import diffusers
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from compel.cross_attention_control import Arguments
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionType(enum.Enum):
|
||||||
|
SELF = 1
|
||||||
|
TOKENS = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Context:
|
||||||
|
|
||||||
|
cross_attention_mask: Optional[torch.Tensor]
|
||||||
|
cross_attention_index_map: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
class Action(enum.Enum):
|
||||||
|
NONE = 0
|
||||||
|
SAVE = 1,
|
||||||
|
APPLY = 2
|
||||||
|
|
||||||
|
def __init__(self, arguments: Arguments, step_count: int):
|
||||||
|
"""
|
||||||
|
:param arguments: Arguments for the cross-attention control process
|
||||||
|
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||||
|
"""
|
||||||
|
self.cross_attention_mask = None
|
||||||
|
self.cross_attention_index_map = None
|
||||||
|
self.self_cross_attention_action = Context.Action.NONE
|
||||||
|
self.tokens_cross_attention_action = Context.Action.NONE
|
||||||
|
self.arguments = arguments
|
||||||
|
self.step_count = step_count
|
||||||
|
|
||||||
|
self.self_cross_attention_module_identifiers = []
|
||||||
|
self.tokens_cross_attention_module_identifiers = []
|
||||||
|
|
||||||
|
self.saved_cross_attention_maps = {}
|
||||||
|
|
||||||
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
|
def register_cross_attention_modules(self, model):
|
||||||
|
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||||
|
if name in self.self_cross_attention_module_identifiers:
|
||||||
|
assert False, f"name {name} cannot appear more than once"
|
||||||
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
|
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
|
if name in self.tokens_cross_attention_module_identifiers:
|
||||||
|
assert False, f"name {name} cannot appear more than once"
|
||||||
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
|
|
||||||
|
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
|
if cross_attention_type == CrossAttentionType.SELF:
|
||||||
|
self.self_cross_attention_action = Context.Action.SAVE
|
||||||
|
else:
|
||||||
|
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||||
|
|
||||||
|
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
|
if cross_attention_type == CrossAttentionType.SELF:
|
||||||
|
self.self_cross_attention_action = Context.Action.APPLY
|
||||||
|
else:
|
||||||
|
self.tokens_cross_attention_action = Context.Action.APPLY
|
||||||
|
|
||||||
|
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||||
|
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||||
|
|
||||||
|
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||||
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||||
|
return self.self_cross_attention_action == Context.Action.SAVE
|
||||||
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||||
|
return self.tokens_cross_attention_action == Context.Action.SAVE
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||||
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||||
|
return self.self_cross_attention_action == Context.Action.APPLY
|
||||||
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||||
|
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
||||||
|
-> list[CrossAttentionType]:
|
||||||
|
"""
|
||||||
|
Should cross-attention control be applied on the given step?
|
||||||
|
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||||
|
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||||
|
"""
|
||||||
|
if percent_through is None:
|
||||||
|
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
||||||
|
|
||||||
|
opts = self.arguments.edit_options
|
||||||
|
to_control = []
|
||||||
|
if opts['s_start'] <= percent_through < opts['s_end']:
|
||||||
|
to_control.append(CrossAttentionType.SELF)
|
||||||
|
if opts['t_start'] <= percent_through < opts['t_end']:
|
||||||
|
to_control.append(CrossAttentionType.TOKENS)
|
||||||
|
return to_control
|
||||||
|
|
||||||
|
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
||||||
|
slice_size: Optional[int]):
|
||||||
|
if identifier not in self.saved_cross_attention_maps:
|
||||||
|
self.saved_cross_attention_maps[identifier] = {
|
||||||
|
'dim': dim,
|
||||||
|
'slice_size': slice_size,
|
||||||
|
'slices': {offset or 0: slice}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
|
||||||
|
|
||||||
|
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
|
||||||
|
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||||
|
if requested_dim is None:
|
||||||
|
if saved_attention_dict['dim'] is not None:
|
||||||
|
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||||
|
return saved_attention_dict['slices'][0]
|
||||||
|
|
||||||
|
if saved_attention_dict['dim'] == requested_dim:
|
||||||
|
if slice_size != saved_attention_dict['slice_size']:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||||
|
return saved_attention_dict['slices'][requested_offset]
|
||||||
|
|
||||||
|
if saved_attention_dict['dim'] is None:
|
||||||
|
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||||
|
if requested_dim == 0:
|
||||||
|
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||||
|
elif requested_dim == 1:
|
||||||
|
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
|
||||||
|
|
||||||
|
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||||
|
|
||||||
|
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||||
|
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||||
|
if saved_attention is None:
|
||||||
|
return None, None
|
||||||
|
return saved_attention['dim'], saved_attention['slice_size']
|
||||||
|
|
||||||
|
def clear_requests(self, cleanup=True):
|
||||||
|
self.tokens_cross_attention_action = Context.Action.NONE
|
||||||
|
self.self_cross_attention_action = Context.Action.NONE
|
||||||
|
if cleanup:
|
||||||
|
self.saved_cross_attention_maps = {}
|
||||||
|
|
||||||
|
def offload_saved_attention_slices_to_cpu(self):
|
||||||
|
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||||
|
for offset, slice in map_dict['slices'].items():
|
||||||
|
map_dict[offset] = slice.to('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAICrossAttentionMixin:
|
||||||
|
"""
|
||||||
|
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||||
|
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||||
|
and dymamic slicing strategy selection.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
self.attention_slice_wrangler = None
|
||||||
|
self.slicing_strategy_getter = None
|
||||||
|
self.attention_slice_calculated_callback = None
|
||||||
|
|
||||||
|
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||||
|
'''
|
||||||
|
Set custom attention calculator to be called when attention is calculated
|
||||||
|
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||||
|
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||||
|
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||||
|
`suggested_attention_slice` is the default-calculated attention slice
|
||||||
|
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
|
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||||
|
|
||||||
|
Pass None to use the default attention calculation.
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
self.attention_slice_wrangler = wrangler
|
||||||
|
|
||||||
|
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||||
|
self.slicing_strategy_getter = getter
|
||||||
|
|
||||||
|
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||||
|
self.attention_slice_calculated_callback = callback
|
||||||
|
|
||||||
|
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||||
|
# calculate attention scores
|
||||||
|
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||||
|
attention_scores = torch.baddbmm(
|
||||||
|
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||||
|
query,
|
||||||
|
key.transpose(-1, -2),
|
||||||
|
beta=0,
|
||||||
|
alpha=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate attention slice by taking the best scores for each latent pixel
|
||||||
|
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||||
|
attention_slice_wrangler = self.attention_slice_wrangler
|
||||||
|
if attention_slice_wrangler is not None:
|
||||||
|
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||||
|
else:
|
||||||
|
attention_slice = default_attention_slice
|
||||||
|
|
||||||
|
if self.attention_slice_calculated_callback is not None:
|
||||||
|
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||||
|
|
||||||
|
hidden_states = torch.bmm(attention_slice, value)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[0], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_mps_v1(self, q, k, v):
|
||||||
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
else:
|
||||||
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
def einsum_op_mps_v2(self, q, k, v):
|
||||||
|
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
else:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||||
|
|
||||||
|
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||||
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||||
|
if size_mb <= max_tensor_mb:
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||||
|
if div <= q.shape[0]:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||||
|
|
||||||
|
def einsum_op_cuda(self, q, k, v):
|
||||||
|
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||||
|
slicing_strategy_getter = self.slicing_strategy_getter
|
||||||
|
if slicing_strategy_getter is not None:
|
||||||
|
(dim, slice_size) = slicing_strategy_getter(self)
|
||||||
|
if dim is not None:
|
||||||
|
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||||
|
if dim == 0:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||||
|
elif dim == 1:
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||||
|
mem_free_total = get_mem_free_total(q.device)
|
||||||
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
|
||||||
|
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||||
|
if q.device.type == 'cuda':
|
||||||
|
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||||
|
return self.einsum_op_cuda(q, k, v)
|
||||||
|
|
||||||
|
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||||
|
if self.mem_total_gb >= 32:
|
||||||
|
return self.einsum_op_mps_v1(q, k, v)
|
||||||
|
return self.einsum_op_mps_v2(q, k, v)
|
||||||
|
|
||||||
|
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||||
|
# Tested on i7 with 8MB L3 cache.
|
||||||
|
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
||||||
|
if is_running_diffusers:
|
||||||
|
unet = model
|
||||||
|
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||||
|
else:
|
||||||
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
|
||||||
|
def override_cross_attention(model, context: Context, is_running_diffusers = False):
|
||||||
|
"""
|
||||||
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
|
:param model: The unet model to inject into.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# adapted from init_attention_edit
|
||||||
|
device = context.arguments.edited_conditioning.device
|
||||||
|
|
||||||
|
# urgh. should this be hardcoded?
|
||||||
|
max_length = 77
|
||||||
|
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||||
|
mask = torch.zeros(max_length, dtype=torch_dtype(device))
|
||||||
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||||
|
if b0 < max_length:
|
||||||
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||||
|
# these tokens have not been edited
|
||||||
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
|
context.cross_attention_mask = mask.to(device)
|
||||||
|
context.cross_attention_index_map = indices.to(device)
|
||||||
|
if is_running_diffusers:
|
||||||
|
unet = model
|
||||||
|
old_attn_processors = unet.attn_processors
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||||
|
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||||
|
else:
|
||||||
|
# try to re-use an existing slice size
|
||||||
|
default_slice_size = 4
|
||||||
|
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||||
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
|
return old_attn_processors
|
||||||
|
else:
|
||||||
|
context.register_cross_attention_modules(model)
|
||||||
|
inject_attention_function(model, context)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
|
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||||
|
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
||||||
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
|
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||||
|
isinstance(module, cross_attention_class) and which_attn in name]
|
||||||
|
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||||
|
expected_count = 16
|
||||||
|
if cross_attention_modules_in_model_count != expected_count:
|
||||||
|
# non-fatal error but .swap() won't work.
|
||||||
|
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
|
||||||
|
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||||
|
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||||
|
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||||
|
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
|
||||||
|
f"work properly until it is fixed.")
|
||||||
|
return attention_module_tuples
|
||||||
|
|
||||||
|
|
||||||
|
def inject_attention_function(unet, context: Context):
|
||||||
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
|
||||||
|
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
||||||
|
|
||||||
|
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||||
|
|
||||||
|
attention_slice = suggested_attention_slice
|
||||||
|
|
||||||
|
if context.get_should_save_maps(module.identifier):
|
||||||
|
#print(module.identifier, "saving suggested_attention_slice of shape",
|
||||||
|
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||||
|
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
|
||||||
|
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
|
||||||
|
elif context.get_should_apply_saved_maps(module.identifier):
|
||||||
|
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||||
|
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||||
|
|
||||||
|
# slice may have been offloaded to CPU
|
||||||
|
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||||
|
|
||||||
|
if context.is_tokens_cross_attention(module.identifier):
|
||||||
|
index_map = context.cross_attention_index_map
|
||||||
|
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||||
|
this_attention_slice = suggested_attention_slice
|
||||||
|
|
||||||
|
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||||
|
saved_mask = mask
|
||||||
|
this_mask = 1 - mask
|
||||||
|
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
||||||
|
this_attention_slice * this_mask
|
||||||
|
else:
|
||||||
|
# just use everything
|
||||||
|
attention_slice = saved_attention_slice
|
||||||
|
|
||||||
|
return attention_slice
|
||||||
|
|
||||||
|
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
|
for identifier, module in cross_attention_modules:
|
||||||
|
module.identifier = identifier
|
||||||
|
try:
|
||||||
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
|
module.set_slicing_strategy_getter(
|
||||||
|
lambda module: context.get_slicing_strategy(identifier)
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||||
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def remove_attention_function(unet):
|
||||||
|
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
|
for identifier, module in cross_attention_modules:
|
||||||
|
try:
|
||||||
|
# clear wrangler callback
|
||||||
|
module.set_attention_slice_wrangler(None)
|
||||||
|
module.set_slicing_strategy_getter(None)
|
||||||
|
except AttributeError as e:
|
||||||
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||||
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||||
|
if hasattr(error, 'name'): # Python 3.10
|
||||||
|
return error.name == attribute
|
||||||
|
else: # Python 3.9
|
||||||
|
return attribute in str(error)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_mem_free_total(device):
|
||||||
|
#only on cuda
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return None
|
||||||
|
stats = torch.cuda.memory_stats(device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
InvokeAICrossAttentionMixin.__init__(self)
|
||||||
|
|
||||||
|
def _attention(self, query, key, value, attention_mask=None):
|
||||||
|
#default_result = super()._attention(query, key, value)
|
||||||
|
if attention_mask is not None:
|
||||||
|
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||||
|
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||||
|
|
||||||
|
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 🧨diffusers implementation follows
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# base implementation
|
||||||
|
|
||||||
|
class CrossAttnProcessor:
|
||||||
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
|
||||||
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
"""
|
||||||
|
from dataclasses import field, dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SwapCrossAttnContext:
|
||||||
|
modified_text_embeddings: torch.Tensor
|
||||||
|
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||||
|
mask: torch.Tensor # in the target space of the index_map
|
||||||
|
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __int__(self,
|
||||||
|
cac_types_to_do: [CrossAttentionType],
|
||||||
|
modified_text_embeddings: torch.Tensor,
|
||||||
|
index_map: torch.Tensor,
|
||||||
|
mask: torch.Tensor):
|
||||||
|
self.cross_attention_types_to_do = cac_types_to_do
|
||||||
|
self.modified_text_embeddings = modified_text_embeddings
|
||||||
|
self.index_map = index_map
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
||||||
|
return attn_type in self.cross_attention_types_to_do
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \
|
||||||
|
-> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
||||||
|
mask = torch.zeros(max_length)
|
||||||
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
|
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||||
|
if b0 < max_length:
|
||||||
|
if name == "equal":
|
||||||
|
# these tokens remain the same as in the original prompt
|
||||||
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
|
return mask, indices
|
||||||
|
|
||||||
|
|
||||||
|
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||||
|
|
||||||
|
# TODO: dynamically pick slice size based on memory conditions
|
||||||
|
|
||||||
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||||
|
# kwargs
|
||||||
|
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||||
|
|
||||||
|
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||||
|
|
||||||
|
# if cross-attention control is not in play, just call through to the base implementation.
|
||||||
|
if attention_type is CrossAttentionType.SELF or \
|
||||||
|
swap_cross_attn_context is None or \
|
||||||
|
not swap_cross_attn_context.wants_cross_attention_control(attention_type):
|
||||||
|
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||||
|
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||||
|
#else:
|
||||||
|
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask=attention_mask, target_length=sequence_length,
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
dim = query.shape[-1]
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
|
||||||
|
original_text_embeddings = encoder_hidden_states
|
||||||
|
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||||
|
original_text_key = attn.to_k(original_text_embeddings)
|
||||||
|
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||||
|
original_value = attn.to_v(original_text_embeddings)
|
||||||
|
modified_value = attn.to_v(modified_text_embeddings)
|
||||||
|
|
||||||
|
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||||
|
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||||
|
original_value = attn.head_to_batch_dim(original_value)
|
||||||
|
modified_value = attn.head_to_batch_dim(modified_value)
|
||||||
|
|
||||||
|
# compute slices and prepare output tensor
|
||||||
|
batch_size_attention = query.shape[0]
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# do slices
|
||||||
|
for i in range(max(1,hidden_states.shape[0] // self.slice_size)):
|
||||||
|
start_idx = i * self.slice_size
|
||||||
|
end_idx = (i + 1) * self.slice_size
|
||||||
|
|
||||||
|
query_slice = query[start_idx:end_idx]
|
||||||
|
original_key_slice = original_text_key[start_idx:end_idx]
|
||||||
|
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
|
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||||
|
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||||
|
|
||||||
|
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||||
|
# the original attention probabilities must be remapped to account for token index changes in the
|
||||||
|
# modified prompt
|
||||||
|
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1,
|
||||||
|
swap_cross_attn_context.index_map)
|
||||||
|
|
||||||
|
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||||
|
mask = swap_cross_attn_context.mask
|
||||||
|
inverse_mask = 1 - mask
|
||||||
|
attn_slice = \
|
||||||
|
remapped_original_attn_slice * mask + \
|
||||||
|
modified_attn_slice * inverse_mask
|
||||||
|
|
||||||
|
del remapped_original_attn_slice, modified_attn_slice
|
||||||
|
|
||||||
|
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
||||||
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
|
||||||
|
|
||||||
|
# done
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
||||||
|
|
95
invokeai/models/diffusion/cross_attention_map_saving.py
Normal file
95
invokeai/models/diffusion/cross_attention_map_saving.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
||||||
|
|
||||||
|
from .cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMapSaver():
|
||||||
|
|
||||||
|
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||||
|
self.token_ids = token_ids
|
||||||
|
self.latents_shape = latents_shape
|
||||||
|
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||||
|
self.collated_maps = {}
|
||||||
|
|
||||||
|
def clear_maps(self):
|
||||||
|
self.collated_maps = {}
|
||||||
|
|
||||||
|
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||||
|
"""
|
||||||
|
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||||
|
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||||
|
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
key_and_size = f'{key}_{maps.shape[1]}'
|
||||||
|
|
||||||
|
# extract desired tokens
|
||||||
|
maps = maps[:, :, self.token_ids]
|
||||||
|
|
||||||
|
# merge attention heads to a single map per token
|
||||||
|
maps = torch.sum(maps, 0)
|
||||||
|
|
||||||
|
# store
|
||||||
|
if key_and_size not in self.collated_maps:
|
||||||
|
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
||||||
|
self.collated_maps[key_and_size] += maps.cpu()
|
||||||
|
|
||||||
|
def write_maps_to_disk(self, path: str):
|
||||||
|
pil_image = self.get_stacked_maps_image()
|
||||||
|
pil_image.save(path, 'PNG')
|
||||||
|
|
||||||
|
def get_stacked_maps_image(self) -> PIL.Image:
|
||||||
|
"""
|
||||||
|
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||||
|
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||||
|
"""
|
||||||
|
num_tokens = len(self.token_ids)
|
||||||
|
if num_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
latents_height = self.latents_shape[0]
|
||||||
|
latents_width = self.latents_shape[1]
|
||||||
|
|
||||||
|
merged = None
|
||||||
|
|
||||||
|
for key, maps in self.collated_maps.items():
|
||||||
|
|
||||||
|
# maps has shape [(H*W), N] for N tokens
|
||||||
|
# but we want [N, H, W]
|
||||||
|
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||||
|
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||||
|
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||||
|
# and we need to do some dimension juggling
|
||||||
|
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
||||||
|
|
||||||
|
# scale to output size if necessary
|
||||||
|
if this_scale_factor != 1:
|
||||||
|
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
maps_min = torch.min(maps)
|
||||||
|
maps_range = torch.max(maps) - maps_min
|
||||||
|
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||||
|
maps_normalized = (maps - maps_min) / maps_range
|
||||||
|
# expand to (-0.1, 1.1) and clamp
|
||||||
|
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||||
|
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||||
|
|
||||||
|
# merge together, producing a vertical stack
|
||||||
|
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
||||||
|
|
||||||
|
if merged is None:
|
||||||
|
merged = maps_stacked
|
||||||
|
else:
|
||||||
|
# screen blend
|
||||||
|
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
||||||
|
|
||||||
|
if merged is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
merged_bytes = merged.mul(0xff).byte()
|
||||||
|
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|
111
invokeai/models/diffusion/ddim.py
Normal file
111
invokeai/models/diffusion/ddim.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from .sampler import Sampler
|
||||||
|
from ldm.modules.diffusionmodules.util import noise_like
|
||||||
|
|
||||||
|
class DDIMSampler(Sampler):
|
||||||
|
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||||
|
super().__init__(model,schedule,model.num_timesteps,device)
|
||||||
|
|
||||||
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||||
|
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||||
|
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
|
# This is the central routine
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
t,
|
||||||
|
index,
|
||||||
|
repeat_noise=False,
|
||||||
|
use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
step_count:int=1000, # total number of steps
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
if (
|
||||||
|
unconditional_conditioning is None
|
||||||
|
or unconditional_guidance_scale == 1.0
|
||||||
|
):
|
||||||
|
# damian0815 would like to know when/if this code path is used
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
# step_index counts in the opposite direction to index
|
||||||
|
step_index = step_count-(index+1)
|
||||||
|
e_t = self.invokeai_diffuser.do_diffusion_step(
|
||||||
|
x, t,
|
||||||
|
unconditional_conditioning, c,
|
||||||
|
unconditional_guidance_scale,
|
||||||
|
step_index=step_index
|
||||||
|
)
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == 'eps'
|
||||||
|
e_t = score_corrector.modify_score(
|
||||||
|
self.model, e_t, x, t, c, **corrector_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
alphas = (
|
||||||
|
self.model.alphas_cumprod
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_alphas
|
||||||
|
)
|
||||||
|
alphas_prev = (
|
||||||
|
self.model.alphas_cumprod_prev
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_alphas_prev
|
||||||
|
)
|
||||||
|
sqrt_one_minus_alphas = (
|
||||||
|
self.model.sqrt_one_minus_alphas_cumprod
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sqrt_one_minus_alphas
|
||||||
|
)
|
||||||
|
sigmas = (
|
||||||
|
self.model.ddim_sigmas_for_original_num_steps
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sigmas
|
||||||
|
)
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full(
|
||||||
|
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = (
|
||||||
|
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
)
|
||||||
|
if noise_dropout > 0.0:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0, None
|
||||||
|
|
2271
invokeai/models/diffusion/ddpm.py
Normal file
2271
invokeai/models/diffusion/ddpm.py
Normal file
File diff suppressed because it is too large
Load Diff
312
invokeai/models/diffusion/ksampler.py
Normal file
312
invokeai/models/diffusion/ksampler.py
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||||
|
|
||||||
|
import k_diffusion as K
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
|
from .sampler import Sampler
|
||||||
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
|
|
||||||
|
# at this threshold, the scheduler will stop using the Karras
|
||||||
|
# noise schedule and start using the model's schedule
|
||||||
|
STEP_THRESHOLD = 30
|
||||||
|
|
||||||
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||||
|
if threshold <= 0.0:
|
||||||
|
return result
|
||||||
|
maxval = 0.0 + torch.max(result).cpu().numpy()
|
||||||
|
minval = 0.0 + torch.min(result).cpu().numpy()
|
||||||
|
if maxval < threshold and minval > -threshold:
|
||||||
|
return result
|
||||||
|
if maxval > threshold:
|
||||||
|
maxval = min(max(1, scale*maxval), threshold)
|
||||||
|
if minval < -threshold:
|
||||||
|
minval = max(min(-1, scale*minval), -threshold)
|
||||||
|
return torch.clamp(result, min=minval, max=maxval)
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiser(nn.Module):
|
||||||
|
def __init__(self, model, threshold = 0, warmup = 0):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
self.threshold = threshold
|
||||||
|
self.warmup_max = warmup
|
||||||
|
self.warmup = max(warmup / 10, 1)
|
||||||
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||||
|
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
|
||||||
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc)
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||||
|
if self.warmup < self.warmup_max:
|
||||||
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
|
self.warmup += 1
|
||||||
|
else:
|
||||||
|
thresh = self.threshold
|
||||||
|
if thresh > self.threshold:
|
||||||
|
thresh = self.threshold
|
||||||
|
return cfg_apply_threshold(next_x, thresh)
|
||||||
|
|
||||||
|
class KSampler(Sampler):
|
||||||
|
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
||||||
|
denoiser = K.external.CompVisDenoiser(model)
|
||||||
|
super().__init__(
|
||||||
|
denoiser,
|
||||||
|
schedule,
|
||||||
|
steps=model.num_timesteps,
|
||||||
|
)
|
||||||
|
self.sigmas = None
|
||||||
|
self.ds = None
|
||||||
|
self.s_in = None
|
||||||
|
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
|
||||||
|
if self.karras_max is None:
|
||||||
|
self.karras_max = STEP_THRESHOLD
|
||||||
|
|
||||||
|
def make_schedule(
|
||||||
|
self,
|
||||||
|
ddim_num_steps,
|
||||||
|
ddim_discretize='uniform',
|
||||||
|
ddim_eta=0.0,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
outer_model = self.model
|
||||||
|
self.model = outer_model.inner_model
|
||||||
|
super().make_schedule(
|
||||||
|
ddim_num_steps,
|
||||||
|
ddim_discretize='uniform',
|
||||||
|
ddim_eta=0.0,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
self.model = outer_model
|
||||||
|
self.ddim_num_steps = ddim_num_steps
|
||||||
|
# we don't need both of these sigmas, but storing them here to make
|
||||||
|
# comparison easier later on
|
||||||
|
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||||
|
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
||||||
|
n=ddim_num_steps,
|
||||||
|
sigma_min=self.model.sigmas[0].item(),
|
||||||
|
sigma_max=self.model.sigmas[-1].item(),
|
||||||
|
rho=7.,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ddim_num_steps >= self.karras_max:
|
||||||
|
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})')
|
||||||
|
self.sigmas = self.model_sigmas
|
||||||
|
else:
|
||||||
|
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||||
|
self.sigmas = self.karras_sigmas
|
||||||
|
|
||||||
|
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||||
|
# means that inpainting will not work. To get this to work we need to be able to
|
||||||
|
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||||
|
# in the lstein/k-diffusion branch.
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
z_enc,
|
||||||
|
cond,
|
||||||
|
t_enc,
|
||||||
|
img_callback=None,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
use_original_steps=False,
|
||||||
|
init_latent = None,
|
||||||
|
mask = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
samples,_ = self.sample(
|
||||||
|
batch_size = 1,
|
||||||
|
S = t_enc,
|
||||||
|
x_T = z_enc,
|
||||||
|
shape = z_enc.shape[1:],
|
||||||
|
conditioning = cond,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning = unconditional_conditioning,
|
||||||
|
img_callback = img_callback,
|
||||||
|
x0 = init_latent,
|
||||||
|
mask = mask,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
||||||
|
@torch.no_grad()
|
||||||
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
|
return x0
|
||||||
|
|
||||||
|
# Most of these arguments are ignored and are only present for compatibility with
|
||||||
|
# other samples
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
attention_maps_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.0,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||||
|
threshold = 0,
|
||||||
|
perlin = 0,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
def route_callback(k_callback_values):
|
||||||
|
if img_callback is not None:
|
||||||
|
img_callback(k_callback_values['x'],k_callback_values['i'])
|
||||||
|
|
||||||
|
# if make_schedule() hasn't been called, we do it now
|
||||||
|
if self.sigmas is None:
|
||||||
|
self.make_schedule(
|
||||||
|
ddim_num_steps=S,
|
||||||
|
ddim_eta = eta,
|
||||||
|
verbose = False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# sigmas are set up in make_schedule - we take the last steps items
|
||||||
|
sigmas = self.sigmas[-S-1:]
|
||||||
|
|
||||||
|
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||||
|
# more randomness to the starting image.
|
||||||
|
if x_T is not None:
|
||||||
|
if x0 is not None:
|
||||||
|
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
|
||||||
|
else:
|
||||||
|
x = x_T * sigmas[0]
|
||||||
|
else:
|
||||||
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
|
|
||||||
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
|
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||||
|
|
||||||
|
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
||||||
|
attention_map_saver = None
|
||||||
|
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
||||||
|
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||||
|
attention_map_token_ids = range(1, eos_token_index)
|
||||||
|
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||||
|
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||||
|
|
||||||
|
extra_args = {
|
||||||
|
'cond': conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': unconditional_guidance_scale,
|
||||||
|
}
|
||||||
|
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||||
|
sampling_result = (
|
||||||
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||||
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||||
|
callback=route_callback
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if attention_map_saver is not None:
|
||||||
|
attention_maps_callback(attention_map_saver)
|
||||||
|
return sampling_result
|
||||||
|
|
||||||
|
# this code will support inpainting if and when ksampler API modified or
|
||||||
|
# a workaround is found.
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(
|
||||||
|
self,
|
||||||
|
img,
|
||||||
|
cond,
|
||||||
|
ts,
|
||||||
|
index,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
extra_conditioning_info=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
self.model_wrap = CFGDenoiser(self.model)
|
||||||
|
extra_args = {
|
||||||
|
'cond': cond,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': unconditional_guidance_scale,
|
||||||
|
}
|
||||||
|
if self.s_in is None:
|
||||||
|
self.s_in = img.new_ones([img.shape[0]])
|
||||||
|
if self.ds is None:
|
||||||
|
self.ds = []
|
||||||
|
|
||||||
|
# terrible, confusing names here
|
||||||
|
steps = self.ddim_num_steps
|
||||||
|
t_enc = self.t_enc
|
||||||
|
|
||||||
|
# sigmas is a full steps in length, but t_enc might
|
||||||
|
# be less. We start in the middle of the sigma array
|
||||||
|
# and work our way to the end after t_enc steps.
|
||||||
|
# index starts at t_enc and works its way to zero,
|
||||||
|
# so the actual formula for indexing into sigmas:
|
||||||
|
# sigma_index = (steps-index)
|
||||||
|
s_index = t_enc - index - 1
|
||||||
|
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||||
|
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||||
|
self.model_wrap,
|
||||||
|
img,
|
||||||
|
self.sigmas,
|
||||||
|
s_index,
|
||||||
|
s_in = self.s_in,
|
||||||
|
ds = self.ds,
|
||||||
|
extra_args=extra_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
return img, None, None
|
||||||
|
|
||||||
|
# REVIEW THIS METHOD: it has never been tested. In particular,
|
||||||
|
# we should not be multiplying by self.sigmas[0] if we
|
||||||
|
# are at an intermediate step in img2img. See similar in
|
||||||
|
# sample() which does work.
|
||||||
|
def get_initial_image(self,x_T,shape,steps):
|
||||||
|
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
||||||
|
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
||||||
|
if x_T is not None:
|
||||||
|
return x_T + x
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def prepare_to_sample(self,t_enc,**kwargs):
|
||||||
|
self.t_enc = t_enc
|
||||||
|
self.model_wrap = None
|
||||||
|
self.ds = None
|
||||||
|
self.s_in = None
|
||||||
|
|
||||||
|
def q_sample(self,x0,ts):
|
||||||
|
'''
|
||||||
|
Overrides parent method to return the q_sample of the inner model.
|
||||||
|
'''
|
||||||
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
def conditioning_key(self)->str:
|
||||||
|
return self.model.inner_model.model.conditioning_key
|
||||||
|
|
146
invokeai/models/diffusion/plms.py
Normal file
146
invokeai/models/diffusion/plms.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from .sampler import Sampler
|
||||||
|
from ldm.modules.diffusionmodules.util import noise_like
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler(Sampler):
|
||||||
|
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||||
|
super().__init__(model,schedule,model.num_timesteps, device)
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||||
|
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
|
# this is the essential routine
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(
|
||||||
|
self,
|
||||||
|
x, # image, called 'img' elsewhere
|
||||||
|
c, # conditioning, called 'cond' elsewhere
|
||||||
|
t, # timesteps, called 'ts' elsewhere
|
||||||
|
index,
|
||||||
|
repeat_noise=False,
|
||||||
|
use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
old_eps=[],
|
||||||
|
t_next=None,
|
||||||
|
step_count:int=1000, # total number of steps
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if (
|
||||||
|
unconditional_conditioning is None
|
||||||
|
or unconditional_guidance_scale == 1.0
|
||||||
|
):
|
||||||
|
# damian0815 would like to know when/if this code path is used
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
# step_index counts in the opposite direction to index
|
||||||
|
step_index = step_count-(index+1)
|
||||||
|
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||||
|
unconditional_conditioning, c,
|
||||||
|
unconditional_guidance_scale,
|
||||||
|
step_index=step_index)
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == 'eps'
|
||||||
|
e_t = score_corrector.modify_score(
|
||||||
|
self.model, e_t, x, t, c, **corrector_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = (
|
||||||
|
self.model.alphas_cumprod
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_alphas
|
||||||
|
)
|
||||||
|
alphas_prev = (
|
||||||
|
self.model.alphas_cumprod_prev
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_alphas_prev
|
||||||
|
)
|
||||||
|
sqrt_one_minus_alphas = (
|
||||||
|
self.model.sqrt_one_minus_alphas_cumprod
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sqrt_one_minus_alphas
|
||||||
|
)
|
||||||
|
sigmas = (
|
||||||
|
self.model.ddim_sigmas_for_original_num_steps
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sigmas
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full(
|
||||||
|
(b, 1, 1, 1), alphas_prev[index], device=device
|
||||||
|
)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full(
|
||||||
|
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = (
|
||||||
|
sigma_t
|
||||||
|
* noise_like(x.shape, device, repeat_noise)
|
||||||
|
* temperature
|
||||||
|
)
|
||||||
|
if noise_dropout > 0.0:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (
|
||||||
|
55 * e_t
|
||||||
|
- 59 * old_eps[-1]
|
||||||
|
+ 37 * old_eps[-2]
|
||||||
|
- 9 * old_eps[-3]
|
||||||
|
) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
450
invokeai/models/diffusion/sampler.py
Normal file
450
invokeai/models/diffusion/sampler.py
Normal file
@ -0,0 +1,450 @@
|
|||||||
|
'''
|
||||||
|
invokeai.models.diffusion.sampler
|
||||||
|
|
||||||
|
Base class for invokeai.models.diffusion.ddim, invokeai.models.diffusion.ksampler, etc
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import (
|
||||||
|
make_ddim_sampling_parameters,
|
||||||
|
make_ddim_timesteps,
|
||||||
|
noise_like,
|
||||||
|
extract_into_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Sampler(object):
|
||||||
|
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs):
|
||||||
|
self.model = model
|
||||||
|
self.ddim_timesteps = None
|
||||||
|
self.ddpm_num_timesteps = steps
|
||||||
|
self.schedule = schedule
|
||||||
|
self.device = device or choose_torch_device()
|
||||||
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||||
|
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device(self.device):
|
||||||
|
attr = attr.to(torch.float32).to(torch.device(self.device))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
# This method was copied over from ddim.py and probably does stuff that is
|
||||||
|
# ddim-specific. Disentangle at some point.
|
||||||
|
def make_schedule(
|
||||||
|
self,
|
||||||
|
ddim_num_steps,
|
||||||
|
ddim_discretize='uniform',
|
||||||
|
ddim_eta=0.0,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
self.total_steps = ddim_num_steps
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
|
ddim_discr_method=ddim_discretize,
|
||||||
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert (
|
||||||
|
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||||
|
), 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = (
|
||||||
|
lambda x: x.clone()
|
||||||
|
.detach()
|
||||||
|
.to(torch.float32)
|
||||||
|
.to(self.model.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer(
|
||||||
|
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer(
|
||||||
|
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'sqrt_one_minus_alphas_cumprod',
|
||||||
|
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'log_one_minus_alphas_cumprod',
|
||||||
|
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'sqrt_recip_alphas_cumprod',
|
||||||
|
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'sqrt_recipm1_alphas_cumprod',
|
||||||
|
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
(
|
||||||
|
ddim_sigmas,
|
||||||
|
ddim_alphas,
|
||||||
|
ddim_alphas_prev,
|
||||||
|
) = make_ddim_sampling_parameters(
|
||||||
|
alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer(
|
||||||
|
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||||
|
)
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev)
|
||||||
|
/ (1 - self.alphas_cumprod)
|
||||||
|
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'ddim_sigmas_for_original_num_steps',
|
||||||
|
sigmas_for_original_sampling_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
|
# fast, but does not allow for exact reconstruction
|
||||||
|
# t serves as an index to gather the correct alphas
|
||||||
|
if use_original_steps:
|
||||||
|
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||||
|
else:
|
||||||
|
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
return (
|
||||||
|
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||||
|
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||||
|
* noise
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
S, # S is steps
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.0,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=False,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
# check to see if make_schedule() has run, and if not, run it
|
||||||
|
if self.ddim_timesteps is None:
|
||||||
|
self.make_schedule(
|
||||||
|
ddim_num_steps=S,
|
||||||
|
ddim_eta = eta,
|
||||||
|
verbose = False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ts = self.get_timesteps(S)
|
||||||
|
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
shape = (batch_size, C, H, W)
|
||||||
|
samples, intermediates = self.do_sampling(
|
||||||
|
conditioning,
|
||||||
|
shape,
|
||||||
|
timesteps=ts,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask,
|
||||||
|
x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
steps=S,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def do_sampling(
|
||||||
|
self,
|
||||||
|
cond,
|
||||||
|
shape,
|
||||||
|
timesteps=None,
|
||||||
|
x_T=None,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
callback=None,
|
||||||
|
quantize_denoised=False,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
img_callback=None,
|
||||||
|
log_every_t=100,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
steps=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
b = shape[0]
|
||||||
|
time_range = (
|
||||||
|
list(reversed(range(0, timesteps)))
|
||||||
|
if ddim_use_original_steps
|
||||||
|
else np.flip(timesteps)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_steps=steps
|
||||||
|
|
||||||
|
iterator = tqdm(
|
||||||
|
time_range,
|
||||||
|
desc=f'{self.__class__.__name__}',
|
||||||
|
total=total_steps,
|
||||||
|
dynamic_ncols=True,
|
||||||
|
)
|
||||||
|
old_eps = []
|
||||||
|
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
|
||||||
|
img = self.get_initial_image(x_T,shape,total_steps)
|
||||||
|
|
||||||
|
# probably don't need this at all
|
||||||
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full(
|
||||||
|
(b,),
|
||||||
|
step,
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
ts_next = torch.full(
|
||||||
|
(b,),
|
||||||
|
time_range[min(i + 1, len(time_range) - 1)],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(
|
||||||
|
x0, ts
|
||||||
|
) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1.0 - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample(
|
||||||
|
img,
|
||||||
|
cond,
|
||||||
|
ts,
|
||||||
|
index=index,
|
||||||
|
use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised,
|
||||||
|
temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
old_eps=old_eps,
|
||||||
|
t_next=ts_next,
|
||||||
|
step_count=steps
|
||||||
|
)
|
||||||
|
img, pred_x0, e_t = outs
|
||||||
|
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
if callback:
|
||||||
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img,i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
return img, intermediates
|
||||||
|
|
||||||
|
# NOTE that decode() and sample() are almost the same code, and do the same thing.
|
||||||
|
# The variable names are changed in order to be confusing.
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
x_latent,
|
||||||
|
cond,
|
||||||
|
t_start,
|
||||||
|
img_callback=None,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
use_original_steps=False,
|
||||||
|
init_latent = None,
|
||||||
|
mask = None,
|
||||||
|
all_timesteps_count = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
timesteps = (
|
||||||
|
np.arange(self.ddpm_num_timesteps)
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_timesteps
|
||||||
|
)
|
||||||
|
timesteps = timesteps[:t_start]
|
||||||
|
|
||||||
|
time_range = np.flip(timesteps)
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||||
|
x_dec = x_latent
|
||||||
|
x0 = init_latent
|
||||||
|
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full(
|
||||||
|
(x_latent.shape[0],),
|
||||||
|
step,
|
||||||
|
device=x_latent.device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
|
ts_next = torch.full(
|
||||||
|
(x_latent.shape[0],),
|
||||||
|
time_range[min(i + 1, len(time_range) - 1)],
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
||||||
|
|
||||||
|
outs = self.p_sample(
|
||||||
|
x_dec,
|
||||||
|
cond,
|
||||||
|
ts,
|
||||||
|
index=index,
|
||||||
|
use_original_steps=use_original_steps,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
t_next = ts_next,
|
||||||
|
step_count=len(self.ddim_timesteps)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_dec, pred_x0, e_t = outs
|
||||||
|
if img_callback:
|
||||||
|
img_callback(x_dec,i)
|
||||||
|
|
||||||
|
return x_dec
|
||||||
|
|
||||||
|
def get_initial_image(self,x_T,shape,timesteps=None):
|
||||||
|
if x_T is None:
|
||||||
|
return torch.randn(shape, device=self.device)
|
||||||
|
else:
|
||||||
|
return x_T
|
||||||
|
|
||||||
|
def p_sample(
|
||||||
|
self,
|
||||||
|
img,
|
||||||
|
cond,
|
||||||
|
ts,
|
||||||
|
index,
|
||||||
|
repeat_noise=False,
|
||||||
|
use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
old_eps=None,
|
||||||
|
t_next=None,
|
||||||
|
steps=None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("p_sample() must be implemented in a descendent class")
|
||||||
|
|
||||||
|
def prepare_to_sample(self,t_enc,**kwargs):
|
||||||
|
'''
|
||||||
|
Hook that will be called right before the very first invocation of p_sample()
|
||||||
|
to allow subclass to do additional initialization. t_enc corresponds to the actual
|
||||||
|
number of steps that will be run, and may be less than total steps if img2img is
|
||||||
|
active.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_timesteps(self,ddim_steps):
|
||||||
|
'''
|
||||||
|
The ddim and plms samplers work on timesteps. This method is called after
|
||||||
|
ddim_timesteps are created in make_schedule(), and selects the portion of
|
||||||
|
timesteps that will be used for sampling, depending on the t_enc in img2img.
|
||||||
|
'''
|
||||||
|
return self.ddim_timesteps[:ddim_steps]
|
||||||
|
|
||||||
|
def q_sample(self,x0,ts):
|
||||||
|
'''
|
||||||
|
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
|
||||||
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
|
'''
|
||||||
|
return self.model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
def conditioning_key(self)->str:
|
||||||
|
return self.model.model.conditioning_key
|
||||||
|
|
||||||
|
def uses_inpainting_model(self)->bool:
|
||||||
|
return self.conditioning_key() in ('hybrid','concat')
|
||||||
|
|
||||||
|
def adjust_settings(self,**kwargs):
|
||||||
|
'''
|
||||||
|
This is a catch-all method for adjusting any instance variables
|
||||||
|
after the sampler is instantiated. No type-checking performed
|
||||||
|
here, so use with care!
|
||||||
|
'''
|
||||||
|
for k in kwargs.keys():
|
||||||
|
try:
|
||||||
|
setattr(self,k,kwargs[k])
|
||||||
|
except AttributeError:
|
||||||
|
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')
|
491
invokeai/models/diffusion/shared_invokeai_diffusion.py
Normal file
491
invokeai/models/diffusion/shared_invokeai_diffusion.py
Normal file
@ -0,0 +1,491 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from math import ceil
|
||||||
|
from typing import Callable, Optional, Union, Any, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
from .cross_attention_control import Arguments, \
|
||||||
|
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||||
|
CrossAttentionType, SwapCrossAttnContext
|
||||||
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
|
ModelForwardCallback: TypeAlias = Union[
|
||||||
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||||
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
|
||||||
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||||
|
]
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PostprocessingSettings:
|
||||||
|
threshold: float
|
||||||
|
warmup: float
|
||||||
|
h_symmetry_time_pct: Optional[float]
|
||||||
|
v_symmetry_time_pct: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAIDiffuserComponent:
|
||||||
|
'''
|
||||||
|
The aim of this component is to provide a single place for code that can be applied identically to
|
||||||
|
all InvokeAI diffusion procedures.
|
||||||
|
|
||||||
|
At the moment it includes the following features:
|
||||||
|
* Cross attention control ("prompt2prompt")
|
||||||
|
* Hybrid conditioning (used for inpainting)
|
||||||
|
'''
|
||||||
|
debug_thresholding = False
|
||||||
|
sequential_guidance = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtraConditioningInfo:
|
||||||
|
|
||||||
|
tokens_count_including_eos_bos: int
|
||||||
|
cross_attention_control_args: Optional[Arguments] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
return self.cross_attention_control_args is not None
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||||
|
is_running_diffusers: bool=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param model: the unet model to pass through to cross attention control
|
||||||
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
|
"""
|
||||||
|
self.conditioning = None
|
||||||
|
self.model = model
|
||||||
|
self.is_running_diffusers = is_running_diffusers
|
||||||
|
self.model_forward_callback = model_forward_callback
|
||||||
|
self.cross_attention_control_context = None
|
||||||
|
self.sequential_guidance = Globals.sequential_guidance
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def custom_attention_context(self,
|
||||||
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
|
step_count: int):
|
||||||
|
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
|
old_attn_processor = None
|
||||||
|
if do_swap:
|
||||||
|
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||||
|
step_count=step_count)
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if old_attn_processor is not None:
|
||||||
|
self.restore_default_cross_attention(old_attn_processor)
|
||||||
|
# TODO resuscitate attention map saving
|
||||||
|
#self.remove_attention_map_saving()
|
||||||
|
|
||||||
|
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||||
|
"""
|
||||||
|
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
|
the previous attention processor is returned so that the caller can restore it later.
|
||||||
|
"""
|
||||||
|
self.conditioning = conditioning
|
||||||
|
self.cross_attention_control_context = Context(
|
||||||
|
arguments=self.conditioning.cross_attention_control_args,
|
||||||
|
step_count=step_count
|
||||||
|
)
|
||||||
|
return override_cross_attention(self.model,
|
||||||
|
self.cross_attention_control_context,
|
||||||
|
is_running_diffusers=self.is_running_diffusers)
|
||||||
|
|
||||||
|
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||||
|
self.conditioning = None
|
||||||
|
self.cross_attention_control_context = None
|
||||||
|
restore_default_cross_attention(self.model,
|
||||||
|
is_running_diffusers=self.is_running_diffusers,
|
||||||
|
restore_attention_processor=restore_attention_processor)
|
||||||
|
|
||||||
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
|
if dim is not None:
|
||||||
|
# sliced tokens attention map saving is not implemented
|
||||||
|
return
|
||||||
|
saver.add_attention_maps(slice, key)
|
||||||
|
|
||||||
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
|
for identifier, module in tokens_cross_attention_modules:
|
||||||
|
key = ('down' if identifier.startswith('down') else
|
||||||
|
'up' if identifier.startswith('up') else
|
||||||
|
'mid')
|
||||||
|
module.set_attention_slice_calculated_callback(
|
||||||
|
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||||
|
|
||||||
|
def remove_attention_map_saving(self):
|
||||||
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
|
for _, module in tokens_cross_attention_modules:
|
||||||
|
module.set_attention_slice_calculated_callback(None)
|
||||||
|
|
||||||
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||||
|
unconditioning: Union[torch.Tensor,dict],
|
||||||
|
conditioning: Union[torch.Tensor,dict],
|
||||||
|
unconditional_guidance_scale: float,
|
||||||
|
step_index: Optional[int]=None,
|
||||||
|
total_step_count: Optional[int]=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param x: current latents
|
||||||
|
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||||
|
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||||
|
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||||
|
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||||
|
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||||
|
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
cross_attention_control_types_to_do = []
|
||||||
|
context: Context = self.cross_attention_control_context
|
||||||
|
if self.cross_attention_control_context is not None:
|
||||||
|
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||||
|
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
|
|
||||||
|
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||||
|
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||||
|
|
||||||
|
if wants_hybrid_conditioning:
|
||||||
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
|
||||||
|
conditioning)
|
||||||
|
elif wants_cross_attention_control:
|
||||||
|
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
|
elif self.sequential_guidance:
|
||||||
|
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
|
||||||
|
x, sigma, unconditioning, conditioning)
|
||||||
|
|
||||||
|
else:
|
||||||
|
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
|
||||||
|
x, sigma, unconditioning, conditioning)
|
||||||
|
|
||||||
|
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||||
|
|
||||||
|
return combined_next_x
|
||||||
|
|
||||||
|
def do_latent_postprocessing(
|
||||||
|
self,
|
||||||
|
postprocessing_settings: PostprocessingSettings,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
sigma,
|
||||||
|
step_index,
|
||||||
|
total_step_count
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if postprocessing_settings is not None:
|
||||||
|
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||||
|
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||||
|
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||||
|
if step_index is not None and total_step_count is not None:
|
||||||
|
# 🧨diffusers codepath
|
||||||
|
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||||
|
else:
|
||||||
|
# legacy compvis codepath
|
||||||
|
# TODO remove when compvis codepath support is dropped
|
||||||
|
if step_index is None and sigma is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||||
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
|
return percent_through
|
||||||
|
|
||||||
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
|
# fast batched path
|
||||||
|
x_twice = torch.cat([x] * 2)
|
||||||
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
|
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
|
||||||
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
|
if conditioned_next_x.device.type == 'mps':
|
||||||
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
|
||||||
|
# low-memory sequential path
|
||||||
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||||
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
||||||
|
if conditioned_next_x.device.type == 'mps':
|
||||||
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
|
assert isinstance(conditioning, dict)
|
||||||
|
assert isinstance(unconditioning, dict)
|
||||||
|
x_twice = torch.cat([x] * 2)
|
||||||
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
both_conditionings = dict()
|
||||||
|
for k in conditioning:
|
||||||
|
if isinstance(conditioning[k], list):
|
||||||
|
both_conditionings[k] = [
|
||||||
|
torch.cat([unconditioning[k][i], conditioning[k][i]])
|
||||||
|
for i in range(len(conditioning[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||||
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_cross_attention_controlled_conditioning(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do):
|
||||||
|
if self.is_running_diffusers:
|
||||||
|
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
|
else:
|
||||||
|
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
|
|
||||||
|
def _apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do):
|
||||||
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
|
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
|
||||||
|
index_map=context.cross_attention_index_map,
|
||||||
|
mask=context.cross_attention_mask,
|
||||||
|
cross_attention_types_to_do=[])
|
||||||
|
# no cross attention for unconditioning (negative prompt)
|
||||||
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
|
||||||
|
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||||
|
|
||||||
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
|
||||||
|
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||||
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||||
|
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||||
|
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||||
|
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||||
|
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||||
|
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||||
|
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||||
|
context:Context = self.cross_attention_control_context
|
||||||
|
|
||||||
|
try:
|
||||||
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||||
|
|
||||||
|
# process x using the original prompt, saving the attention maps
|
||||||
|
#print("saving attention maps for", cross_attention_control_types_to_do)
|
||||||
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
|
context.request_save_attention_maps(ca_type)
|
||||||
|
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||||
|
context.clear_requests(cleanup=False)
|
||||||
|
|
||||||
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
|
#print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||||
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
|
context.request_apply_saved_attention_maps(ca_type)
|
||||||
|
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
|
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||||
|
context.clear_requests(cleanup=True)
|
||||||
|
|
||||||
|
except:
|
||||||
|
context.clear_requests(cleanup=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||||
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||||
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
|
return combined_next_x
|
||||||
|
|
||||||
|
def apply_threshold(
|
||||||
|
self,
|
||||||
|
postprocessing_settings: PostprocessingSettings,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
percent_through: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
threshold = postprocessing_settings.threshold
|
||||||
|
warmup = postprocessing_settings.warmup
|
||||||
|
|
||||||
|
if percent_through < warmup:
|
||||||
|
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||||
|
else:
|
||||||
|
current_threshold = threshold
|
||||||
|
|
||||||
|
if current_threshold <= 0:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
maxval = latents.max().item()
|
||||||
|
minval = latents.min().item()
|
||||||
|
|
||||||
|
scale = 0.7 # default value from #395
|
||||||
|
|
||||||
|
if self.debug_thresholding:
|
||||||
|
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||||
|
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||||
|
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||||
|
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||||
|
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||||
|
|
||||||
|
if maxval < current_threshold and minval > -current_threshold:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
num_altered = 0
|
||||||
|
|
||||||
|
# MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
|
||||||
|
|
||||||
|
if maxval > current_threshold:
|
||||||
|
latents = torch.clone(latents)
|
||||||
|
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||||
|
num_altered += torch.count_nonzero(latents > maxval)
|
||||||
|
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||||
|
|
||||||
|
if minval < -current_threshold:
|
||||||
|
latents = torch.clone(latents)
|
||||||
|
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||||
|
num_altered += torch.count_nonzero(latents < minval)
|
||||||
|
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||||
|
|
||||||
|
if self.debug_thresholding:
|
||||||
|
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||||
|
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def apply_symmetry(
|
||||||
|
self,
|
||||||
|
postprocessing_settings: PostprocessingSettings,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
percent_through: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Reset our last percent through if this is our first step.
|
||||||
|
if percent_through == 0.0:
|
||||||
|
self.last_percent_through = 0.0
|
||||||
|
|
||||||
|
if postprocessing_settings is None:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# Check for out of bounds
|
||||||
|
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||||
|
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
|
||||||
|
h_symmetry_time_pct = None
|
||||||
|
|
||||||
|
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||||
|
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
|
||||||
|
v_symmetry_time_pct = None
|
||||||
|
|
||||||
|
dev = latents.device.type
|
||||||
|
|
||||||
|
latents.to(device='cpu')
|
||||||
|
|
||||||
|
if (
|
||||||
|
h_symmetry_time_pct != None and
|
||||||
|
self.last_percent_through < h_symmetry_time_pct and
|
||||||
|
percent_through >= h_symmetry_time_pct
|
||||||
|
):
|
||||||
|
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||||
|
width = latents.shape[3]
|
||||||
|
x_flipped = torch.flip(latents, dims=[3])
|
||||||
|
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||||
|
|
||||||
|
if (
|
||||||
|
v_symmetry_time_pct != None and
|
||||||
|
self.last_percent_through < v_symmetry_time_pct and
|
||||||
|
percent_through >= v_symmetry_time_pct
|
||||||
|
):
|
||||||
|
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||||
|
height = latents.shape[2]
|
||||||
|
y_flipped = torch.flip(latents, dims=[2])
|
||||||
|
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||||
|
|
||||||
|
self.last_percent_through = percent_through
|
||||||
|
return latents.to(device=dev)
|
||||||
|
|
||||||
|
def estimate_percent_through(self, step_index, sigma):
|
||||||
|
if step_index is not None and self.cross_attention_control_context is not None:
|
||||||
|
# percent_through will never reach 1.0 (but this is intended)
|
||||||
|
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||||
|
# find the best possible index of the current sigma in the sigma sequence
|
||||||
|
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||||
|
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||||
|
# flip because sigmas[0] is for the fully denoised image
|
||||||
|
# percent_through must be <1
|
||||||
|
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
||||||
|
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||||
|
|
||||||
|
|
||||||
|
# todo: make this work
|
||||||
|
@classmethod
|
||||||
|
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2) # aka sigmas
|
||||||
|
|
||||||
|
deltas = None
|
||||||
|
uncond_latents = None
|
||||||
|
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||||
|
|
||||||
|
# below is fugly omg
|
||||||
|
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||||
|
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||||
|
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||||
|
chunk_count = ceil(len(conditionings)/2)
|
||||||
|
deltas = None
|
||||||
|
for chunk_index in range(chunk_count):
|
||||||
|
offset = chunk_index*2
|
||||||
|
chunk_size = min(2, len(conditionings)-offset)
|
||||||
|
|
||||||
|
if chunk_size == 1:
|
||||||
|
c_in = conditionings[offset]
|
||||||
|
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||||
|
latents_b = None
|
||||||
|
else:
|
||||||
|
c_in = torch.cat(conditionings[offset:offset+2])
|
||||||
|
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||||
|
|
||||||
|
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||||
|
if chunk_index == 0:
|
||||||
|
uncond_latents = latents_a
|
||||||
|
deltas = latents_b - uncond_latents
|
||||||
|
else:
|
||||||
|
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||||
|
if latents_b is not None:
|
||||||
|
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||||
|
|
||||||
|
# merge the weighted deltas together into a single merged delta
|
||||||
|
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||||
|
normalize = False
|
||||||
|
if normalize:
|
||||||
|
per_delta_weights /= torch.sum(per_delta_weights)
|
||||||
|
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||||
|
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||||
|
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||||
|
|
||||||
|
return uncond_latents + deltas_merged * global_guidance_scale
|
1221
invokeai/models/model_manager.py
Normal file
1221
invokeai/models/model_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user