PR revision: replace cuda call with dynamic type

This commit is contained in:
JigenD 2022-08-25 12:18:35 -04:00
parent eb58276a2c
commit e82c5eba18

View File

@ -494,7 +494,7 @@ The vast majority of these arguments default to reasonable values.
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.to(self.device)
model.eval()
if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)')