mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add missing files
This commit is contained in:
parent
8c9764476c
commit
d334f7f1f6
9
.gitignore
vendored
9
.gitignore
vendored
@ -214,9 +214,9 @@ gfpgan/
|
||||
configs/models.yaml
|
||||
|
||||
# weights (will be created by installer)
|
||||
models/ldm/stable-diffusion-v1/*.ckpt
|
||||
models/clipseg
|
||||
models/gfpgan
|
||||
# models/ldm/stable-diffusion-v1/*.ckpt
|
||||
# models/clipseg
|
||||
# models/gfpgan
|
||||
|
||||
# ignore initfile
|
||||
.invokeai
|
||||
@ -232,6 +232,3 @@ installer/install.bat
|
||||
installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
|
||||
# no longer stored in source directory
|
||||
models
|
||||
|
10
invokeai/models/__init__.py
Normal file
10
invokeai/models/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
'''
|
||||
Initialization file for the invokeai.models package
|
||||
'''
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.ddim import DDIMSampler
|
||||
from .diffusion.ksampler import KSampler
|
||||
from .diffusion.plms import PLMSSampler
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
10
invokeai/models/__init__.py~
Normal file
10
invokeai/models/__init__.py~
Normal file
@ -0,0 +1,10 @@
|
||||
'''
|
||||
Initialization file for the invokeai.models package
|
||||
'''
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
from .diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.ddim import DDIMSampler
|
||||
from .diffusion.ksampler import KSampler
|
||||
from .diffusion.plms import PLMSSampler
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
BIN
invokeai/models/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
invokeai/models/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/__pycache__/autoencoder.cpython-310.pyc
Normal file
BIN
invokeai/models/__pycache__/autoencoder.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/__pycache__/model_manager.cpython-310.pyc
Normal file
BIN
invokeai/models/__pycache__/model_manager.cpython-310.pyc
Normal file
Binary file not shown.
596
invokeai/models/autoencoder.py
Normal file
596
invokeai/models/autoencoder.py
Normal file
@ -0,0 +1,596 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import (
|
||||
DiagonalGaussianDistribution,
|
||||
)
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(
|
||||
n_embed,
|
||||
embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
embed_dim, ddconfig['z_channels'], 1
|
||||
)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||
)
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(
|
||||
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
|
||||
)
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f'{context}: Switched to EMA weights')
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f'{context}: Restored training weights')
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f'Missing Keys: {missing}')
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_, _, ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = (
|
||||
x.permute(0, 3, 1, 2)
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(
|
||||
np.arange(lower_size, upper_size + 16, 16)
|
||||
)
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode='bicubic')
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(
|
||||
batch, batch_idx, suffix='_ema'
|
||||
)
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=''):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log(
|
||||
f'val{suffix}/rec_loss',
|
||||
rec_loss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
f'val{suffix}/aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor * self.learning_rate
|
||||
print('lr_d', lr_d)
|
||||
print('lr_g', lr_g)
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quantize.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
|
||||
)
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(
|
||||
opt_ae, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1,
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(
|
||||
opt_disc, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1,
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log['inputs'] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['inputs'] = x
|
||||
log['reconstructions'] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||
)
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
2 * ddconfig['z_channels'], 2 * embed_dim, 1
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
embed_dim, ddconfig['z_channels'], 1
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||
)
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f'Restored from {path}')
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = (
|
||||
x.permute(0, 3, 1, 2)
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
)
|
||||
self.log(
|
||||
'aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False,
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
)
|
||||
|
||||
self.log(
|
||||
'discloss',
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False,
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val',
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val',
|
||||
)
|
||||
|
||||
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
||||
)
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||
)
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
4
invokeai/models/diffusion/__init__.py
Normal file
4
invokeai/models/diffusion/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for invokeai.models.diffusion
|
||||
'''
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
4
invokeai/models/diffusion/__init__.py~
Normal file
4
invokeai/models/diffusion/__init__.py~
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for invokeai.models.diffusion
|
||||
'''
|
||||
from shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
BIN
invokeai/models/diffusion/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/ddim.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/ddim.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/ddpm.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/ddpm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/ksampler.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/ksampler.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/plms.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/plms.cpython-310.pyc
Normal file
Binary file not shown.
BIN
invokeai/models/diffusion/__pycache__/sampler.cpython-310.pyc
Normal file
BIN
invokeai/models/diffusion/__pycache__/sampler.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
355
invokeai/models/diffusion/classifier.py
Normal file
355
invokeai/models/diffusion/classifier.py
Normal file
@ -0,0 +1,355 @@
|
||||
import os
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from glob import glob
|
||||
from natsort import natsorted
|
||||
|
||||
from ldm.modules.diffusionmodules.openaimodel import (
|
||||
EncoderUNetModel,
|
||||
UNetModel,
|
||||
)
|
||||
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||
|
||||
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool='attention',
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.0e-2,
|
||||
log_steps=10,
|
||||
monitor='val/loss',
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(
|
||||
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
|
||||
)[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = (
|
||||
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
)
|
||||
self.log_time_interval = (
|
||||
self.diffusion_model.num_timesteps // log_steps
|
||||
)
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = (
|
||||
label_key
|
||||
if not hasattr(self.diffusion_model, 'cond_stage_key')
|
||||
else self.diffusion_model.cond_stage_key
|
||||
)
|
||||
|
||||
assert (
|
||||
self.label_key is not None
|
||||
), 'label_key neither in diffusion model nor in model.params'
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.load_classifier(ckpt_path, pool)
|
||||
|
||||
self.scheduler_config = scheduler_config
|
||||
self.use_scheduler = self.scheduler_config is not None
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location='cpu')
|
||||
if 'state_dict' in list(sd.keys()):
|
||||
sd = sd['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = (
|
||||
self.load_state_dict(sd, strict=False)
|
||||
if not only_model
|
||||
else self.model.load_state_dict(sd, strict=False)
|
||||
)
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f'Missing Keys: {missing}')
|
||||
if len(unexpected) > 0:
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
self.diffusion_model = model.eval()
|
||||
self.diffusion_model.train = disabled_train
|
||||
for param in self.diffusion_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(
|
||||
self.diffusion_config.params.unet_config.params
|
||||
)
|
||||
model_config.in_channels = (
|
||||
self.diffusion_config.params.unet_config.params.out_channels
|
||||
)
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == 'class_label':
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print(
|
||||
'#####################################################################'
|
||||
)
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print(
|
||||
'#####################################################################'
|
||||
)
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_x_noisy(self, x, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x))
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = (
|
||||
self.diffusion_model.sample_continuous_noise_level(
|
||||
x.shape[0], t + 1
|
||||
)
|
||||
)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
return self.diffusion_model.q_sample(
|
||||
x_start=x,
|
||||
t=t,
|
||||
noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
|
||||
)
|
||||
|
||||
def forward(self, x_noisy, t, *args, **kwargs):
|
||||
return self.model(x_noisy, t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
x = x.to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_conditioning(self, batch, k=None):
|
||||
if k is None:
|
||||
k = self.label_key
|
||||
assert k is not None, 'Needs to provide label key'
|
||||
|
||||
targets = batch[k].to(self.device)
|
||||
|
||||
if self.label_key == 'segmentation':
|
||||
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(
|
||||
targets, size=(h // 2, w // 2), mode='nearest'
|
||||
)
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction='mean'):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == 'mean':
|
||||
return (
|
||||
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
)
|
||||
elif reduction == 'none':
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# save some memory
|
||||
self.diffusion_model.model.to('cpu')
|
||||
|
||||
@torch.no_grad()
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = 'train' if self.training else 'val'
|
||||
log = {}
|
||||
log[f'{log_prefix}/loss'] = loss.mean()
|
||||
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction='mean'
|
||||
)
|
||||
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction='mean'
|
||||
)
|
||||
|
||||
self.log_dict(
|
||||
log,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=self.training,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log(
|
||||
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
|
||||
)
|
||||
self.log(
|
||||
'global_step',
|
||||
self.global_step,
|
||||
logger=False,
|
||||
on_epoch=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
self.log(
|
||||
'lr_abs',
|
||||
lr,
|
||||
on_step=True,
|
||||
logger=True,
|
||||
on_epoch=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
|
||||
def shared_step(self, batch, t=None):
|
||||
x, *_ = self.diffusion_model.get_input(
|
||||
batch, k=self.diffusion_model.first_stage_key
|
||||
)
|
||||
targets = self.get_conditioning(batch)
|
||||
if targets.dim() == 4:
|
||||
targets = targets.argmax(dim=1)
|
||||
if t is None:
|
||||
t = torch.randint(
|
||||
0,
|
||||
self.diffusion_model.num_timesteps,
|
||||
(x.shape[0],),
|
||||
device=self.device,
|
||||
).long()
|
||||
else:
|
||||
t = torch.full(
|
||||
size=(x.shape[0],), fill_value=t, device=self.device
|
||||
).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||
|
||||
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||
|
||||
loss = loss.mean()
|
||||
return loss, logits, x_noisy, targets
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
return loss
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {
|
||||
t: {'acc@1': [], 'acc@5': []}
|
||||
for t in range(
|
||||
0,
|
||||
self.diffusion_model.num_timesteps,
|
||||
self.diffusion_model.log_every_t,
|
||||
)
|
||||
}
|
||||
|
||||
def on_validation_start(self):
|
||||
self.reset_noise_accs()
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]['acc@1'].append(
|
||||
self.compute_top_k(logits, targets, k=1, reduction='mean')
|
||||
)
|
||||
self.noisy_acc[t]['acc@5'].append(
|
||||
self.compute_top_k(logits, targets, k=5, reduction='mean')
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(
|
||||
optimizer, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1,
|
||||
}
|
||||
]
|
||||
return [optimizer], scheduler
|
||||
|
||||
return optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||
log['inputs'] = x
|
||||
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == 'class_label':
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
|
||||
log['labels'] = y
|
||||
|
||||
if ismap(y):
|
||||
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||
|
||||
for step in range(self.log_steps):
|
||||
current_time = step * self.log_time_interval
|
||||
|
||||
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||
|
||||
log[f'inputs@t{current_time}'] = x_noisy
|
||||
|
||||
pred = F.one_hot(
|
||||
logits.argmax(dim=1), num_classes=self.num_classes
|
||||
)
|
||||
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||
|
||||
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
|
||||
pred
|
||||
)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
||||
|
||||
return log
|
642
invokeai/models/diffusion/cross_attention_control.py
Normal file
642
invokeai/models/diffusion/cross_attention_control.py
Normal file
@ -0,0 +1,642 @@
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
import enum
|
||||
import math
|
||||
from typing import Optional, Callable
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import diffusers
|
||||
from torch import nn
|
||||
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class Context:
|
||||
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
cross_attention_index_map: Optional[torch.Tensor]
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = 1,
|
||||
APPLY = 2
|
||||
|
||||
def __init__(self, arguments: Arguments, step_count: int):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.cross_attention_mask = None
|
||||
self.cross_attention_index_map = None
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
self.self_cross_attention_module_identifiers = []
|
||||
self.tokens_cross_attention_module_identifiers = []
|
||||
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.SAVE
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.APPLY
|
||||
|
||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||
|
||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.SAVE
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.SAVE
|
||||
return False
|
||||
|
||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.APPLY
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
||||
-> list[CrossAttentionType]:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
||||
|
||||
opts = self.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through < opts['s_end']:
|
||||
to_control.append(CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through < opts['t_end']:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
||||
slice_size: Optional[int]):
|
||||
if identifier not in self.saved_cross_attention_maps:
|
||||
self.saved_cross_attention_maps[identifier] = {
|
||||
'dim': dim,
|
||||
'slice_size': slice_size,
|
||||
'slices': {offset or 0: slice}
|
||||
}
|
||||
else:
|
||||
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
|
||||
|
||||
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict['dim'] is not None:
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict['slices'][0]
|
||||
|
||||
if saved_attention_dict['dim'] == requested_dim:
|
||||
if slice_size != saved_attention_dict['slice_size']:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||
return saved_attention_dict['slices'][requested_offset]
|
||||
|
||||
if saved_attention_dict['dim'] is None:
|
||||
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
return saved_attention['dim'], saved_attention['slice_size']
|
||||
|
||||
def clear_requests(self, cleanup=True):
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
if cleanup:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict['slices'].items():
|
||||
map_dict[offset] = slice.to('cpu')
|
||||
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
|
||||
def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers = False):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = context.arguments.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length, dtype=torch_dtype(device))
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||
isinstance(module, cross_attention_class) and which_attn in name]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
|
||||
f"work properly until it is fixed.")
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
||||
|
||||
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attention_slice = suggested_attention_slice
|
||||
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
#print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
|
||||
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
||||
this_attention_slice * this_mask
|
||||
else:
|
||||
# just use everything
|
||||
attention_slice = saved_attention_slice
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(
|
||||
lambda module: context.get_slicing_strategy(identifier)
|
||||
)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
if hasattr(error, 'name'): # Python 3.10
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
#only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
#default_result = super()._attention(query, key, value)
|
||||
if attention_mask is not None:
|
||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## 🧨diffusers implementation follows
|
||||
|
||||
|
||||
"""
|
||||
# base implementation
|
||||
|
||||
class CrossAttnProcessor:
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
from dataclasses import field, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||
|
||||
def __int__(self,
|
||||
cac_types_to_do: [CrossAttentionType],
|
||||
modified_text_embeddings: torch.Tensor,
|
||||
index_map: torch.Tensor,
|
||||
mask: torch.Tensor):
|
||||
self.cross_attention_types_to_do = cac_types_to_do
|
||||
self.modified_text_embeddings = modified_text_embeddings
|
||||
self.index_map = index_map
|
||||
self.mask = mask
|
||||
|
||||
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
||||
return attn_type in self.cross_attention_types_to_do
|
||||
|
||||
@classmethod
|
||||
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \
|
||||
-> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":
|
||||
# these tokens remain the same as in the original prompt
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
return mask, indices
|
||||
|
||||
|
||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
# TODO: dynamically pick slice size based on memory conditions
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
|
||||
# if cross-attention control is not in play, just call through to the base implementation.
|
||||
if attention_type is CrossAttentionType.SELF or \
|
||||
swap_cross_attn_context is None or \
|
||||
not swap_cross_attn_context.wants_cross_attention_control(attention_type):
|
||||
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
#else:
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask=attention_mask, target_length=sequence_length,
|
||||
batch_size=batch_size)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
original_text_embeddings = encoder_hidden_states
|
||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||
original_text_key = attn.to_k(original_text_embeddings)
|
||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||
original_value = attn.to_v(original_text_embeddings)
|
||||
modified_value = attn.to_v(modified_text_embeddings)
|
||||
|
||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||
original_value = attn.head_to_batch_dim(original_value)
|
||||
modified_value = attn.head_to_batch_dim(modified_value)
|
||||
|
||||
# compute slices and prepare output tensor
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
# do slices
|
||||
for i in range(max(1,hidden_states.shape[0] // self.slice_size)):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
original_key_slice = original_text_key[start_idx:end_idx]
|
||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||
|
||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||
# the original attention probabilities must be remapped to account for token index changes in the
|
||||
# modified prompt
|
||||
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1,
|
||||
swap_cross_attn_context.index_map)
|
||||
|
||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||
mask = swap_cross_attn_context.mask
|
||||
inverse_mask = 1 - mask
|
||||
attn_slice = \
|
||||
remapped_original_attn_slice * mask + \
|
||||
modified_attn_slice * inverse_mask
|
||||
|
||||
del remapped_original_attn_slice, modified_attn_slice
|
||||
|
||||
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
|
||||
# done
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||
|
||||
def __init__(self):
|
||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
||||
|
95
invokeai/models/diffusion/cross_attention_map_saving.py
Normal file
95
invokeai/models/diffusion/cross_attention_map_saving.py
Normal file
@ -0,0 +1,95 @@
|
||||
import math
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
||||
|
||||
from .cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
||||
|
||||
|
||||
class AttentionMapSaver():
|
||||
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
|
||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||
"""
|
||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f'{key}_{maps.shape[1]}'
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
|
||||
# merge attention heads to a single map per token
|
||||
maps = torch.sum(maps, 0)
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
pil_image.save(path, 'PNG')
|
||||
|
||||
def get_stacked_maps_image(self) -> PIL.Image:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
"""
|
||||
num_tokens = len(self.token_ids)
|
||||
if num_tokens == 0:
|
||||
return None
|
||||
|
||||
latents_height = self.latents_shape[0]
|
||||
latents_width = self.latents_shape[1]
|
||||
|
||||
merged = None
|
||||
|
||||
for key, maps in self.collated_maps.items():
|
||||
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xff).byte()
|
||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|
111
invokeai/models/diffusion/ddim.py
Normal file
111
invokeai/models/diffusion/ddim.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from .sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
class DDIMSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps,device)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(
|
||||
x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index
|
||||
)
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = (
|
||||
self.model.alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_alphas
|
||||
)
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = (
|
||||
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
)
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0, None
|
||||
|
2271
invokeai/models/diffusion/ddpm.py
Normal file
2271
invokeai/models/diffusion/ddpm.py
Normal file
File diff suppressed because it is too large
Load Diff
312
invokeai/models/diffusion/ksampler.py
Normal file
312
invokeai/models/diffusion/ksampler.py
Normal file
@ -0,0 +1,312 @@
|
||||
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||
|
||||
import k_diffusion as K
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
# at this threshold, the scheduler will stop using the Karras
|
||||
# noise schedule and start using the model's schedule
|
||||
STEP_THRESHOLD = 30
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
return result
|
||||
maxval = 0.0 + torch.max(result).cpu().numpy()
|
||||
minval = 0.0 + torch.min(result).cpu().numpy()
|
||||
if maxval < threshold and minval > -threshold:
|
||||
return result
|
||||
if maxval > threshold:
|
||||
maxval = min(max(1, scale*maxval), threshold)
|
||||
if minval < -threshold:
|
||||
minval = max(min(-1, scale*minval), -threshold)
|
||||
return torch.clamp(result, min=minval, max=maxval)
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model, threshold = 0, warmup = 0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
else:
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(next_x, thresh)
|
||||
|
||||
class KSampler(Sampler):
|
||||
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
||||
denoiser = K.external.CompVisDenoiser(model)
|
||||
super().__init__(
|
||||
denoiser,
|
||||
schedule,
|
||||
steps=model.num_timesteps,
|
||||
)
|
||||
self.sigmas = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
|
||||
if self.karras_max is None:
|
||||
self.karras_max = STEP_THRESHOLD
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
outer_model = self.model
|
||||
self.model = outer_model.inner_model
|
||||
super().make_schedule(
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
)
|
||||
self.model = outer_model
|
||||
self.ddim_num_steps = ddim_num_steps
|
||||
# we don't need both of these sigmas, but storing them here to make
|
||||
# comparison easier later on
|
||||
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
||||
n=ddim_num_steps,
|
||||
sigma_min=self.model.sigmas[0].item(),
|
||||
sigma_max=self.model.sigmas[-1].item(),
|
||||
rho=7.,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if ddim_num_steps >= self.karras_max:
|
||||
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})')
|
||||
self.sigmas = self.model_sigmas
|
||||
else:
|
||||
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||
self.sigmas = self.karras_sigmas
|
||||
|
||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||
# means that inpainting will not work. To get this to work we need to be able to
|
||||
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||
# in the lstein/k-diffusion branch.
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
z_enc,
|
||||
cond,
|
||||
t_enc,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
**kwargs
|
||||
):
|
||||
samples,_ = self.sample(
|
||||
batch_size = 1,
|
||||
S = t_enc,
|
||||
x_T = z_enc,
|
||||
shape = z_enc.shape[1:],
|
||||
conditioning = cond,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning = unconditional_conditioning,
|
||||
img_callback = img_callback,
|
||||
x0 = init_latent,
|
||||
mask = mask,
|
||||
**kwargs
|
||||
)
|
||||
return samples
|
||||
|
||||
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
return x0
|
||||
|
||||
# Most of these arguments are ignored and are only present for compatibility with
|
||||
# other samples
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
attention_maps_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
def route_callback(k_callback_values):
|
||||
if img_callback is not None:
|
||||
img_callback(k_callback_values['x'],k_callback_values['i'])
|
||||
|
||||
# if make_schedule() hasn't been called, we do it now
|
||||
if self.sigmas is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
# more randomness to the starting image.
|
||||
if x_T is not None:
|
||||
if x0 is not None:
|
||||
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
|
||||
else:
|
||||
x = x_T * sigmas[0]
|
||||
else:
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
|
||||
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
||||
attention_map_saver = None
|
||||
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||
sampling_result = (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
if attention_map_saver is not None:
|
||||
attention_maps_callback(attention_map_saver)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
# a workaround is found.
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
self.model_wrap = CFGDenoiser(self.model)
|
||||
extra_args = {
|
||||
'cond': cond,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
if self.s_in is None:
|
||||
self.s_in = img.new_ones([img.shape[0]])
|
||||
if self.ds is None:
|
||||
self.ds = []
|
||||
|
||||
# terrible, confusing names here
|
||||
steps = self.ddim_num_steps
|
||||
t_enc = self.t_enc
|
||||
|
||||
# sigmas is a full steps in length, but t_enc might
|
||||
# be less. We start in the middle of the sigma array
|
||||
# and work our way to the end after t_enc steps.
|
||||
# index starts at t_enc and works its way to zero,
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap,
|
||||
img,
|
||||
self.sigmas,
|
||||
s_index,
|
||||
s_in = self.s_in,
|
||||
ds = self.ds,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
return img, None, None
|
||||
|
||||
# REVIEW THIS METHOD: it has never been tested. In particular,
|
||||
# we should not be multiplying by self.sigmas[0] if we
|
||||
# are at an intermediate step in img2img. See similar in
|
||||
# sample() which does work.
|
||||
def get_initial_image(self,x_T,shape,steps):
|
||||
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
||||
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
||||
if x_T is not None:
|
||||
return x_T + x
|
||||
else:
|
||||
return x
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def q_sample(self,x0,ts):
|
||||
'''
|
||||
Overrides parent method to return the q_sample of the inner model.
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
return self.model.inner_model.model.conditioning_key
|
||||
|
146
invokeai/models/diffusion/plms.py
Normal file
146
invokeai/models/diffusion/plms.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from .sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
|
||||
class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
x, # image, called 'img' elsewhere
|
||||
c, # conditioning, called 'cond' elsewhere
|
||||
t, # timesteps, called 'ts' elsewhere
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=[],
|
||||
t_next=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = (
|
||||
self.model.alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_alphas
|
||||
)
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full(
|
||||
(b, 1, 1, 1), alphas_prev[index], device=device
|
||||
)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = (
|
||||
sigma_t
|
||||
* noise_like(x.shape, device, repeat_noise)
|
||||
* temperature
|
||||
)
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (
|
||||
55 * e_t
|
||||
- 59 * old_eps[-1]
|
||||
+ 37 * old_eps[-2]
|
||||
- 9 * old_eps[-3]
|
||||
) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
450
invokeai/models/diffusion/sampler.py
Normal file
450
invokeai/models/diffusion/sampler.py
Normal file
@ -0,0 +1,450 @@
|
||||
'''
|
||||
invokeai.models.diffusion.sampler
|
||||
|
||||
Base class for invokeai.models.diffusion.ddim, invokeai.models.diffusion.ksampler, etc
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
class Sampler(object):
|
||||
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs):
|
||||
self.model = model
|
||||
self.ddim_timesteps = None
|
||||
self.ddpm_num_timesteps = steps
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.float32).to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
# This method was copied over from ddim.py and probably does stuff that is
|
||||
# ddim-specific. Disentangle at some point.
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
self.total_steps = ddim_num_steps
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), 'alphas have to be defined for each timestep'
|
||||
to_torch = (
|
||||
lambda x: x.clone()
|
||||
.detach()
|
||||
.to(torch.float32)
|
||||
.to(self.model.device)
|
||||
)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
(
|
||||
ddim_sigmas,
|
||||
ddim_alphas,
|
||||
ddim_alphas_prev,
|
||||
) = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer(
|
||||
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||
)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
'ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S, # S is steps
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=False,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# check to see if make_schedule() has run, and if not, run it
|
||||
if self.ddim_timesteps is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
)
|
||||
|
||||
ts = self.get_timesteps(S)
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
shape = (batch_size, C, H, W)
|
||||
samples, intermediates = self.do_sampling(
|
||||
conditioning,
|
||||
shape,
|
||||
timesteps=ts,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
steps=S,
|
||||
**kwargs
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def do_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
timesteps=None,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
steps=None,
|
||||
**kwargs
|
||||
):
|
||||
b = shape[0]
|
||||
time_range = (
|
||||
list(reversed(range(0, timesteps)))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
|
||||
total_steps=steps
|
||||
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
desc=f'{self.__class__.__name__}',
|
||||
total=total_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
|
||||
img = self.get_initial_image(x_T,shape,total_steps)
|
||||
|
||||
# probably don't need this at all
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(b,),
|
||||
step,
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
ts_next = torch.full(
|
||||
(b,),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
step_count=steps
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img,i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
# NOTE that decode() and sample() are almost the same code, and do the same thing.
|
||||
# The variable names are changed in order to be confusing.
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
all_timesteps_count = None,
|
||||
**kwargs
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
ts_next = torch.full(
|
||||
(x_latent.shape[0],),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
||||
|
||||
outs = self.p_sample(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
t_next = ts_next,
|
||||
step_count=len(self.ddim_timesteps)
|
||||
)
|
||||
|
||||
x_dec, pred_x0, e_t = outs
|
||||
if img_callback:
|
||||
img_callback(x_dec,i)
|
||||
|
||||
return x_dec
|
||||
|
||||
def get_initial_image(self,x_T,shape,timesteps=None):
|
||||
if x_T is None:
|
||||
return torch.randn(shape, device=self.device)
|
||||
else:
|
||||
return x_T
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
steps=None,
|
||||
):
|
||||
raise NotImplementedError("p_sample() must be implemented in a descendent class")
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
'''
|
||||
Hook that will be called right before the very first invocation of p_sample()
|
||||
to allow subclass to do additional initialization. t_enc corresponds to the actual
|
||||
number of steps that will be run, and may be less than total steps if img2img is
|
||||
active.
|
||||
'''
|
||||
pass
|
||||
|
||||
def get_timesteps(self,ddim_steps):
|
||||
'''
|
||||
The ddim and plms samplers work on timesteps. This method is called after
|
||||
ddim_timesteps are created in make_schedule(), and selects the portion of
|
||||
timesteps that will be used for sampling, depending on the t_enc in img2img.
|
||||
'''
|
||||
return self.ddim_timesteps[:ddim_steps]
|
||||
|
||||
def q_sample(self,x0,ts):
|
||||
'''
|
||||
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
'''
|
||||
return self.model.q_sample(x0,ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
return self.model.model.conditioning_key
|
||||
|
||||
def uses_inpainting_model(self)->bool:
|
||||
return self.conditioning_key() in ('hybrid','concat')
|
||||
|
||||
def adjust_settings(self,**kwargs):
|
||||
'''
|
||||
This is a catch-all method for adjusting any instance variables
|
||||
after the sampler is instantiated. No type-checking performed
|
||||
here, so use with care!
|
||||
'''
|
||||
for k in kwargs.keys():
|
||||
try:
|
||||
setattr(self,k,kwargs[k])
|
||||
except AttributeError:
|
||||
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')
|
491
invokeai/models/diffusion/shared_invokeai_diffusion.py
Normal file
491
invokeai/models/diffusion/shared_invokeai_diffusion.py
Normal file
@ -0,0 +1,491 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from .cross_attention_control import Arguments, \
|
||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
|
||||
At the moment it includes the following features:
|
||||
* Cross attention control ("prompt2prompt")
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
debug_thresholding = False
|
||||
sequential_guidance = False
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool=False,
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
# TODO resuscitate attention map saving
|
||||
#self.remove_attention_map_saving()
|
||||
|
||||
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
return override_cross_attention(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
|
||||
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
# sliced tokens attention map saving is not implemented
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = ('down' if identifier.startswith('down') else
|
||||
'up' if identifier.startswith('up') else
|
||||
'mid')
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None,
|
||||
total_step_count: Optional[int]=None,
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
:param unconditioning: embeddings for unconditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
|
||||
conditioning)
|
||||
elif wants_cross_attention_control:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
elif self.sequential_guidance:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning)
|
||||
|
||||
else:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning)
|
||||
|
||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||
|
||||
return combined_next_x
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
sigma,
|
||||
step_index,
|
||||
total_step_count
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == 'mps':
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
||||
if conditioned_next_x.device.type == 'mps':
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = dict()
|
||||
for k in conditioning:
|
||||
if isinstance(conditioning[k], list):
|
||||
both_conditionings[k] = [
|
||||
torch.cat([unconditioning[k][i], conditioning[k][i]])
|
||||
for i in range(len(conditioning[k]))
|
||||
]
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[])
|
||||
# no cross attention for unconditioning (negative prompt)
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context:Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
#print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
#print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
except:
|
||||
context.clear_requests(cleanup=True)
|
||||
raise
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
return combined_next_x
|
||||
|
||||
def apply_threshold(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||
return latents
|
||||
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
if percent_through < warmup:
|
||||
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||
else:
|
||||
current_threshold = threshold
|
||||
|
||||
if current_threshold <= 0:
|
||||
return latents
|
||||
|
||||
maxval = latents.max().item()
|
||||
minval = latents.min().item()
|
||||
|
||||
scale = 0.7 # default value from #395
|
||||
|
||||
if self.debug_thresholding:
|
||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
return latents
|
||||
|
||||
num_altered = 0
|
||||
|
||||
# MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
|
||||
|
||||
if maxval > current_threshold:
|
||||
latents = torch.clone(latents)
|
||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||
num_altered += torch.count_nonzero(latents > maxval)
|
||||
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||
|
||||
if minval < -current_threshold:
|
||||
latents = torch.clone(latents)
|
||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||
num_altered += torch.count_nonzero(latents < minval)
|
||||
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||
|
||||
if self.debug_thresholding:
|
||||
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
|
||||
|
||||
return latents
|
||||
|
||||
def apply_symmetry(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
if postprocessing_settings is None:
|
||||
return latents
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device='cpu')
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None and
|
||||
self.last_percent_through < h_symmetry_time_pct and
|
||||
percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None and
|
||||
self.last_percent_through < v_symmetry_time_pct and
|
||||
percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
# find the best possible index of the current sigma in the sigma sequence
|
||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||
# flip because sigmas[0] is for the fully denoised image
|
||||
# percent_through must be <1
|
||||
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
1221
invokeai/models/model_manager.py
Normal file
1221
invokeai/models/model_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user