mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve conflicts
This commit is contained in:
commit
5e8d1ca19f
@ -292,6 +292,7 @@ class Generate:
|
|||||||
# Set this True to handle KeyboardInterrupt internally
|
# Set this True to handle KeyboardInterrupt internally
|
||||||
catch_interrupts = False,
|
catch_interrupts = False,
|
||||||
hires_fix = False,
|
hires_fix = False,
|
||||||
|
use_mps_noise = False,
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -433,6 +434,7 @@ class Generate:
|
|||||||
generator.set_variation(
|
generator.set_variation(
|
||||||
self.seed, variation_amount, with_variations
|
self.seed, variation_amount, with_variations
|
||||||
)
|
)
|
||||||
|
generator.use_mps_noise = use_mps_noise
|
||||||
|
|
||||||
checker = {
|
checker = {
|
||||||
'checker':self.safety_checker,
|
'checker':self.safety_checker,
|
||||||
|
@ -816,6 +816,13 @@ class Args(object):
|
|||||||
type=str,
|
type=str,
|
||||||
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
||||||
)
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--use_mps_noise',
|
||||||
|
action='store_true',
|
||||||
|
dest='use_mps_noise',
|
||||||
|
help='Simulate noise on M1 systems to get the same results'
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def format_metadata(**kwargs):
|
def format_metadata(**kwargs):
|
||||||
|
@ -28,6 +28,7 @@ class Generator():
|
|||||||
self.threshold = 0
|
self.threshold = 0
|
||||||
self.variation_amount = 0
|
self.variation_amount = 0
|
||||||
self.with_variations = []
|
self.with_variations = []
|
||||||
|
self.use_mps_noise = False
|
||||||
|
|
||||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||||
def get_make_image(self,prompt,**kwargs):
|
def get_make_image(self,prompt,**kwargs):
|
||||||
|
@ -59,7 +59,7 @@ class Txt2Img(Generator):
|
|||||||
# returns a tensor filled with random numbers from a normal distribution
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
def get_noise(self,width,height):
|
def get_noise(self,width,height):
|
||||||
device = self.model.device
|
device = self.model.device
|
||||||
if device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
|
@ -118,7 +118,7 @@ class Txt2Img2Img(Generator):
|
|||||||
scaled_height = height
|
scaled_height = height
|
||||||
|
|
||||||
device = self.model.device
|
device = self.model.device
|
||||||
if device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
return torch.randn([1,
|
return torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
scaled_height // self.downsampling_factor,
|
scaled_height // self.downsampling_factor,
|
||||||
|
Loading…
Reference in New Issue
Block a user