diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index bac7bbb333..25cd281cfe 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -237,7 +237,8 @@ class Generator: def get_perlin_noise(self,width,height): fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device - return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) + noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) + return noise def new_seed(self): self.seed = random.randrange(0, np.iinfo(np.uint32).max) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index e356f719c4..1dba0cfafb 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -90,9 +90,9 @@ class Txt2Img2Img(Generator): def get_noise_like(self, like: torch.Tensor): device = like.device if device.type == 'mps': - x = torch.randn_like(like, device='cpu').to(device) + x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device) else: - x = torch.randn_like(like, device=device) + x = torch.randn_like(like, device=device, dtype=self.torch_dtype()) if self.perlin > 0.0: shape = like.shape x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) @@ -117,10 +117,12 @@ class Txt2Img2Img(Generator): self.latent_channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], - device='cpu').to(device) + dtype=self.torch_dtype(), + device='cpu').to(device) else: return torch.randn([1, self.latent_channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], - device=device) + dtype=self.torch_dtype(), + device=device) diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index bc19ba1449..f74706aaef 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -349,7 +349,7 @@ class ModelManager(object): if self.precision == 'float16': print(' | Using faster float16 precision') - model.to(torch.float16) + model = model.to(torch.float16) else: print(' | Using more accurate float32 precision') diff --git a/ldm/util.py b/ldm/util.py index 282a56c3e5..7d44dcd266 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -8,6 +8,7 @@ from threading import Thread from urllib import request from tqdm import tqdm from pathlib import Path +from ldm.invoke.devices import torch_dtype import numpy as np import torch @@ -235,7 +236,8 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t* n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device) n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device) t = fade(grid[:shape[0], :shape[1]]) - return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) + noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) + return noise.to(dtype=torch_dtype(device)) def ask_user(question: str, answers: list): from itertools import chain, repeat