From 101cac6a215ad21f2a444361ead4c4e3fb8674ac Mon Sep 17 00:00:00 2001 From: Jan Skurovec Date: Tue, 11 Oct 2022 22:10:29 +0200 Subject: [PATCH] reintroduce fix for m1 from PR#579 missing after merge Make results reproducible (so runs with the same seed produce the same result). Implements fix by @wbowling referenced in https://github.com/invoke-ai/InvokeAI/issues/397#issuecomment-1240679294 --- ldm/generate.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/ldm/generate.py b/ldm/generate.py index 22fb27e0a8..f9fc364cf3 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -34,6 +34,24 @@ from ldm.invoke.image_util import InitImageResizer from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.conditioning import get_uc_and_c +def fix_func(orig): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + def new_func(*args, **kw): + device = kw.get("device", "mps") + kw["device"]="cpu" + return orig(*args, **kw).to(device) + return new_func + return orig + +torch.rand = fix_func(torch.rand) +torch.rand_like = fix_func(torch.rand_like) +torch.randn = fix_func(torch.randn) +torch.randn_like = fix_func(torch.randn_like) +torch.randint = fix_func(torch.randint) +torch.randint_like = fix_func(torch.randint_like) +torch.bernoulli = fix_func(torch.bernoulli) +torch.multinomial = fix_func(torch.multinomial) + """Simplified text to image API for stable diffusion/latent diffusion Example Usage: