fixes dream.py mps seed

This commit is contained in:
Lincoln Stein 2022-09-03 10:11:46 -04:00
parent 361cc42829
commit fe5cc79249

View File

@ -348,7 +348,17 @@ class T2I:
def get_noise(): def get_noise():
if init_img: if init_img:
if self.device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(self.device)
else:
return torch.randn_like(init_latent, device=self.device) return torch.randn_like(init_latent, device=self.device)
else:
if self.device.type == 'mps':
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(self.device)
else: else:
return torch.randn([1, return torch.randn([1,
self.latent_channels, self.latent_channels,
@ -383,6 +393,8 @@ class T2I:
x_T = initial_noise x_T = initial_noise
else: else:
seed_everything(seed) seed_everything(seed)
if self.device.type == 'mps':
x_T = get_noise()
# make_image will do the equivalent of get_noise itself # make_image will do the equivalent of get_noise itself
image = make_image(x_T) image = make_image(x_T)
results.append([image, seed]) results.append([image, seed])