improved error reporting when a missing online dependency can't be downloaded

This commit is contained in:
Lincoln Stein 2022-08-17 18:06:30 -04:00
parent c477525036
commit 831bbd7a54
3 changed files with 11 additions and 4 deletions

View File

@ -17,6 +17,7 @@ from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
import urllib
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
@ -524,7 +525,10 @@ class LatentDiffusion(DDPM):
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
try:
model = instantiate_from_config(config)
except urllib.error.URLError:
raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.")
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):

View File

@ -60,7 +60,10 @@ class BERTTokenizer(AbstractEncoder):
# by running:
# from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
try:
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
except OSError:
raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length

View File

@ -5,13 +5,13 @@
# two machines must share a common .cache directory.
# this will preload the Bert tokenizer fles
print("preloading bert tokenizer...",end='')
print("preloading bert tokenizer...")
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
print("...success")
# this will download requirements for Kornia
print("preloading Kornia requirements...",end='')
print("preloading Kornia requirements...")
import kornia
print("...success")