mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move special-casing test for precision on mps into T2I class
This commit is contained in:
parent
d0df894c9f
commit
91cce6b4c3
@ -151,7 +151,7 @@ class T2I:
|
|||||||
self.grid = grid
|
self.grid = grid
|
||||||
self.ddim_eta = ddim_eta
|
self.ddim_eta = ddim_eta
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.full_precision = full_precision
|
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
self.device_type = device_type
|
self.device_type = device_type
|
||||||
|
@ -9,7 +9,6 @@ import sys
|
|||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import time
|
import time
|
||||||
from ldm.dream.devices import choose_torch_device
|
|
||||||
import ldm.dream.readline
|
import ldm.dream.readline
|
||||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
||||||
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
||||||
@ -391,9 +390,7 @@ def create_argv_parser():
|
|||||||
'--full_precision',
|
'--full_precision',
|
||||||
dest='full_precision',
|
dest='full_precision',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Use slower full precision math for calculations',
|
help='Use more memory-intensive full precision math for calculations',
|
||||||
# MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237
|
|
||||||
default=choose_torch_device() == 'mps',
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-g',
|
'-g',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user