mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Undo more 'cuda' hardcoding
This commit is contained in:
@ -17,9 +17,6 @@ class DDIMSampler(object):
|
|||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != torch.device("cuda"):
|
|
||||||
attr = attr.to(torch.device("cuda"))
|
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
@ -16,9 +16,6 @@ class PLMSSampler(object):
|
|||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != torch.device("cuda"):
|
|
||||||
attr = attr.to(torch.device("cuda"))
|
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
@ -432,6 +432,8 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
|
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = self._load_model_from_config(config,self.weights)
|
model = self._load_model_from_config(config,self.weights)
|
||||||
self.model = model.to(self.device)
|
self.model = model.to(self.device)
|
||||||
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
|
self.model.cond_stage_model.device = self.device
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
|
1
src/clip
Submodule
1
src/clip
Submodule
Submodule src/clip added at d50d76daa6
1
src/k-diffusion
Submodule
1
src/k-diffusion
Submodule
Submodule src/k-diffusion added at db57990687
1
src/taming-transformers
Submodule
1
src/taming-transformers
Submodule
Submodule src/taming-transformers added at 24268930bf
Reference in New Issue
Block a user