From 8254ca9492df0efb45c62c3251b8c6d7927d44bb Mon Sep 17 00:00:00 2001 From: wfng92 <43742196+wfng92@users.noreply.github.com> Date: Sat, 22 Oct 2022 09:03:31 +0800 Subject: [PATCH] Removed duplicate fix_func for MPS --- ldm/generate.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index fd43093e6f..cfd2f57b45 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -58,24 +58,6 @@ torch.multinomial = fix_func(torch.multinomial) # this is fallback model in case no default is defined FALLBACK_MODEL_NAME='stable-diffusion-1.4' -def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"]="cpu" - return orig(*args, **kw).to(device) - return new_func - return orig - -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) - """Simplified text to image API for stable diffusion/latent diffusion Example Usage: