From 60b731e7abb151ff02ab8521bc565315eda4d54a Mon Sep 17 00:00:00 2001 From: Any-Winter-4079 <50542132+Any-Winter-4079@users.noreply.github.com> Date: Thu, 15 Sep 2022 17:02:17 +0200 Subject: [PATCH] Update dream.py. k_euler_a and k_dpm_2_a M1 fix (#579) * Update dream.py. k_euler_a and k_dpm_2_a M1 fix Make results reproducible (so runs with the same seed produce the same result). Implements fix by @wbowling referenced in https://github.com/lstein/stable-diffusion/issues/397#issuecomment-1240679294 * Update dream.py. Remove import torch from dream.py * generate.py: k_euler_a and k_dpm_2_a M1 fix #579 Co-authored-by: Lincoln Stein --- ldm/generate.py | 18 ++++++++++++++++++ scripts/dream.py | 1 - 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/ldm/generate.py b/ldm/generate.py index 1bb8e33eb9..fc622c7cad 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -28,6 +28,24 @@ from ldm.dream.image_util import InitImageResizer from ldm.dream.devices import choose_torch_device from ldm.dream.conditioning import get_uc_and_c +def fix_func(orig): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + def new_func(*args, **kw): + device = kw.get("device", "mps") + kw["device"]="cpu" + return orig(*args, **kw).to(device) + return new_func + return orig + +torch.rand = fix_func(torch.rand) +torch.rand_like = fix_func(torch.rand_like) +torch.randn = fix_func(torch.randn) +torch.randn_like = fix_func(torch.randn_like) +torch.randint = fix_func(torch.randint) +torch.randint_like = fix_func(torch.randint_like) +torch.bernoulli = fix_func(torch.bernoulli) +torch.multinomial = fix_func(torch.multinomial) + """Simplified text to image API for stable diffusion/latent diffusion Example Usage: diff --git a/scripts/dream.py b/scripts/dream.py index 4044af7cb8..e5a48f38a6 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -20,7 +20,6 @@ from omegaconf import OmegaConf # Just want to get the formatting look right for now. output_cntr = 0 - def main(): """Initialize command-line parsers and the diffusion model""" arg_parser = create_argv_parser()