Fixes issue with cuda/current mismatch

This commit is contained in:
Sean McLellan
2022-08-24 13:14:08 -04:00
parent c6b5e930dc
commit b5cdbd3b0b
5 changed files with 21 additions and 12 deletions

View File

@ -10,13 +10,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -9,13 +9,18 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):