mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for Apple hardware using MPS acceleration
This commit is contained in:
@ -18,6 +18,7 @@ from pytorch_lightning import seed_everything
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
@ -40,7 +41,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.to(choose_torch_device())
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@ -199,7 +200,7 @@ def main():
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = choose_torch_device()
|
||||
model = model.to(device)
|
||||
|
||||
if opt.plms:
|
||||
@ -241,8 +242,10 @@ def main():
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||
if device.type in ['mps', 'cpu']:
|
||||
precision_scope = nullcontext # have to use f32 on mps
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with precision_scope(device.type):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
|
@ -15,10 +15,10 @@ from contextlib import contextmanager, nullcontext
|
||||
import k_diffusion as K
|
||||
import torch.nn as nn
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.to(choose_torch_device())
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@ -190,13 +190,14 @@ def main():
|
||||
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
||||
opt.outdir = "outputs/txt2img-samples-laion400m"
|
||||
|
||||
seed_everything(opt.seed)
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
seed_everything(opt.seed)
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
model = model.to(device)
|
||||
|
||||
#for klms
|
||||
model_wrap = K.external.CompVisDenoiser(model)
|
||||
@ -240,11 +241,17 @@ def main():
|
||||
|
||||
start_code = None
|
||||
if opt.fixed_code:
|
||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
if device.type == 'mps':
|
||||
start_code = torch.randn(shape, device='cpu').to(device)
|
||||
else:
|
||||
torch.randn(shape, device=device)
|
||||
|
||||
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||
if device.type in ['mps', 'cpu']:
|
||||
precision_scope = nullcontext # have to use f32 on mps
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with precision_scope(device.type):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
|
Reference in New Issue
Block a user