mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix perlin noise and txt2img2img
This commit is contained in:
parent
7e8f364d8d
commit
ce00c9856f
@ -237,7 +237,8 @@ class Generator:
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
return noise
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
|
@ -90,9 +90,9 @@ class Txt2Img2Img(Generator):
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu').to(device)
|
||||
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
@ -117,10 +117,12 @@ class Txt2Img2Img(Generator):
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
dtype=self.torch_dtype(),
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
dtype=self.torch_dtype(),
|
||||
device=device)
|
||||
|
@ -349,7 +349,7 @@ class ModelManager(object):
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
model = model.to(torch.float16)
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
|
@ -8,6 +8,7 @@ from threading import Thread
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -235,7 +236,8 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
|
||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
||||
t = fade(grid[:shape[0], :shape[1]])
|
||||
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
||||
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
||||
return noise.to(dtype=torch_dtype(device))
|
||||
|
||||
def ask_user(question: str, answers: list):
|
||||
from itertools import chain, repeat
|
||||
|
Loading…
Reference in New Issue
Block a user