fixes dream.py mps seed

This commit is contained in:
Lincoln Stein 2022-09-03 10:11:46 -04:00
parent 361cc42829
commit fe5cc79249

View File

@ -348,13 +348,23 @@ class T2I:
def get_noise(): def get_noise():
if init_img: if init_img:
return torch.randn_like(init_latent, device=self.device) if self.device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(self.device)
else:
return torch.randn_like(init_latent, device=self.device)
else: else:
return torch.randn([1, if self.device.type == 'mps':
self.latent_channels, return torch.randn([1,
height // self.downsampling_factor, self.latent_channels,
width // self.downsampling_factor], height // self.downsampling_factor,
device=self.device) width // self.downsampling_factor],
device='cpu').to(self.device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
initial_noise = None initial_noise = None
if variation_amount > 0 or len(with_variations) > 0: if variation_amount > 0 or len(with_variations) > 0:
@ -383,6 +393,8 @@ class T2I:
x_T = initial_noise x_T = initial_noise
else: else:
seed_everything(seed) seed_everything(seed)
if self.device.type == 'mps':
x_T = get_noise()
# make_image will do the equivalent of get_noise itself # make_image will do the equivalent of get_noise itself
image = make_image(x_T) image = make_image(x_T)
results.append([image, seed]) results.append([image, seed])