From 0bc6779361ffb8a31e68a4a817a14c7769631290 Mon Sep 17 00:00:00 2001 From: Mihai <299015+mh-dm@users.noreply.github.com> Date: Mon, 12 Sep 2022 23:55:21 +0300 Subject: [PATCH] Disable autocast for cpu to fix error. Remove unused precision arg. (#518) When running on just cpu (intel), a call to torch.layer_norm would error with RuntimeError: expected scalar type BFloat16 but found Float Fix buggy device handling in model.py. Tested with scripts/dream.py --full_precision on just cpu on intel laptop. Works but slow at ~10s/it. --- ldm/dream/devices.py | 5 +++-- ldm/generate.py | 2 -- ldm/modules/diffusionmodules/model.py | 6 ++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ldm/dream/devices.py b/ldm/dream/devices.py index 90bc9e97dd..3b85a7420c 100644 --- a/ldm/dream/devices.py +++ b/ldm/dream/devices.py @@ -13,8 +13,9 @@ def choose_torch_device() -> str: def choose_autocast_device(device): '''Returns an autocast compatible device from a torch device''' device_type = device.type # this returns 'mps' on M1 - # autocast only supports cuda or cpu - if device_type in ('cuda','cpu'): + if device_type == 'cuda': return device_type,autocast + elif device_type == 'cpu': + return device_type,nullcontext else: return 'cpu',nullcontext diff --git a/ldm/generate.py b/ldm/generate.py index 89bbb470f5..52c8846d80 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -111,7 +111,6 @@ class Generate: height = 512, sampler_name = 'k_lms', ddim_eta = 0.0, # deterministic - precision = 'autocast', full_precision = False, strength = 0.75, # default in scripts/img2img.py seamless = False, @@ -129,7 +128,6 @@ class Generate: self.sampler_name = sampler_name self.grid = grid self.ddim_eta = ddim_eta - self.precision = precision self.full_precision = True if choose_torch_device() == 'mps' else full_precision self.strength = strength self.seamless = seamless diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 970f6aad8f..5880452d47 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -209,8 +209,7 @@ class AttnBlock(nn.Module): h_ = torch.zeros_like(k, device=q.device) - device_type = 'mps' if q.device.type == 'mps' else 'cuda' - if device_type == 'cuda': + if q.device.type == 'cuda': stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] @@ -612,9 +611,8 @@ class Decoder(nn.Module): del h3 # prepare for up sampling - device_type = 'mps' if h.device.type == 'mps' else 'cuda' gc.collect() - if device_type == 'cuda': + if h.device.type == 'cuda': torch.cuda.empty_cache() # upsampling