Fixes PYTORCH_ENABLE_MPS_FALLBACK not set correctly

`torch` wasn't seeing the environment variable. I suspect this is because it was imported before the variable was set, so was running with a different environment.

Many `torch` ops are supported on MPS so this wasn't noticed immediately, but some samplers like k_dpm_2 still use unsupported operations and need this fallback.
This commit is contained in:
psychedelicious 2023-02-04 17:27:33 +11:00
parent b5160321bf
commit a1b1a48fb3
No known key found for this signature in database
2 changed files with 3 additions and 6 deletions

View File

@ -4,6 +4,9 @@ import sys
import shlex
import traceback
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from ldm.invoke.globals import Globals
from ldm.generate import Generate
from ldm.invoke.prompt_parser import PromptParser
@ -21,9 +24,6 @@ import ldm.invoke
# global used in multiple functions (fix)
infile = None
if sys.platform == 'darwin':
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
def main():
"""Initialize command-line parsers and the diffusion model"""
global infile

View File

@ -1,7 +1,4 @@
#!/usr/bin/env python
import sys
import os
import ldm.invoke.CLI
ldm.invoke.CLI.main()