diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 78876a0919..d10676c841 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -49,9 +49,15 @@ class Upsample(nn.Module): padding=1) def forward(self, x): + cpu_m1_cond = True if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and \ + x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] % 2**27 == 0 else False + if cpu_m1_cond: + x = x.to('cpu') # send to cpu x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) + if cpu_m1_cond: + x = x.to('mps') # return to mps return x @@ -117,6 +123,14 @@ class ResnetBlock(nn.Module): padding=0) def forward(self, x, temb): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + x_size = x.size() + if (x_size[0] * x_size[1] * x_size[2] * x_size[3]) % 2**29 == 0: + self.to('cpu') + x = x.to('cpu') + else: + self.to('mps') + x = x.to('mps') h = self.norm1(x) h = silu(h) h = self.conv1(h)