move special-casing test for precision on mps into T2I class

This commit is contained in:
Lincoln Stein 2022-09-03 09:43:18 -04:00
parent d0df894c9f
commit 91cce6b4c3
2 changed files with 2 additions and 5 deletions

View File

@ -151,7 +151,7 @@ class T2I:
self.grid = grid self.grid = grid
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
self.precision = precision self.precision = precision
self.full_precision = full_precision self.full_precision = True if choose_torch_device() == 'mps' else full_precision
self.strength = strength self.strength = strength
self.embedding_path = embedding_path self.embedding_path = embedding_path
self.device_type = device_type self.device_type = device_type

View File

@ -9,7 +9,6 @@ import sys
import copy import copy
import warnings import warnings
import time import time
from ldm.dream.devices import choose_torch_device
import ldm.dream.readline import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.server import DreamServer, ThreadingDreamServer
@ -391,9 +390,7 @@ def create_argv_parser():
'--full_precision', '--full_precision',
dest='full_precision', dest='full_precision',
action='store_true', action='store_true',
help='Use slower full precision math for calculations', help='Use more memory-intensive full precision math for calculations',
# MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237
default=choose_torch_device() == 'mps',
) )
parser.add_argument( parser.add_argument(
'-g', '-g',