mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(txt2img2img): factor out tensor shape
This commit is contained in:
parent
1bb5b4ab32
commit
09b6104bfd
@ -115,22 +115,13 @@ class Txt2Img2Img(Generator):
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
|
||||
device = self.model.device
|
||||
channels = self.latent_channels
|
||||
if channels == 9:
|
||||
channels = 4 # we don't really want noise for all the mask channels
|
||||
shape = (1, channels,
|
||||
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device='cpu').to(device)
|
||||
return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device)
|
||||
return torch.randn(shape, dtype=self.torch_dtype(), device=device)
|
||||
|
Loading…
Reference in New Issue
Block a user