refactor(txt2img2img): factor out tensor shape

This commit is contained in:
Kevin Turner 2023-01-27 12:04:12 -08:00
parent 1bb5b4ab32
commit 09b6104bfd

View File

@ -115,22 +115,13 @@ class Txt2Img2Img(Generator):
scaled_width = width scaled_width = width
scaled_height = height scaled_height = height
device = self.model.device device = self.model.device
channels = self.latent_channels channels = self.latent_channels
if channels == 9: if channels == 9:
channels = 4 # we don't really want noise for all the mask channels channels = 4 # we don't really want noise for all the mask channels
shape = (1, channels,
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
if self.use_mps_noise or device.type == 'mps': if self.use_mps_noise or device.type == 'mps':
return torch.randn([1, return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else: else:
return torch.randn([1, return torch.randn(shape, dtype=self.torch_dtype(), device=device)
channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)