move autocast device selection to a function

This commit is contained in:
Jason Toffaletti 2022-08-31 22:21:14 -07:00
parent fa98601bfb
commit 09bd9fa47e
2 changed files with 9 additions and 5 deletions

View File

@ -8,4 +8,10 @@ def choose_torch_device() -> str:
return 'mps' return 'mps'
return 'cpu' return 'cpu'
def choose_autocast_device(device) -> str:
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
# autocast only supports cuda or cpu
if device_type != 'cuda' or device_type != 'cpu':
return 'cpu'
return device_type

View File

@ -27,7 +27,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from ldm.dream.devices import choose_torch_device from ldm.dream.devices import choose_autocast_device, choose_torch_device
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
@ -315,9 +315,7 @@ class T2I:
callback=step_callback, callback=step_callback,
) )
device_type = self.device.type # this returns 'mps' on M1 device_type = choose_autocast_device(self.device)
if device_type != 'cuda' or device_type != 'cpu':
device_type = 'cpu'
with scope(device_type), self.model.ema_scope(): with scope(device_type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'): for n in trange(iterations, desc='Generating'):
seed_everything(seed) seed_everything(seed)