Allow user to generate images with initial noise as on M1 / mps system

This commit is contained in:
ArDiouscuros 2022-10-07 22:52:14 +02:00 committed by tildebyte
parent 3170c83d8d
commit 0b7ca6a326
5 changed files with 12 additions and 2 deletions

View File

@ -753,6 +753,13 @@ class Args(object):
type=str, type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...' help='list of variations to apply, in the format `seed:weight,seed:weight,...'
) )
render_group.add_argument(
'--use_mps_noise',
action='store_true',
dest='use_mps_noise',
help='Simulate noise on M1 systems to get the same results'
)
return parser return parser
def format_metadata(**kwargs): def format_metadata(**kwargs):

View File

@ -23,6 +23,7 @@ class Generator():
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
self.variation_amount = 0 self.variation_amount = 0
self.with_variations = [] self.with_variations = []
self.use_mps_noise = False
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py # this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs): def get_make_image(self,prompt,**kwargs):

View File

@ -59,7 +59,7 @@ class Txt2Img(Generator):
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height): def get_noise(self,width,height):
device = self.model.device device = self.model.device
if device.type == 'mps': if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1, x = torch.randn([1,
self.latent_channels, self.latent_channels,
height // self.downsampling_factor, height // self.downsampling_factor,

View File

@ -116,7 +116,7 @@ class Txt2Img2Img(Generator):
scaled_height = height scaled_height = height
device = self.model.device device = self.model.device
if device.type == 'mps': if self.use_mps_noise or device.type == 'mps':
return torch.randn([1, return torch.randn([1,
self.latent_channels, self.latent_channels,
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,

View File

@ -258,6 +258,7 @@ class Generate:
# Set this True to handle KeyboardInterrupt internally # Set this True to handle KeyboardInterrupt internally
catch_interrupts = False, catch_interrupts = False,
hires_fix = False, hires_fix = False,
use_mps_noise = False,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -386,6 +387,7 @@ class Generate:
generator.set_variation( generator.set_variation(
self.seed, variation_amount, with_variations) self.seed, variation_amount, with_variations)
generator.use_mps_noise = use_mps_noise
results = generator.generate( results = generator.generate(
prompt, prompt,
iterations=iterations, iterations=iterations,