Undo more 'cuda' hardcoding

This commit is contained in:
Benjamin Warner 2022-08-24 00:39:25 -05:00
parent de1cea92ce
commit 886f1c0138
6 changed files with 5 additions and 6 deletions

View File

@ -17,9 +17,6 @@ class DDIMSampler(object):
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -16,9 +16,6 @@ class PLMSSampler(object):
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -432,6 +432,8 @@ The vast majority of these arguments default to reasonable values.
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
model = self._load_model_from_config(config,self.weights)
self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device
except AttributeError:
raise SystemExit

1
src/clip Submodule

@ -0,0 +1 @@
Subproject commit d50d76daa670286dd6cacf3bcd80b5e4823fc8e1

1
src/k-diffusion Submodule

@ -0,0 +1 @@
Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b

@ -0,0 +1 @@
Subproject commit 24268930bf1dce879235a7fddd0b2355b84d7ea6