Undo more 'cuda' hardcoding

This commit is contained in:
Benjamin Warner
2022-08-24 00:39:25 -05:00
parent de1cea92ce
commit 886f1c0138
6 changed files with 5 additions and 6 deletions

View File

@ -17,9 +17,6 @@ class DDIMSampler(object):
self.schedule = schedule self.schedule = schedule
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -16,9 +16,6 @@ class PLMSSampler(object):
self.schedule = schedule self.schedule = schedule
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -432,6 +432,8 @@ The vast majority of these arguments default to reasonable values.
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu") self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
model = self._load_model_from_config(config,self.weights) model = self._load_model_from_config(config,self.weights)
self.model = model.to(self.device) self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device
except AttributeError: except AttributeError:
raise SystemExit raise SystemExit

1
src/clip Submodule

Submodule src/clip added at d50d76daa6

1
src/k-diffusion Submodule

Submodule src/k-diffusion added at db57990687

Submodule src/taming-transformers added at 24268930bf