From a7515624b28dafb68bbb3e3a500d953993e3b4fe Mon Sep 17 00:00:00 2001 From: spezialspezial <75758219+spezialspezial@users.noreply.github.com> Date: Fri, 7 Oct 2022 09:57:15 +0200 Subject: [PATCH] remove duplicated code --- ldm/generate.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index c8ac16baa0..317e05b31d 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -34,23 +34,7 @@ from ldm.dream.image_util import InitImageResizer from ldm.dream.devices import choose_torch_device, choose_precision from ldm.dream.conditioning import get_uc_and_c -def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"]="cpu" - return orig(*args, **kw).to(device) - return new_func - return orig -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): @@ -70,23 +54,7 @@ torch.randint_like = fix_func(torch.randint_like) torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) -def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"]="cpu" - return orig(*args, **kw).to(device) - return new_func - return orig -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) """Simplified text to image API for stable diffusion/latent diffusion