move special-casing test for precision on mps into T2I class

This commit is contained in:
Lincoln Stein 2022-09-03 09:43:18 -04:00
parent d0df894c9f
commit 91cce6b4c3
2 changed files with 2 additions and 5 deletions

View File

@ -151,7 +151,7 @@ class T2I:
self.grid = grid
self.ddim_eta = ddim_eta
self.precision = precision
self.full_precision = full_precision
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
self.strength = strength
self.embedding_path = embedding_path
self.device_type = device_type

View File

@ -9,7 +9,6 @@ import sys
import copy
import warnings
import time
from ldm.dream.devices import choose_torch_device
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.server import DreamServer, ThreadingDreamServer
@ -391,9 +390,7 @@ def create_argv_parser():
'--full_precision',
dest='full_precision',
action='store_true',
help='Use slower full precision math for calculations',
# MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237
default=choose_torch_device() == 'mps',
help='Use more memory-intensive full precision math for calculations',
)
parser.add_argument(
'-g',