fixes dream.py mps seed

This commit is contained in:
Lincoln Stein 2022-09-03 10:11:46 -04:00
parent 361cc42829
commit fe5cc79249

View File

@ -348,13 +348,23 @@ class T2I:
def get_noise():
if init_img:
return torch.randn_like(init_latent, device=self.device)
if self.device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(self.device)
else:
return torch.randn_like(init_latent, device=self.device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
if self.device.type == 'mps':
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(self.device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
initial_noise = None
if variation_amount > 0 or len(with_variations) > 0:
@ -383,6 +393,8 @@ class T2I:
x_T = initial_noise
else:
seed_everything(seed)
if self.device.type == 'mps':
x_T = get_noise()
# make_image will do the equivalent of get_noise itself
image = make_image(x_T)
results.append([image, seed])