mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
* add whole <style token> to vocab for concept library embeddings * add ability to load multiple concept .bin files * make --log_tokenization respect custom tokens * start working on concept downloading system * preliminary support for dynamic loading and merging of multiple embedded models - The embedding_manager is now enhanced with ldm.invoke.concepts_lib, which handles dynamic downloading and caching of embedded models from the Hugging Face concepts library (https://huggingface.co/sd-concepts-library) - Downloading of a embedded model is triggered by the presence of one or more <concept> tags in the prompt. - Once the embedded model is downloaded, its trigger phrase will be loaded into the embedding manager and the prompt's <concept> tag will be replaced with the <trigger_phrase> - The downloaded model stays on disk for fast loading later. - The CLI autocomplete will complete partial <concept> tags for you. Type a '<' and hit tab to get all ~700 concepts. BUGS AND LIMITATIONS: - MODEL NAME VS TRIGGER PHRASE You must use the name of the concept embed model from the SD library, and not the trigger phrase itself. Usually these are the same, but not always. For example, the model named "hoi4-leaders" corresponds to the trigger "<HOI4-Leader>" One reason for this design choice is that there is no apparent constraint on the uniqueness of the trigger phrases and one trigger phrase may map onto multiple models. So we use the model name instead. The second reason is that there is no way I know of to search Hugging Face for models with certain trigger phrases. So we'd have to download all 700 models to index the phrases. The problem this presents is that this may confuse users, who will want to reuse prompts from distributions that use the trigger phrase directly. Usually this will work, but not always. - WON'T WORK ON A FIREWALLED SYSTEM If the host running IAI has no internet connection, it can't download the concept libraries. I will add a script that allows users to preload a list of concept models. - BUG IN PROMPT REPLACEMENT WHEN MODEL NOT FOUND There's a small bug that occurs when the user provides an invalid model name. The <concept> gets replaced with <None> in the prompt. * fix loading .pt embeddings; allow multi-vector embeddings; warn on dupes * simplify replacement logic and remove cuda assumption * download list of concepts from hugging face * remove misleading customization of '*' placeholder the existing code as-is did not do anything; unclear what it was supposed to do. the obvious alternative -- setting using 'placeholder_strings' instead of 'placeholder_tokens' to match model.params.personalization_config.params.placeholder_strings -- caused a crash. i think this is because the passed string also needed to be handed over on init of the PersonalizedBase as the 'placeholder_token' argument. this is weird config dict magic and i don't want to touch it. put a breakpoint in personalzied.py line 116 (top of PersonalizedBase.__init__) if you want to have a crack at it yourself. * address all the issues raised by damian0815 in review of PR #1526 * actually resize the token_embeddings * multiple improvements to the concept loader based on code reviews 1. Activated the --embedding_directory option (alias --embedding_path) to load a single embedding or an entire directory of embeddings at startup time. 2. Can turn off automatic loading of embeddings using --no-embeddings. 3. Embedding checkpoints are scanned with the pickle scanner. 4. More informative error messages when a concept can't be loaded due either to a 404 not found error or a network error. * autocomplete terms end with ">" now * fix startup error and network unreachable 1. If the .invokeai file does not contain the --root and --outdir options, invoke.py will now fix it. 2. Catch and handle network problems when downloading hugging face textual inversion concepts. * fix misformatted error string Co-authored-by: Damian Stewart <d@damianstewart.com>
972 lines
32 KiB
Python
972 lines
32 KiB
Python
import argparse, os, sys, datetime, glob, importlib, csv
|
|
import numpy as np
|
|
import time
|
|
import torch
|
|
|
|
import torchvision
|
|
import pytorch_lightning as pl
|
|
|
|
from packaging import version
|
|
from omegaconf import OmegaConf
|
|
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
|
from functools import partial
|
|
from PIL import Image
|
|
|
|
from pytorch_lightning import seed_everything
|
|
from pytorch_lightning.trainer import Trainer
|
|
from pytorch_lightning.callbacks import (
|
|
ModelCheckpoint,
|
|
Callback,
|
|
LearningRateMonitor,
|
|
)
|
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
from pytorch_lightning.utilities import rank_zero_info
|
|
|
|
from ldm.data.base import Txt2ImgIterableBaseDataset
|
|
from ldm.util import instantiate_from_config
|
|
|
|
def fix_func(orig):
|
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
def new_func(*args, **kw):
|
|
device = kw.get("device", "mps")
|
|
kw["device"]="cpu"
|
|
return orig(*args, **kw).to(device)
|
|
return new_func
|
|
return orig
|
|
|
|
torch.rand = fix_func(torch.rand)
|
|
torch.rand_like = fix_func(torch.rand_like)
|
|
torch.randn = fix_func(torch.randn)
|
|
torch.randn_like = fix_func(torch.randn_like)
|
|
torch.randint = fix_func(torch.randint)
|
|
torch.randint_like = fix_func(torch.randint_like)
|
|
torch.bernoulli = fix_func(torch.bernoulli)
|
|
torch.multinomial = fix_func(torch.multinomial)
|
|
|
|
def load_model_from_config(config, ckpt, verbose=False):
|
|
print(f'Loading model from {ckpt}')
|
|
pl_sd = torch.load(ckpt, map_location='cpu')
|
|
sd = pl_sd['state_dict']
|
|
config.model.params.ckpt_path = ckpt
|
|
model = instantiate_from_config(config.model)
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
if len(m) > 0 and verbose:
|
|
print('missing keys:')
|
|
print(m)
|
|
if len(u) > 0 and verbose:
|
|
print('unexpected keys:')
|
|
print(u)
|
|
|
|
if torch.cuda.is_available():
|
|
model.cuda()
|
|
return model
|
|
|
|
|
|
def get_parser(**parser_kwargs):
|
|
def str2bool(v):
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
return True
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
|
|
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
|
parser.add_argument(
|
|
'-n',
|
|
'--name',
|
|
type=str,
|
|
const=True,
|
|
default='',
|
|
nargs='?',
|
|
help='postfix for logdir',
|
|
)
|
|
parser.add_argument(
|
|
'-r',
|
|
'--resume',
|
|
type=str,
|
|
const=True,
|
|
default='',
|
|
nargs='?',
|
|
help='resume from logdir or checkpoint in logdir',
|
|
)
|
|
parser.add_argument(
|
|
'-b',
|
|
'--base',
|
|
nargs='*',
|
|
metavar='base_config.yaml',
|
|
help='paths to base configs. Loaded from left-to-right. '
|
|
'Parameters can be overwritten or added with command-line options of the form `--key value`.',
|
|
default=list(),
|
|
)
|
|
parser.add_argument(
|
|
'-t',
|
|
'--train',
|
|
type=str2bool,
|
|
const=True,
|
|
default=False,
|
|
nargs='?',
|
|
help='train',
|
|
)
|
|
parser.add_argument(
|
|
'--no-test',
|
|
type=str2bool,
|
|
const=True,
|
|
default=False,
|
|
nargs='?',
|
|
help='disable test',
|
|
)
|
|
parser.add_argument(
|
|
'-p', '--project', help='name of new or path to existing project'
|
|
)
|
|
parser.add_argument(
|
|
'-d',
|
|
'--debug',
|
|
type=str2bool,
|
|
nargs='?',
|
|
const=True,
|
|
default=False,
|
|
help='enable post-mortem debugging',
|
|
)
|
|
parser.add_argument(
|
|
'-s',
|
|
'--seed',
|
|
type=int,
|
|
default=23,
|
|
help='seed for seed_everything',
|
|
)
|
|
parser.add_argument(
|
|
'-f',
|
|
'--postfix',
|
|
type=str,
|
|
default='',
|
|
help='post-postfix for default name',
|
|
)
|
|
parser.add_argument(
|
|
'-l',
|
|
'--logdir',
|
|
type=str,
|
|
default='logs',
|
|
help='directory for logging dat shit',
|
|
)
|
|
parser.add_argument(
|
|
'--scale_lr',
|
|
type=str2bool,
|
|
nargs='?',
|
|
const=True,
|
|
default=True,
|
|
help='scale base-lr by ngpu * batch_size * n_accumulate',
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--datadir_in_name',
|
|
type=str2bool,
|
|
nargs='?',
|
|
const=True,
|
|
default=True,
|
|
help='Prepend the final directory in the data_root to the output directory name',
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--actual_resume',
|
|
type=str,
|
|
default='',
|
|
help='Path to model to actually resume from',
|
|
)
|
|
parser.add_argument(
|
|
'--data_root',
|
|
type=str,
|
|
required=True,
|
|
help='Path to directory with training images',
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--embedding_manager_ckpt',
|
|
type=str,
|
|
default='',
|
|
help='Initialize embedding manager from a checkpoint',
|
|
)
|
|
parser.add_argument(
|
|
'--init_word',
|
|
type=str,
|
|
help='Word to use as source for initial token embedding.',
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def nondefault_trainer_args(opt):
|
|
parser = argparse.ArgumentParser()
|
|
parser = Trainer.add_argparse_args(parser)
|
|
args = parser.parse_args([])
|
|
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
|
|
|
|
|
class WrappedDataset(Dataset):
|
|
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
|
|
|
def __init__(self, dataset):
|
|
self.data = dataset
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data[idx]
|
|
|
|
|
|
def worker_init_fn(_):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
|
|
dataset = worker_info.dataset
|
|
worker_id = worker_info.id
|
|
|
|
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
|
split_size = dataset.num_records // worker_info.num_workers
|
|
# reset num_records to the true number to retain reliable length information
|
|
dataset.sample_ids = dataset.valid_ids[
|
|
worker_id * split_size : (worker_id + 1) * split_size
|
|
]
|
|
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
|
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
|
else:
|
|
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
|
|
|
|
class DataModuleFromConfig(pl.LightningDataModule):
|
|
def __init__(
|
|
self,
|
|
batch_size,
|
|
train=None,
|
|
validation=None,
|
|
test=None,
|
|
predict=None,
|
|
wrap=False,
|
|
num_workers=None,
|
|
shuffle_test_loader=False,
|
|
use_worker_init_fn=False,
|
|
shuffle_val_dataloader=False,
|
|
):
|
|
super().__init__()
|
|
self.batch_size = batch_size
|
|
self.dataset_configs = dict()
|
|
self.num_workers = (
|
|
num_workers if num_workers is not None else batch_size * 2
|
|
)
|
|
self.use_worker_init_fn = use_worker_init_fn
|
|
if train is not None:
|
|
self.dataset_configs['train'] = train
|
|
self.train_dataloader = self._train_dataloader
|
|
if validation is not None:
|
|
self.dataset_configs['validation'] = validation
|
|
self.val_dataloader = partial(
|
|
self._val_dataloader, shuffle=shuffle_val_dataloader
|
|
)
|
|
if test is not None:
|
|
self.dataset_configs['test'] = test
|
|
self.test_dataloader = partial(
|
|
self._test_dataloader, shuffle=shuffle_test_loader
|
|
)
|
|
if predict is not None:
|
|
self.dataset_configs['predict'] = predict
|
|
self.predict_dataloader = self._predict_dataloader
|
|
self.wrap = wrap
|
|
|
|
def prepare_data(self):
|
|
for data_cfg in self.dataset_configs.values():
|
|
instantiate_from_config(data_cfg)
|
|
|
|
def setup(self, stage=None):
|
|
self.datasets = dict(
|
|
(k, instantiate_from_config(self.dataset_configs[k]))
|
|
for k in self.dataset_configs
|
|
)
|
|
if self.wrap:
|
|
for k in self.datasets:
|
|
self.datasets[k] = WrappedDataset(self.datasets[k])
|
|
|
|
def _train_dataloader(self):
|
|
is_iterable_dataset = isinstance(
|
|
self.datasets['train'], Txt2ImgIterableBaseDataset
|
|
)
|
|
if is_iterable_dataset or self.use_worker_init_fn:
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
return DataLoader(
|
|
self.datasets['train'],
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
shuffle=False if is_iterable_dataset else True,
|
|
worker_init_fn=init_fn,
|
|
)
|
|
|
|
def _val_dataloader(self, shuffle=False):
|
|
if (
|
|
isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset)
|
|
or self.use_worker_init_fn
|
|
):
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
return DataLoader(
|
|
self.datasets['validation'],
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
worker_init_fn=init_fn,
|
|
shuffle=shuffle,
|
|
)
|
|
|
|
def _test_dataloader(self, shuffle=False):
|
|
is_iterable_dataset = isinstance(
|
|
self.datasets['train'], Txt2ImgIterableBaseDataset
|
|
)
|
|
if is_iterable_dataset or self.use_worker_init_fn:
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
|
|
# do not shuffle dataloader for iterable dataset
|
|
shuffle = shuffle and (not is_iterable_dataset)
|
|
|
|
return DataLoader(
|
|
self.datasets['test'],
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
worker_init_fn=init_fn,
|
|
shuffle=shuffle,
|
|
)
|
|
|
|
def _predict_dataloader(self, shuffle=False):
|
|
if (
|
|
isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset)
|
|
or self.use_worker_init_fn
|
|
):
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
return DataLoader(
|
|
self.datasets['predict'],
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
worker_init_fn=init_fn,
|
|
)
|
|
|
|
|
|
class SetupCallback(Callback):
|
|
def __init__(
|
|
self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config
|
|
):
|
|
super().__init__()
|
|
self.resume = resume
|
|
self.now = now
|
|
self.logdir = logdir
|
|
self.ckptdir = ckptdir
|
|
self.cfgdir = cfgdir
|
|
self.config = config
|
|
self.lightning_config = lightning_config
|
|
|
|
def on_keyboard_interrupt(self, trainer, pl_module):
|
|
if trainer.global_rank == 0:
|
|
print('Summoning checkpoint.')
|
|
ckpt_path = os.path.join(self.ckptdir, 'last.ckpt')
|
|
trainer.save_checkpoint(ckpt_path)
|
|
|
|
def on_pretrain_routine_start(self, trainer, pl_module):
|
|
if trainer.global_rank == 0:
|
|
# Create logdirs and save configs
|
|
os.makedirs(self.logdir, exist_ok=True)
|
|
os.makedirs(self.ckptdir, exist_ok=True)
|
|
os.makedirs(self.cfgdir, exist_ok=True)
|
|
|
|
if 'callbacks' in self.lightning_config:
|
|
if (
|
|
'metrics_over_trainsteps_checkpoint'
|
|
in self.lightning_config['callbacks']
|
|
):
|
|
os.makedirs(
|
|
os.path.join(self.ckptdir, 'trainstep_checkpoints'),
|
|
exist_ok=True,
|
|
)
|
|
print('Project config')
|
|
print(OmegaConf.to_yaml(self.config))
|
|
OmegaConf.save(
|
|
self.config,
|
|
os.path.join(self.cfgdir, '{}-project.yaml'.format(self.now)),
|
|
)
|
|
|
|
print('Lightning config')
|
|
print(OmegaConf.to_yaml(self.lightning_config))
|
|
OmegaConf.save(
|
|
OmegaConf.create({'lightning': self.lightning_config}),
|
|
os.path.join(
|
|
self.cfgdir, '{}-lightning.yaml'.format(self.now)
|
|
),
|
|
)
|
|
|
|
else:
|
|
# ModelCheckpoint callback created log directory --- remove it
|
|
if not self.resume and os.path.exists(self.logdir):
|
|
dst, name = os.path.split(self.logdir)
|
|
dst = os.path.join(dst, 'child_runs', name)
|
|
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
|
try:
|
|
os.rename(self.logdir, dst)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
|
|
class ImageLogger(Callback):
|
|
def __init__(
|
|
self,
|
|
batch_frequency,
|
|
max_images,
|
|
clamp=True,
|
|
increase_log_steps=True,
|
|
rescale=True,
|
|
disabled=False,
|
|
log_on_batch_idx=False,
|
|
log_first_step=False,
|
|
log_images_kwargs=None,
|
|
):
|
|
super().__init__()
|
|
self.rescale = rescale
|
|
self.batch_freq = batch_frequency
|
|
self.max_images = max_images
|
|
self.logger_log_images = { }
|
|
self.log_steps = [
|
|
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
|
]
|
|
if not increase_log_steps:
|
|
self.log_steps = [self.batch_freq]
|
|
self.clamp = clamp
|
|
self.disabled = disabled
|
|
self.log_on_batch_idx = log_on_batch_idx
|
|
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
|
self.log_first_step = log_first_step
|
|
|
|
@rank_zero_only
|
|
def log_local(
|
|
self, save_dir, split, images, global_step, current_epoch, batch_idx
|
|
):
|
|
root = os.path.join(save_dir, 'images', split)
|
|
for k in images:
|
|
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
|
if self.rescale:
|
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
|
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
|
grid = grid.numpy()
|
|
grid = (grid * 255).astype(np.uint8)
|
|
filename = '{}_gs-{:06}_e-{:06}_b-{:06}.png'.format(
|
|
k, global_step, current_epoch, batch_idx
|
|
)
|
|
path = os.path.join(root, filename)
|
|
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
|
Image.fromarray(grid).save(path)
|
|
|
|
def log_img(self, pl_module, batch, batch_idx, split='train'):
|
|
check_idx = (
|
|
batch_idx if self.log_on_batch_idx else pl_module.global_step
|
|
)
|
|
if (
|
|
self.check_frequency(check_idx)
|
|
and hasattr( # batch_idx % self.batch_freq == 0
|
|
pl_module, 'log_images'
|
|
)
|
|
and callable(pl_module.log_images)
|
|
and self.max_images > 0
|
|
):
|
|
logger = type(pl_module.logger)
|
|
|
|
is_train = pl_module.training
|
|
if is_train:
|
|
pl_module.eval()
|
|
|
|
with torch.no_grad():
|
|
images = pl_module.log_images(
|
|
batch, split=split, **self.log_images_kwargs
|
|
)
|
|
|
|
for k in images:
|
|
N = min(images[k].shape[0], self.max_images)
|
|
images[k] = images[k][:N]
|
|
if isinstance(images[k], torch.Tensor):
|
|
images[k] = images[k].detach().cpu()
|
|
if self.clamp:
|
|
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
|
|
|
self.log_local(
|
|
pl_module.logger.save_dir,
|
|
split,
|
|
images,
|
|
pl_module.global_step,
|
|
pl_module.current_epoch,
|
|
batch_idx,
|
|
)
|
|
|
|
logger_log_images = self.logger_log_images.get(
|
|
logger, lambda *args, **kwargs: None
|
|
)
|
|
logger_log_images(pl_module, images, pl_module.global_step, split)
|
|
|
|
if is_train:
|
|
pl_module.train()
|
|
|
|
def check_frequency(self, check_idx):
|
|
if (
|
|
(check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)
|
|
) and (check_idx > 0 or self.log_first_step):
|
|
try:
|
|
self.log_steps.pop(0)
|
|
except IndexError as e:
|
|
print(e)
|
|
pass
|
|
return True
|
|
return False
|
|
|
|
def on_train_batch_end(
|
|
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
|
|
):
|
|
if not self.disabled and (
|
|
pl_module.global_step > 0 or self.log_first_step
|
|
):
|
|
self.log_img(pl_module, batch, batch_idx, split='train')
|
|
|
|
def on_validation_batch_end(
|
|
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
|
|
):
|
|
if not self.disabled and pl_module.global_step > 0:
|
|
self.log_img(pl_module, batch, batch_idx, split='val')
|
|
if hasattr(pl_module, 'calibrate_grad_norm'):
|
|
if (
|
|
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
|
|
) and batch_idx > 0:
|
|
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
|
|
|
|
|
class CUDACallback(Callback):
|
|
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
# Reset the memory use counter
|
|
if torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
|
torch.cuda.synchronize(trainer.root_gpu)
|
|
self.start_time = time.time()
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize(trainer.root_gpu)
|
|
epoch_time = time.time() - self.start_time
|
|
|
|
try:
|
|
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
|
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
|
|
|
|
if torch.cuda.is_available():
|
|
max_memory = (
|
|
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
|
|
)
|
|
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
|
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
|
|
except AttributeError:
|
|
pass
|
|
|
|
class ModeSwapCallback(Callback):
|
|
|
|
def __init__(self, swap_step=2000):
|
|
super().__init__()
|
|
self.is_frozen = False
|
|
self.swap_step = swap_step
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
if trainer.global_step < self.swap_step and not self.is_frozen:
|
|
self.is_frozen = True
|
|
trainer.optimizers = [pl_module.configure_opt_embedding()]
|
|
|
|
if trainer.global_step > self.swap_step and self.is_frozen:
|
|
self.is_frozen = False
|
|
trainer.optimizers = [pl_module.configure_opt_model()]
|
|
|
|
if __name__ == '__main__':
|
|
# custom parser to specify config files, train, test and debug mode,
|
|
# postfix, resume.
|
|
# `--key value` arguments are interpreted as arguments to the trainer.
|
|
# `nested.key=value` arguments are interpreted as config parameters.
|
|
# configs are merged from left-to-right followed by command line parameters.
|
|
|
|
# model:
|
|
# base_learning_rate: float
|
|
# target: path to lightning module
|
|
# params:
|
|
# key: value
|
|
# data:
|
|
# target: main.DataModuleFromConfig
|
|
# params:
|
|
# batch_size: int
|
|
# wrap: bool
|
|
# train:
|
|
# target: path to train dataset
|
|
# params:
|
|
# key: value
|
|
# validation:
|
|
# target: path to validation dataset
|
|
# params:
|
|
# key: value
|
|
# test:
|
|
# target: path to test dataset
|
|
# params:
|
|
# key: value
|
|
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
|
# trainer:
|
|
# additional arguments to trainer
|
|
# logger:
|
|
# logger to instantiate
|
|
# modelcheckpoint:
|
|
# modelcheckpoint to instantiate
|
|
# callbacks:
|
|
# callback1:
|
|
# target: importpath
|
|
# params:
|
|
# key: value
|
|
|
|
now = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
|
|
|
|
# add cwd for convenience and to make classes in this file available when
|
|
# running as `python main.py`
|
|
# (in particular `main.DataModuleFromConfig`)
|
|
sys.path.append(os.getcwd())
|
|
|
|
parser = get_parser()
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
opt, unknown = parser.parse_known_args()
|
|
if opt.name and opt.resume:
|
|
raise ValueError(
|
|
'-n/--name and -r/--resume cannot be specified both.'
|
|
'If you want to resume training in a new log folder, '
|
|
'use -n/--name in combination with --resume_from_checkpoint'
|
|
)
|
|
if opt.resume:
|
|
if not os.path.exists(opt.resume):
|
|
raise ValueError('Cannot find {}'.format(opt.resume))
|
|
if os.path.isfile(opt.resume):
|
|
paths = opt.resume.split('/')
|
|
# idx = len(paths)-paths[::-1].index("logs")+1
|
|
# logdir = "/".join(paths[:idx])
|
|
logdir = '/'.join(paths[:-2])
|
|
ckpt = opt.resume
|
|
else:
|
|
assert os.path.isdir(opt.resume), opt.resume
|
|
logdir = opt.resume.rstrip('/')
|
|
ckpt = os.path.join(logdir, 'checkpoints', 'last.ckpt')
|
|
|
|
opt.resume_from_checkpoint = ckpt
|
|
base_configs = sorted(
|
|
glob.glob(os.path.join(logdir, 'configs/*.yaml'))
|
|
)
|
|
opt.base = base_configs + opt.base
|
|
_tmp = logdir.split('/')
|
|
nowname = _tmp[-1]
|
|
else:
|
|
if opt.name:
|
|
name = '_' + opt.name
|
|
elif opt.base:
|
|
cfg_fname = os.path.split(opt.base[0])[-1]
|
|
cfg_name = os.path.splitext(cfg_fname)[0]
|
|
name = '_' + cfg_name
|
|
else:
|
|
name = ''
|
|
|
|
if opt.datadir_in_name:
|
|
now = os.path.basename(os.path.normpath(opt.data_root)) + now
|
|
|
|
|
|
nowname = now + name + opt.postfix
|
|
logdir = os.path.join(opt.logdir, nowname)
|
|
|
|
ckptdir = os.path.join(logdir, 'checkpoints')
|
|
cfgdir = os.path.join(logdir, 'configs')
|
|
seed_everything(opt.seed)
|
|
|
|
try:
|
|
# init and save configs
|
|
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
|
cli = OmegaConf.from_dotlist(unknown)
|
|
config = OmegaConf.merge(*configs, cli)
|
|
lightning_config = config.pop('lightning', OmegaConf.create())
|
|
# merge trainer cli with config
|
|
trainer_config = lightning_config.get('trainer', OmegaConf.create())
|
|
# default to ddp
|
|
trainer_config['accelerator'] = 'auto'
|
|
for k in nondefault_trainer_args(opt):
|
|
trainer_config[k] = getattr(opt, k)
|
|
if not 'gpus' in trainer_config:
|
|
del trainer_config['accelerator']
|
|
cpu = True
|
|
else:
|
|
gpuinfo = trainer_config['gpus']
|
|
print(f'Running on GPUs {gpuinfo}')
|
|
cpu = False
|
|
trainer_opt = argparse.Namespace(**trainer_config)
|
|
lightning_config.trainer = trainer_config
|
|
|
|
# model
|
|
|
|
# config.model.params.personalization_config.params.init_word = opt.init_word
|
|
config.model.params.personalization_config.params.embedding_manager_ckpt = (
|
|
opt.embedding_manager_ckpt
|
|
)
|
|
|
|
if opt.init_word:
|
|
config.model.params.personalization_config.params.initializer_words = [opt.init_word]
|
|
|
|
if opt.actual_resume:
|
|
model = load_model_from_config(config, opt.actual_resume)
|
|
else:
|
|
model = instantiate_from_config(config.model)
|
|
|
|
# trainer and callbacks
|
|
trainer_kwargs = dict()
|
|
|
|
# default logger configs
|
|
def_logger = 'csv'
|
|
def_logger_target = 'CSVLogger'
|
|
default_logger_cfgs = {
|
|
'wandb': {
|
|
'target': 'pytorch_lightning.loggers.WandbLogger',
|
|
'params': {
|
|
'name': nowname,
|
|
'save_dir': logdir,
|
|
'offline': opt.debug,
|
|
'id': nowname,
|
|
},
|
|
},
|
|
def_logger: {
|
|
'target': 'pytorch_lightning.loggers.' + def_logger_target,
|
|
'params': {
|
|
'name': def_logger,
|
|
'save_dir': logdir,
|
|
},
|
|
},
|
|
}
|
|
default_logger_cfg = default_logger_cfgs[def_logger]
|
|
if 'logger' in lightning_config:
|
|
logger_cfg = lightning_config.logger
|
|
else:
|
|
logger_cfg = OmegaConf.create()
|
|
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
|
trainer_kwargs['logger'] = instantiate_from_config(logger_cfg)
|
|
|
|
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
|
# specify which metric is used to determine best models
|
|
default_modelckpt_cfg = {
|
|
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
|
|
'params': {
|
|
'dirpath': ckptdir,
|
|
'filename': '{epoch:06}',
|
|
'verbose': True,
|
|
'save_last': True,
|
|
},
|
|
}
|
|
if hasattr(model, 'monitor'):
|
|
print(f'Monitoring {model.monitor} as checkpoint metric.')
|
|
default_modelckpt_cfg['params']['monitor'] = model.monitor
|
|
default_modelckpt_cfg['params']['save_top_k'] = 1
|
|
|
|
if 'modelcheckpoint' in lightning_config:
|
|
modelckpt_cfg = lightning_config.modelcheckpoint
|
|
else:
|
|
modelckpt_cfg = OmegaConf.create()
|
|
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
|
print(f'Merged modelckpt-cfg: \n{modelckpt_cfg}')
|
|
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
|
trainer_kwargs['checkpoint_callback'] = instantiate_from_config(
|
|
modelckpt_cfg
|
|
)
|
|
|
|
# add callback which sets up log directory
|
|
default_callbacks_cfg = {
|
|
'setup_callback': {
|
|
'target': 'main.SetupCallback',
|
|
'params': {
|
|
'resume': opt.resume,
|
|
'now': now,
|
|
'logdir': logdir,
|
|
'ckptdir': ckptdir,
|
|
'cfgdir': cfgdir,
|
|
'config': config,
|
|
'lightning_config': lightning_config,
|
|
},
|
|
},
|
|
'image_logger': {
|
|
'target': 'main.ImageLogger',
|
|
'params': {
|
|
'batch_frequency': 750,
|
|
'max_images': 4,
|
|
'clamp': True,
|
|
},
|
|
},
|
|
'learning_rate_logger': {
|
|
'target': 'main.LearningRateMonitor',
|
|
'params': {
|
|
'logging_interval': 'step',
|
|
# "log_momentum": True
|
|
},
|
|
},
|
|
'cuda_callback': {'target': 'main.CUDACallback'},
|
|
}
|
|
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
|
default_callbacks_cfg.update(
|
|
{'checkpoint_callback': modelckpt_cfg}
|
|
)
|
|
|
|
if 'callbacks' in lightning_config:
|
|
callbacks_cfg = lightning_config.callbacks
|
|
else:
|
|
callbacks_cfg = OmegaConf.create()
|
|
|
|
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
|
print(
|
|
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
|
|
)
|
|
default_metrics_over_trainsteps_ckpt_dict = {
|
|
'metrics_over_trainsteps_checkpoint': {
|
|
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
|
|
'params': {
|
|
'dirpath': os.path.join(
|
|
ckptdir, 'trainstep_checkpoints'
|
|
),
|
|
'filename': '{epoch:06}-{step:09}',
|
|
'verbose': True,
|
|
'save_top_k': -1,
|
|
'every_n_train_steps': 10000,
|
|
'save_weights_only': True,
|
|
},
|
|
}
|
|
}
|
|
default_callbacks_cfg.update(
|
|
default_metrics_over_trainsteps_ckpt_dict
|
|
)
|
|
|
|
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
|
if 'ignore_keys_callback' in callbacks_cfg and hasattr(
|
|
trainer_opt, 'resume_from_checkpoint'
|
|
):
|
|
callbacks_cfg.ignore_keys_callback.params[
|
|
'ckpt_path'
|
|
] = trainer_opt.resume_from_checkpoint
|
|
elif 'ignore_keys_callback' in callbacks_cfg:
|
|
del callbacks_cfg['ignore_keys_callback']
|
|
|
|
trainer_kwargs['callbacks'] = [
|
|
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
|
]
|
|
trainer_kwargs['max_steps'] = trainer_opt.max_steps
|
|
|
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
trainer_opt.accelerator = 'mps'
|
|
trainer_opt.detect_anomaly = False
|
|
|
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
|
trainer.logdir = logdir ###
|
|
|
|
# data
|
|
config.data.params.train.params.data_root = opt.data_root
|
|
config.data.params.validation.params.data_root = opt.data_root
|
|
data = instantiate_from_config(config.data)
|
|
|
|
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
|
# calling these ourselves should not be necessary but it is.
|
|
# lightning still takes care of proper multiprocessing though
|
|
data.prepare_data()
|
|
data.setup()
|
|
print('#### Data #####')
|
|
for k in data.datasets:
|
|
print(
|
|
f'{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}'
|
|
)
|
|
|
|
# configure learning rate
|
|
bs, base_lr = (
|
|
config.data.params.batch_size,
|
|
config.model.base_learning_rate,
|
|
)
|
|
if not cpu:
|
|
gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
|
|
ngpu = len(gpus)
|
|
else:
|
|
ngpu = 1
|
|
if 'accumulate_grad_batches' in lightning_config.trainer:
|
|
accumulate_grad_batches = (
|
|
lightning_config.trainer.accumulate_grad_batches
|
|
)
|
|
else:
|
|
accumulate_grad_batches = 1
|
|
print(f'accumulate_grad_batches = {accumulate_grad_batches}')
|
|
lightning_config.trainer.accumulate_grad_batches = (
|
|
accumulate_grad_batches
|
|
)
|
|
if opt.scale_lr:
|
|
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
|
print(
|
|
'Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)'.format(
|
|
model.learning_rate,
|
|
accumulate_grad_batches,
|
|
ngpu,
|
|
bs,
|
|
base_lr,
|
|
)
|
|
)
|
|
else:
|
|
model.learning_rate = base_lr
|
|
print('++++ NOT USING LR SCALING ++++')
|
|
print(f'Setting learning rate to {model.learning_rate:.2e}')
|
|
|
|
# allow checkpointing via USR1
|
|
def melk(*args, **kwargs):
|
|
# run all checkpoint hooks
|
|
if trainer.global_rank == 0:
|
|
print('Summoning checkpoint.')
|
|
ckpt_path = os.path.join(ckptdir, 'last.ckpt')
|
|
trainer.save_checkpoint(ckpt_path)
|
|
|
|
def divein(*args, **kwargs):
|
|
if trainer.global_rank == 0:
|
|
import pudb
|
|
|
|
pudb.set_trace()
|
|
|
|
import signal
|
|
|
|
signal.signal(signal.SIGTERM, melk)
|
|
signal.signal(signal.SIGTERM, divein)
|
|
|
|
# run
|
|
if opt.train:
|
|
try:
|
|
trainer.fit(model, data)
|
|
except Exception:
|
|
melk()
|
|
raise
|
|
if not opt.no_test and not trainer.interrupted:
|
|
trainer.test(model, data)
|
|
except Exception:
|
|
if opt.debug and trainer.global_rank == 0:
|
|
try:
|
|
import pudb as debugger
|
|
except ImportError:
|
|
import pdb as debugger
|
|
debugger.post_mortem()
|
|
raise
|
|
finally:
|
|
# move newly created debug project to debug_runs
|
|
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
|
dst, name = os.path.split(logdir)
|
|
dst = os.path.join(dst, 'debug_runs', name)
|
|
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
|
os.rename(logdir, dst)
|
|
# if trainer.global_rank == 0:
|
|
# print(trainer.profiler.summary())
|