Fixes PYTORCH_ENABLE_MPS_FALLBACK not set correctly (#2508)

`torch` wasn't seeing the environment variable. I suspect this is
because it was imported before the variable was set, so was running with
a different environment.

Many `torch` ops are supported on MPS so this wasn't noticed
immediately, but some samplers like k_dpm_2 still use unsupported
operations and need this fallback.
This commit is contained in:
Lincoln Stein 2023-02-04 11:32:52 -05:00 committed by GitHub
commit 3b58413d9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 6 deletions

View File

@ -4,6 +4,9 @@ import sys
import shlex
import traceback
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from ldm.invoke.globals import Globals
from ldm.generate import Generate
from ldm.invoke.prompt_parser import PromptParser
@ -21,9 +24,6 @@ import ldm.invoke
# global used in multiple functions (fix)
infile = None
if sys.platform == 'darwin':
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
def main():
"""Initialize command-line parsers and the diffusion model"""
global infile

View File

@ -1,7 +1,4 @@
#!/usr/bin/env python
import sys
import os
import ldm.invoke.CLI
ldm.invoke.CLI.main()