mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes issue with cuda/current mismatch
This commit is contained in:
@ -10,13 +10,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
@ -9,13 +9,18 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.device(self.device))
|
||||
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
Reference in New Issue
Block a user