resolve conflicts

This commit is contained in:
Lincoln Stein 2022-10-25 07:17:54 -04:00
commit 5e8d1ca19f
5 changed files with 12 additions and 2 deletions

View File

@ -292,6 +292,7 @@ class Generate:
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
use_mps_noise = False,
**args,
): # eat up additional cruft
"""
@ -433,6 +434,7 @@ class Generate:
generator.set_variation(
self.seed, variation_amount, with_variations
)
generator.use_mps_noise = use_mps_noise
checker = {
'checker':self.safety_checker,

View File

@ -816,6 +816,13 @@ class Args(object):
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
render_group.add_argument(
'--use_mps_noise',
action='store_true',
dest='use_mps_noise',
help='Simulate noise on M1 systems to get the same results'
)
return parser
def format_metadata(**kwargs):

View File

@ -28,6 +28,7 @@ class Generator():
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.use_mps_noise = False
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):

View File

@ -59,7 +59,7 @@ class Txt2Img(Generator):
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,

View File

@ -118,7 +118,7 @@ class Txt2Img2Img(Generator):
scaled_height = height
device = self.model.device
if device.type == 'mps':
if self.use_mps_noise or device.type == 'mps':
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,