Merge branch 'lstein-fix-autocast' of github.com:invoke-ai/InvokeAI into lstein-fix-autocast

This commit is contained in:
Lincoln Stein 2023-01-16 23:18:54 -05:00
commit bcc0110c59
4 changed files with 12 additions and 7 deletions

View File

@ -237,7 +237,8 @@ class Generator:
def get_perlin_noise(self,width,height): def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
return noise
def new_seed(self): def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max) self.seed = random.randrange(0, np.iinfo(np.uint32).max)

View File

@ -90,9 +90,9 @@ class Txt2Img2Img(Generator):
def get_noise_like(self, like: torch.Tensor): def get_noise_like(self, like: torch.Tensor):
device = like.device device = like.device
if device.type == 'mps': if device.type == 'mps':
x = torch.randn_like(like, device='cpu').to(device) x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
else: else:
x = torch.randn_like(like, device=device) x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0: if self.perlin > 0.0:
shape = like.shape shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
@ -117,10 +117,12 @@ class Txt2Img2Img(Generator):
self.latent_channels, self.latent_channels,
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor], scaled_width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device) device='cpu').to(device)
else: else:
return torch.randn([1, return torch.randn([1,
self.latent_channels, self.latent_channels,
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor], scaled_width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device) device=device)

View File

@ -349,7 +349,7 @@ class ModelManager(object):
if self.precision == 'float16': if self.precision == 'float16':
print(' | Using faster float16 precision') print(' | Using faster float16 precision')
model.to(torch.float16) model = model.to(torch.float16)
else: else:
print(' | Using more accurate float32 precision') print(' | Using more accurate float32 precision')

View File

@ -8,6 +8,7 @@ from threading import Thread
from urllib import request from urllib import request
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
from ldm.invoke.devices import torch_dtype
import numpy as np import numpy as np
import torch import torch
@ -235,7 +236,8 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device) n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device) n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]]) t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
return noise.to(dtype=torch_dtype(device))
def ask_user(question: str, answers: list): def ask_user(question: str, answers: list):
from itertools import chain, repeat from itertools import chain, repeat