diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 2ebaeabd22..065b32986a 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -17,6 +17,9 @@ class DDIMSampler(object): self.schedule = schedule def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):