diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index b82550bf23..e7b7b524b0 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -348,13 +348,23 @@ class T2I: def get_noise(): if init_img: - return torch.randn_like(init_latent, device=self.device) + if self.device.type == 'mps': + return torch.randn_like(init_latent, device='cpu').to(self.device) + else: + return torch.randn_like(init_latent, device=self.device) else: - return torch.randn([1, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device=self.device) + if self.device.type == 'mps': + return torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device='cpu').to(self.device) + else: + return torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device=self.device) initial_noise = None if variation_amount > 0 or len(with_variations) > 0: @@ -383,6 +393,8 @@ class T2I: x_T = initial_noise else: seed_everything(seed) + if self.device.type == 'mps': + x_T = get_noise() # make_image will do the equivalent of get_noise itself image = make_image(x_T) results.append([image, seed])