mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix crash in txt2img and img2img w/ inpainting models and perlin > 0
- get_perlin_noise() was returning 9 channels; fixed code to return noise for just the 4 image channels and not the mask ones. - Closes Issue #2541
This commit is contained in:
parent
05bb9e444b
commit
0240656361
@ -240,7 +240,12 @@ class Generator:
|
|||||||
|
|
||||||
def get_perlin_noise(self,width,height):
|
def get_perlin_noise(self,width,height):
|
||||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||||
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
input_channels = min(self.latent_channels, 4)
|
||||||
|
noise = torch.stack([
|
||||||
|
rand_perlin_2d((height, width),
|
||||||
|
(8, 8),
|
||||||
|
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
def new_seed(self):
|
def new_seed(self):
|
||||||
@ -341,3 +346,27 @@ class Generator:
|
|||||||
|
|
||||||
def torch_dtype(self)->torch.dtype:
|
def torch_dtype(self)->torch.dtype:
|
||||||
return torch.float16 if self.precision == 'float16' else torch.float32
|
return torch.float16 if self.precision == 'float16' else torch.float32
|
||||||
|
|
||||||
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
|
def get_noise(self,width,height):
|
||||||
|
device = self.model.device
|
||||||
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
input_channels = min(self.latent_channels, 4)
|
||||||
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
|
x = torch.randn([1,
|
||||||
|
input_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
dtype=self.torch_dtype(),
|
||||||
|
device='cpu').to(device)
|
||||||
|
else:
|
||||||
|
x = torch.randn([1,
|
||||||
|
input_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
dtype=self.torch_dtype(),
|
||||||
|
device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
|
x = (1-self.perlin)*x + self.perlin*perlin_noise
|
||||||
|
return x
|
||||||
|
@ -63,22 +63,3 @@ class Img2Img(Generator):
|
|||||||
shape = like.shape
|
shape = like.shape
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_noise(self,width,height):
|
|
||||||
# copy of the Txt2Img.get_noise
|
|
||||||
device = self.model.device
|
|
||||||
if self.use_mps_noise or device.type == 'mps':
|
|
||||||
x = torch.randn([1,
|
|
||||||
self.latent_channels,
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor],
|
|
||||||
device='cpu').to(device)
|
|
||||||
else:
|
|
||||||
x = torch.randn([1,
|
|
||||||
self.latent_channels,
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor],
|
|
||||||
device=device)
|
|
||||||
if self.perlin > 0.0:
|
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
|
||||||
return x
|
|
||||||
|
@ -51,26 +51,4 @@ class Txt2Img(Generator):
|
|||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
|
|
||||||
# returns a tensor filled with random numbers from a normal distribution
|
|
||||||
def get_noise(self,width,height):
|
|
||||||
device = self.model.device
|
|
||||||
# limit noise to only the diffusion image channels, not the mask channels
|
|
||||||
input_channels = min(self.latent_channels, 4)
|
|
||||||
if self.use_mps_noise or device.type == 'mps':
|
|
||||||
x = torch.randn([1,
|
|
||||||
input_channels,
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor],
|
|
||||||
dtype=self.torch_dtype(),
|
|
||||||
device='cpu').to(device)
|
|
||||||
else:
|
|
||||||
x = torch.randn([1,
|
|
||||||
input_channels,
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor],
|
|
||||||
dtype=self.torch_dtype(),
|
|
||||||
device=device)
|
|
||||||
if self.perlin > 0.0:
|
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user