diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 4923e7daf5..058628ba1c 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -115,22 +115,13 @@ class Txt2Img2Img(Generator): scaled_width = width scaled_height = height - device = self.model.device - + device = self.model.device channels = self.latent_channels if channels == 9: channels = 4 # we don't really want noise for all the mask channels + shape = (1, channels, + scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor) if self.use_mps_noise or device.type == 'mps': - return torch.randn([1, - channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device='cpu').to(device) + return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device) else: - return torch.randn([1, - channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device=device) + return torch.randn(shape, dtype=self.torch_dtype(), device=device)