InvokeAI/main.py
Lincoln Stein 1e1f871ee1
Embedding merging (#1526)
* add whole <style token> to vocab for concept library embeddings

* add ability to load multiple concept .bin files

* make --log_tokenization respect custom tokens

* start working on concept downloading system

* preliminary support for dynamic loading and merging of multiple embedded models

- The embedding_manager is now enhanced with ldm.invoke.concepts_lib,
  which handles dynamic downloading and caching of embedded models from
  the Hugging Face concepts library (https://huggingface.co/sd-concepts-library)

- Downloading of a embedded model is triggered by the presence of one or more
  <concept> tags in the prompt.

- Once the embedded model is downloaded, its trigger phrase will be loaded
  into the embedding manager and the prompt's <concept> tag will be replaced
  with the <trigger_phrase>

- The downloaded model stays on disk for fast loading later.

- The CLI autocomplete will complete partial <concept> tags for you. Type a
  '<' and hit tab to get all ~700 concepts.

BUGS AND LIMITATIONS:

- MODEL NAME VS TRIGGER PHRASE

  You must use the name of the concept embed model from the SD
  library, and not the trigger phrase itself. Usually these are the
  same, but not always. For example, the model named "hoi4-leaders"
  corresponds to the trigger "<HOI4-Leader>"

  One reason for this design choice is that there is no apparent
  constraint on the uniqueness of the trigger phrases and one trigger
  phrase may map onto multiple models. So we use the model name
  instead.

  The second reason is that there is no way I know of to search
  Hugging Face for models with certain trigger phrases. So we'd have
  to download all 700 models to index the phrases.

  The problem this presents is that this may confuse users, who will
  want to reuse prompts from distributions that use the trigger phrase
  directly. Usually this will work, but not always.

- WON'T WORK ON A FIREWALLED SYSTEM

  If the host running IAI has no internet connection, it can't
  download the concept libraries. I will add a script that allows
  users to preload a list of concept models.

- BUG IN PROMPT REPLACEMENT WHEN MODEL NOT FOUND

  There's a small bug that occurs when the user provides an invalid
  model name. The <concept> gets replaced with <None> in the prompt.

* fix loading .pt embeddings; allow multi-vector embeddings; warn on dupes

* simplify replacement logic and remove cuda assumption

* download list of concepts from hugging face

* remove misleading customization of '*' placeholder

the existing code as-is did not do anything; unclear what it was supposed to do.

the obvious alternative -- setting using 'placeholder_strings' instead of
'placeholder_tokens' to match model.params.personalization_config.params.placeholder_strings --
caused a crash. i think this is because the passed string also needed to be handed over
on init of the PersonalizedBase as the 'placeholder_token' argument.
this is weird config dict magic and i don't want to touch it. put a
breakpoint in personalzied.py line 116 (top of PersonalizedBase.__init__) if
you want to have a crack at it yourself.

* address all the issues raised by damian0815 in review of PR #1526

* actually resize the token_embeddings

* multiple improvements to the concept loader based on code reviews

1. Activated the --embedding_directory option (alias --embedding_path)
   to load a single embedding or an entire directory of embeddings at
   startup time.

2. Can turn off automatic loading of embeddings using --no-embeddings.

3. Embedding checkpoints are scanned with the pickle scanner.

4. More informative error messages when a concept can't be loaded due
   either to a 404 not found error or a network error.

* autocomplete terms end with ">" now

* fix startup error and network unreachable

1. If the .invokeai file does not contain the --root and --outdir options,
  invoke.py will now fix it.

2. Catch and handle network problems when downloading hugging face textual
   inversion concepts.

* fix misformatted error string

Co-authored-by: Damian Stewart <d@damianstewart.com>
2022-11-28 02:40:24 -05:00

972 lines
32 KiB
Python

import argparse, os, sys, datetime, glob, importlib, csv
import numpy as np
import time
import torch
import torchvision
import pytorch_lightning as pl
from packaging import version
from omegaconf import OmegaConf
from torch.utils.data import random_split, DataLoader, Dataset, Subset
from functools import partial
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import (
ModelCheckpoint,
Callback,
LearningRateMonitor,
)
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def load_model_from_config(config, ckpt, verbose=False):
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')
sd = pl_sd['state_dict']
config.model.params.ckpt_path = ckpt
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('missing keys:')
print(m)
if len(u) > 0 and verbose:
print('unexpected keys:')
print(u)
if torch.cuda.is_available():
model.cuda()
return model
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
'-n',
'--name',
type=str,
const=True,
default='',
nargs='?',
help='postfix for logdir',
)
parser.add_argument(
'-r',
'--resume',
type=str,
const=True,
default='',
nargs='?',
help='resume from logdir or checkpoint in logdir',
)
parser.add_argument(
'-b',
'--base',
nargs='*',
metavar='base_config.yaml',
help='paths to base configs. Loaded from left-to-right. '
'Parameters can be overwritten or added with command-line options of the form `--key value`.',
default=list(),
)
parser.add_argument(
'-t',
'--train',
type=str2bool,
const=True,
default=False,
nargs='?',
help='train',
)
parser.add_argument(
'--no-test',
type=str2bool,
const=True,
default=False,
nargs='?',
help='disable test',
)
parser.add_argument(
'-p', '--project', help='name of new or path to existing project'
)
parser.add_argument(
'-d',
'--debug',
type=str2bool,
nargs='?',
const=True,
default=False,
help='enable post-mortem debugging',
)
parser.add_argument(
'-s',
'--seed',
type=int,
default=23,
help='seed for seed_everything',
)
parser.add_argument(
'-f',
'--postfix',
type=str,
default='',
help='post-postfix for default name',
)
parser.add_argument(
'-l',
'--logdir',
type=str,
default='logs',
help='directory for logging dat shit',
)
parser.add_argument(
'--scale_lr',
type=str2bool,
nargs='?',
const=True,
default=True,
help='scale base-lr by ngpu * batch_size * n_accumulate',
)
parser.add_argument(
'--datadir_in_name',
type=str2bool,
nargs='?',
const=True,
default=True,
help='Prepend the final directory in the data_root to the output directory name',
)
parser.add_argument(
'--actual_resume',
type=str,
default='',
help='Path to model to actually resume from',
)
parser.add_argument(
'--data_root',
type=str,
required=True,
help='Path to directory with training images',
)
parser.add_argument(
'--embedding_manager_ckpt',
type=str,
default='',
help='Initialize embedding manager from a checkpoint',
)
parser.add_argument(
'--init_word',
type=str,
help='Word to use as source for initial token embedding.',
)
return parser
def nondefault_trainer_args(opt):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args([])
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[
worker_id * split_size : (worker_id + 1) * split_size
]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = (
num_workers if num_workers is not None else batch_size * 2
)
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs['train'] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs['validation'] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs['test'] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs['predict'] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs
)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = isinstance(
self.datasets['train'], Txt2ImgIterableBaseDataset
)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['train'],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False if is_iterable_dataset else True,
worker_init_fn=init_fn,
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['validation'],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(
self.datasets['train'], Txt2ImgIterableBaseDataset
)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets['test'],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['predict'],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
)
class SetupCallback(Callback):
def __init__(
self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.lightning_config = lightning_config
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
print('Summoning checkpoint.')
ckpt_path = os.path.join(self.ckptdir, 'last.ckpt')
trainer.save_checkpoint(ckpt_path)
def on_pretrain_routine_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
if 'callbacks' in self.lightning_config:
if (
'metrics_over_trainsteps_checkpoint'
in self.lightning_config['callbacks']
):
os.makedirs(
os.path.join(self.ckptdir, 'trainstep_checkpoints'),
exist_ok=True,
)
print('Project config')
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(
self.config,
os.path.join(self.cfgdir, '{}-project.yaml'.format(self.now)),
)
print('Lightning config')
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(
OmegaConf.create({'lightning': self.lightning_config}),
os.path.join(
self.cfgdir, '{}-lightning.yaml'.format(self.now)
),
)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, 'child_runs', name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = { }
self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
@rank_zero_only
def log_local(
self, save_dir, split, images, global_step, current_epoch, batch_idx
):
root = os.path.join(save_dir, 'images', split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = '{}_gs-{:06}_e-{:06}_b-{:06}.png'.format(
k, global_step, current_epoch, batch_idx
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split='train'):
check_idx = (
batch_idx if self.log_on_batch_idx else pl_module.global_step
)
if (
self.check_frequency(check_idx)
and hasattr( # batch_idx % self.batch_freq == 0
pl_module, 'log_images'
)
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (
(check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)
) and (check_idx > 0 or self.log_first_step):
try:
self.log_steps.pop(0)
except IndexError as e:
print(e)
pass
return True
return False
def on_train_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and (
pl_module.global_step > 0 or self.log_first_step
):
self.log_img(pl_module, batch, batch_idx, split='train')
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split='val')
if hasattr(pl_module, 'calibrate_grad_norm'):
if (
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
torch.cuda.synchronize(trainer.root_gpu)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
if torch.cuda.is_available():
torch.cuda.synchronize(trainer.root_gpu)
epoch_time = time.time() - self.start_time
try:
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
if torch.cuda.is_available():
max_memory = (
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
)
max_memory = trainer.training_type_plugin.reduce(max_memory)
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
except AttributeError:
pass
class ModeSwapCallback(Callback):
def __init__(self, swap_step=2000):
super().__init__()
self.is_frozen = False
self.swap_step = swap_step
def on_train_epoch_start(self, trainer, pl_module):
if trainer.global_step < self.swap_step and not self.is_frozen:
self.is_frozen = True
trainer.optimizers = [pl_module.configure_opt_embedding()]
if trainer.global_step > self.swap_step and self.is_frozen:
self.is_frozen = False
trainer.optimizers = [pl_module.configure_opt_model()]
if __name__ == '__main__':
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys.path.append(os.getcwd())
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
'-n/--name and -r/--resume cannot be specified both.'
'If you want to resume training in a new log folder, '
'use -n/--name in combination with --resume_from_checkpoint'
)
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError('Cannot find {}'.format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split('/')
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = '/'.join(paths[:-2])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip('/')
ckpt = os.path.join(logdir, 'checkpoints', 'last.ckpt')
opt.resume_from_checkpoint = ckpt
base_configs = sorted(
glob.glob(os.path.join(logdir, 'configs/*.yaml'))
)
opt.base = base_configs + opt.base
_tmp = logdir.split('/')
nowname = _tmp[-1]
else:
if opt.name:
name = '_' + opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = '_' + cfg_name
else:
name = ''
if opt.datadir_in_name:
now = os.path.basename(os.path.normpath(opt.data_root)) + now
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
ckptdir = os.path.join(logdir, 'checkpoints')
cfgdir = os.path.join(logdir, 'configs')
seed_everything(opt.seed)
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop('lightning', OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get('trainer', OmegaConf.create())
# default to ddp
trainer_config['accelerator'] = 'auto'
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not 'gpus' in trainer_config:
del trainer_config['accelerator']
cpu = True
else:
gpuinfo = trainer_config['gpus']
print(f'Running on GPUs {gpuinfo}')
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
# config.model.params.personalization_config.params.init_word = opt.init_word
config.model.params.personalization_config.params.embedding_manager_ckpt = (
opt.embedding_manager_ckpt
)
if opt.init_word:
config.model.params.personalization_config.params.initializer_words = [opt.init_word]
if opt.actual_resume:
model = load_model_from_config(config, opt.actual_resume)
else:
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
# default logger configs
def_logger = 'csv'
def_logger_target = 'CSVLogger'
default_logger_cfgs = {
'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger',
'params': {
'name': nowname,
'save_dir': logdir,
'offline': opt.debug,
'id': nowname,
},
},
def_logger: {
'target': 'pytorch_lightning.loggers.' + def_logger_target,
'params': {
'name': def_logger,
'save_dir': logdir,
},
},
}
default_logger_cfg = default_logger_cfgs[def_logger]
if 'logger' in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs['logger'] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
'dirpath': ckptdir,
'filename': '{epoch:06}',
'verbose': True,
'save_last': True,
},
}
if hasattr(model, 'monitor'):
print(f'Monitoring {model.monitor} as checkpoint metric.')
default_modelckpt_cfg['params']['monitor'] = model.monitor
default_modelckpt_cfg['params']['save_top_k'] = 1
if 'modelcheckpoint' in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f'Merged modelckpt-cfg: \n{modelckpt_cfg}')
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs['checkpoint_callback'] = instantiate_from_config(
modelckpt_cfg
)
# add callback which sets up log directory
default_callbacks_cfg = {
'setup_callback': {
'target': 'main.SetupCallback',
'params': {
'resume': opt.resume,
'now': now,
'logdir': logdir,
'ckptdir': ckptdir,
'cfgdir': cfgdir,
'config': config,
'lightning_config': lightning_config,
},
},
'image_logger': {
'target': 'main.ImageLogger',
'params': {
'batch_frequency': 750,
'max_images': 4,
'clamp': True,
},
},
'learning_rate_logger': {
'target': 'main.LearningRateMonitor',
'params': {
'logging_interval': 'step',
# "log_momentum": True
},
},
'cuda_callback': {'target': 'main.CUDACallback'},
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update(
{'checkpoint_callback': modelckpt_cfg}
)
if 'callbacks' in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
'dirpath': os.path.join(
ckptdir, 'trainstep_checkpoints'
),
'filename': '{epoch:06}-{step:09}',
'verbose': True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True,
},
}
}
default_callbacks_cfg.update(
default_metrics_over_trainsteps_ckpt_dict
)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if 'ignore_keys_callback' in callbacks_cfg and hasattr(
trainer_opt, 'resume_from_checkpoint'
):
callbacks_cfg.ignore_keys_callback.params[
'ckpt_path'
] = trainer_opt.resume_from_checkpoint
elif 'ignore_keys_callback' in callbacks_cfg:
del callbacks_cfg['ignore_keys_callback']
trainer_kwargs['callbacks'] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
trainer_kwargs['max_steps'] = trainer_opt.max_steps
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
trainer_opt.accelerator = 'mps'
trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###
# data
config.data.params.train.params.data_root = opt.data_root
config.data.params.validation.params.data_root = opt.data_root
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
print('#### Data #####')
for k in data.datasets:
print(
f'{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}'
)
# configure learning rate
bs, base_lr = (
config.data.params.batch_size,
config.model.base_learning_rate,
)
if not cpu:
gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
ngpu = len(gpus)
else:
ngpu = 1
if 'accumulate_grad_batches' in lightning_config.trainer:
accumulate_grad_batches = (
lightning_config.trainer.accumulate_grad_batches
)
else:
accumulate_grad_batches = 1
print(f'accumulate_grad_batches = {accumulate_grad_batches}')
lightning_config.trainer.accumulate_grad_batches = (
accumulate_grad_batches
)
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
'Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)'.format(
model.learning_rate,
accumulate_grad_batches,
ngpu,
bs,
base_lr,
)
)
else:
model.learning_rate = base_lr
print('++++ NOT USING LR SCALING ++++')
print(f'Setting learning rate to {model.learning_rate:.2e}')
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print('Summoning checkpoint.')
ckpt_path = os.path.join(ckptdir, 'last.ckpt')
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGTERM, melk)
signal.signal(signal.SIGTERM, divein)
# run
if opt.train:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
if not opt.no_test and not trainer.interrupted:
trainer.test(model, data)
except Exception:
if opt.debug and trainer.global_rank == 0:
try:
import pudb as debugger
except ImportError:
import pdb as debugger
debugger.post_mortem()
raise
finally:
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, 'debug_runs', name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
# if trainer.global_rank == 0:
# print(trainer.profiler.summary())