mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixes dream.py mps seed
This commit is contained in:
parent
361cc42829
commit
fe5cc79249
@ -348,13 +348,23 @@ class T2I:
|
|||||||
|
|
||||||
def get_noise():
|
def get_noise():
|
||||||
if init_img:
|
if init_img:
|
||||||
return torch.randn_like(init_latent, device=self.device)
|
if self.device.type == 'mps':
|
||||||
|
return torch.randn_like(init_latent, device='cpu').to(self.device)
|
||||||
|
else:
|
||||||
|
return torch.randn_like(init_latent, device=self.device)
|
||||||
else:
|
else:
|
||||||
return torch.randn([1,
|
if self.device.type == 'mps':
|
||||||
self.latent_channels,
|
return torch.randn([1,
|
||||||
height // self.downsampling_factor,
|
self.latent_channels,
|
||||||
width // self.downsampling_factor],
|
height // self.downsampling_factor,
|
||||||
device=self.device)
|
width // self.downsampling_factor],
|
||||||
|
device='cpu').to(self.device)
|
||||||
|
else:
|
||||||
|
return torch.randn([1,
|
||||||
|
self.latent_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
initial_noise = None
|
initial_noise = None
|
||||||
if variation_amount > 0 or len(with_variations) > 0:
|
if variation_amount > 0 or len(with_variations) > 0:
|
||||||
@ -383,6 +393,8 @@ class T2I:
|
|||||||
x_T = initial_noise
|
x_T = initial_noise
|
||||||
else:
|
else:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
if self.device.type == 'mps':
|
||||||
|
x_T = get_noise()
|
||||||
# make_image will do the equivalent of get_noise itself
|
# make_image will do the equivalent of get_noise itself
|
||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user