diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index bfe2c99cc4..b82550bf23 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -151,7 +151,7 @@ class T2I: self.grid = grid self.ddim_eta = ddim_eta self.precision = precision - self.full_precision = full_precision + self.full_precision = True if choose_torch_device() == 'mps' else full_precision self.strength = strength self.embedding_path = embedding_path self.device_type = device_type diff --git a/scripts/dream.py b/scripts/dream.py index 17e481eb33..0c101503a9 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -9,7 +9,6 @@ import sys import copy import warnings import time -from ldm.dream.devices import choose_torch_device import ldm.dream.readline from ldm.dream.pngwriter import PngWriter, PromptFormatter from ldm.dream.server import DreamServer, ThreadingDreamServer @@ -391,9 +390,7 @@ def create_argv_parser(): '--full_precision', dest='full_precision', action='store_true', - help='Use slower full precision math for calculations', - # MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237 - default=choose_torch_device() == 'mps', + help='Use more memory-intensive full precision math for calculations', ) parser.add_argument( '-g',