backed out change from PR #44 that was causing ddim sampler to fail with the message 'sqrt _vml_cpu not implemented for 'Half'

This commit is contained in:
Lincoln Stein 2022-08-24 13:09:01 -04:00
parent 47a5da25b7
commit e043f238af

View File

@ -17,6 +17,9 @@ class DDIMSampler(object):
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):