From 09bd9fa47eb96455d29fb575e028cfaad5b39ca7 Mon Sep 17 00:00:00 2001 From: Jason Toffaletti Date: Wed, 31 Aug 2022 22:21:14 -0700 Subject: [PATCH] move autocast device selection to a function --- ldm/dream/devices.py | 8 +++++++- ldm/simplet2i.py | 6 ++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ldm/dream/devices.py b/ldm/dream/devices.py index 240754dd36..9581abe78c 100644 --- a/ldm/dream/devices.py +++ b/ldm/dream/devices.py @@ -8,4 +8,10 @@ def choose_torch_device() -> str: return 'mps' return 'cpu' - +def choose_autocast_device(device) -> str: + '''Returns an autocast compatible device from a torch device''' + device_type = device.type # this returns 'mps' on M1 + # autocast only supports cuda or cpu + if device_type != 'cuda' or device_type != 'cpu': + return 'cpu' + return device_type diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index c0d2886538..ccecee1a46 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -27,7 +27,7 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ksampler import KSampler from ldm.dream.pngwriter import PngWriter -from ldm.dream.devices import choose_torch_device +from ldm.dream.devices import choose_autocast_device, choose_torch_device """Simplified text to image API for stable diffusion/latent diffusion @@ -315,9 +315,7 @@ class T2I: callback=step_callback, ) - device_type = self.device.type # this returns 'mps' on M1 - if device_type != 'cuda' or device_type != 'cpu': - device_type = 'cpu' + device_type = choose_autocast_device(self.device) with scope(device_type), self.model.ema_scope(): for n in trange(iterations, desc='Generating'): seed_everything(seed)