diff --git a/ldm/dream/generator/base.py b/ldm/dream/generator/base.py index 08dcd3aa81..5f88c02c4e 100644 --- a/ldm/dream/generator/base.py +++ b/ldm/dream/generator/base.py @@ -124,8 +124,8 @@ class Generator(): raise NotImplementedError("get_noise() must be implemented in a descendent class") def get_perlin_noise(self,width,height): - return torch.stack([rand_perlin_2d((height, width), (8, 8)).to(self.model.device) for _ in range(self.latent_channels)], dim=0) - + fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device + return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) def new_seed(self): self.seed = random.randrange(0, np.iinfo(np.uint32).max) diff --git a/ldm/util.py b/ldm/util.py index 94ad492e22..298c3141d6 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -214,15 +214,19 @@ def parallel_data_prefetch( else: return gather_res -def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3): +def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) - grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1) % 1 - angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1) + grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1 + + rand_val = torch.rand(res[0]+1, res[1]+1) + + angles = 2*math.pi*rand_val gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1) tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) + dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1) n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])