mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Undo more 'cuda' hardcoding
This commit is contained in:
parent
de1cea92ce
commit
886f1c0138
@ -17,9 +17,6 @@ class DDIMSampler(object):
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
@ -16,9 +16,6 @@ class PLMSSampler(object):
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
@ -432,6 +432,8 @@ The vast majority of these arguments default to reasonable values.
|
||||
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = self._load_model_from_config(config,self.weights)
|
||||
self.model = model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
self.model.cond_stage_model.device = self.device
|
||||
except AttributeError:
|
||||
raise SystemExit
|
||||
|
||||
|
1
src/clip
Submodule
1
src/clip
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
1
src/k-diffusion
Submodule
1
src/k-diffusion
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b
|
1
src/taming-transformers
Submodule
1
src/taming-transformers
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 24268930bf1dce879235a7fddd0b2355b84d7ea6
|
Loading…
Reference in New Issue
Block a user