add support for Apple hardware using MPS acceleration

This commit is contained in:
Lincoln Stein
2022-08-31 00:33:23 -04:00
parent 1714816fe2
commit bdb0651eb2
16 changed files with 361 additions and 52 deletions

View File

@ -18,6 +18,7 @@ from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.dream.devices import choose_torch_device
def chunk(it, size):
@ -40,7 +41,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.to(choose_torch_device())
model.eval()
return model
@ -199,7 +200,7 @@ def main():
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = choose_torch_device()
model = model.to(device)
if opt.plms:
@ -241,8 +242,10 @@ def main():
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()

View File

@ -15,10 +15,10 @@ from contextlib import contextmanager, nullcontext
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.dream.devices import choose_torch_device
def chunk(it, size):
it = iter(it)
@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.to(choose_torch_device())
model.eval()
return model
@ -190,13 +190,14 @@ def main():
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
seed_everything(opt.seed)
device = torch.device(choose_torch_device())
model = model.to(device)
#for klms
model_wrap = K.external.CompVisDenoiser(model)
@ -240,11 +241,17 @@ def main():
start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
if device.type == 'mps':
start_code = torch.randn(shape, device='cpu').to(device)
else:
torch.randn(shape, device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()