diff --git a/ldm/generate.py b/ldm/generate.py index 22fb27e0a8..f9fc364cf3 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -34,6 +34,24 @@ from ldm.invoke.image_util import InitImageResizer from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.conditioning import get_uc_and_c +def fix_func(orig): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + def new_func(*args, **kw): + device = kw.get("device", "mps") + kw["device"]="cpu" + return orig(*args, **kw).to(device) + return new_func + return orig + +torch.rand = fix_func(torch.rand) +torch.rand_like = fix_func(torch.rand_like) +torch.randn = fix_func(torch.randn) +torch.randn_like = fix_func(torch.randn_like) +torch.randint = fix_func(torch.randint) +torch.randint_like = fix_func(torch.randint_like) +torch.bernoulli = fix_func(torch.bernoulli) +torch.multinomial = fix_func(torch.multinomial) + """Simplified text to image API for stable diffusion/latent diffusion Example Usage: