move autocast device selection to a function

This commit is contained in:
Jason Toffaletti 2022-08-31 22:21:14 -07:00
parent fa98601bfb
commit 09bd9fa47e
2 changed files with 9 additions and 5 deletions

View File

@ -8,4 +8,10 @@ def choose_torch_device() -> str:
return 'mps'
return 'cpu'
def choose_autocast_device(device) -> str:
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
# autocast only supports cuda or cpu
if device_type != 'cuda' or device_type != 'cpu':
return 'cpu'
return device_type

View File

@ -27,7 +27,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter
from ldm.dream.devices import choose_torch_device
from ldm.dream.devices import choose_autocast_device, choose_torch_device
"""Simplified text to image API for stable diffusion/latent diffusion
@ -315,9 +315,7 @@ class T2I:
callback=step_callback,
)
device_type = self.device.type # this returns 'mps' on M1
if device_type != 'cuda' or device_type != 'cpu':
device_type = 'cpu'
device_type = choose_autocast_device(self.device)
with scope(device_type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
seed_everything(seed)