mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow user to generate images with initial noise as on M1 / mps system
This commit is contained in:
parent
3170c83d8d
commit
0b7ca6a326
@ -753,6 +753,13 @@ class Args(object):
|
||||
type=str,
|
||||
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--use_mps_noise',
|
||||
action='store_true',
|
||||
dest='use_mps_noise',
|
||||
help='Simulate noise on M1 systems to get the same results'
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def format_metadata(**kwargs):
|
||||
|
@ -23,6 +23,7 @@ class Generator():
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
|
@ -59,7 +59,7 @@ class Txt2Img(Generator):
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
if device.type == 'mps':
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
|
@ -116,7 +116,7 @@ class Txt2Img2Img(Generator):
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
if device.type == 'mps':
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
|
@ -258,6 +258,7 @@ class Generate:
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
use_mps_noise = False,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
@ -386,6 +387,7 @@ class Generate:
|
||||
|
||||
generator.set_variation(
|
||||
self.seed, variation_amount, with_variations)
|
||||
generator.use_mps_noise = use_mps_noise
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations=iterations,
|
||||
|
Loading…
x
Reference in New Issue
Block a user