fix crash in txt2img and img2img w/ inpainting models and perlin > 0 (#2544)

- get_perlin_noise() was returning 9 channels; fixed code to return
noise for just the 4 image channels and not the mask ones.

- Closes Issue #2541
This commit is contained in:
Lincoln Stein 2023-02-05 23:50:32 -05:00 committed by GitHub
commit 633f702b39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 42 deletions

View File

@ -240,7 +240,12 @@ class Generator:
def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
noise = torch.stack([
rand_perlin_2d((height, width),
(8, 8),
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
return noise
def new_seed(self):
@ -341,3 +346,27 @@ class Generator:
def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
x = (1-self.perlin)*x + self.perlin*perlin_noise
return x

View File

@ -63,22 +63,3 @@ class Img2Img(Generator):
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
def get_noise(self,width,height):
# copy of the Txt2Img.get_noise
device = self.model.device
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -51,26 +51,4 @@ class Txt2Img(Generator):
return make_image
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x