diff --git a/ldm/generate.py b/ldm/generate.py index 1bb8e33eb9..fc622c7cad 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -28,6 +28,24 @@ from ldm.dream.image_util import InitImageResizer from ldm.dream.devices import choose_torch_device from ldm.dream.conditioning import get_uc_and_c +def fix_func(orig): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + def new_func(*args, **kw): + device = kw.get("device", "mps") + kw["device"]="cpu" + return orig(*args, **kw).to(device) + return new_func + return orig + +torch.rand = fix_func(torch.rand) +torch.rand_like = fix_func(torch.rand_like) +torch.randn = fix_func(torch.randn) +torch.randn_like = fix_func(torch.randn_like) +torch.randint = fix_func(torch.randint) +torch.randint_like = fix_func(torch.randint_like) +torch.bernoulli = fix_func(torch.bernoulli) +torch.multinomial = fix_func(torch.multinomial) + """Simplified text to image API for stable diffusion/latent diffusion Example Usage: diff --git a/scripts/dream.py b/scripts/dream.py index 4044af7cb8..e5a48f38a6 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -20,7 +20,6 @@ from omegaconf import OmegaConf # Just want to get the formatting look right for now. output_cntr = 0 - def main(): """Initialize command-line parsers and the diffusion model""" arg_parser = create_argv_parser()