fix performance regression; closes issue #42

This commit is contained in:
Lincoln Stein 2022-08-25 09:41:12 -04:00
parent 0b4459b707
commit 49247b4aa4

View File

@ -486,6 +486,7 @@ The vast majority of these arguments default to reasonable values.
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
model.cuda() # fixes performance issue
model.eval() model.eval()
if self.full_precision: if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)') print('Using slower but more accurate full-precision math (--full_precision)')