From a1b1a48fb372a643637c5ddf5ffd608f594e5b89 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 4 Feb 2023 17:27:33 +1100 Subject: [PATCH] Fixes `PYTORCH_ENABLE_MPS_FALLBACK` not set correctly `torch` wasn't seeing the environment variable. I suspect this is because it was imported before the variable was set, so was running with a different environment. Many `torch` ops are supported on MPS so this wasn't noticed immediately, but some samplers like k_dpm_2 still use unsupported operations and need this fallback. --- ldm/invoke/CLI.py | 6 +++--- scripts/invoke.py | 3 --- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index 2d673a9112..ca4153f53d 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -4,6 +4,9 @@ import sys import shlex import traceback +if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + from ldm.invoke.globals import Globals from ldm.generate import Generate from ldm.invoke.prompt_parser import PromptParser @@ -21,9 +24,6 @@ import ldm.invoke # global used in multiple functions (fix) infile = None -if sys.platform == 'darwin': - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - def main(): """Initialize command-line parsers and the diffusion model""" global infile diff --git a/scripts/invoke.py b/scripts/invoke.py index 710cea3830..7431300f15 100755 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -1,7 +1,4 @@ #!/usr/bin/env python -import sys -import os - import ldm.invoke.CLI ldm.invoke.CLI.main()