From 0b7ca6a3267ffae35406904ffd7ab62fc252cdb7 Mon Sep 17 00:00:00 2001 From: ArDiouscuros <72071512+ArDiouscuros@users.noreply.github.com> Date: Fri, 7 Oct 2022 22:52:14 +0200 Subject: [PATCH] Allow user to generate images with initial noise as on M1 / mps system --- ldm/dream/args.py | 7 +++++++ ldm/dream/generator/base.py | 1 + ldm/dream/generator/txt2img.py | 2 +- ldm/dream/generator/txt2img2img.py | 2 +- ldm/generate.py | 2 ++ 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ldm/dream/args.py b/ldm/dream/args.py index 2063286122..18d7499d80 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -753,6 +753,13 @@ class Args(object): type=str, help='list of variations to apply, in the format `seed:weight,seed:weight,...' ) + render_group.add_argument( + '--use_mps_noise', + action='store_true', + dest='use_mps_noise', + help='Simulate noise on M1 systems to get the same results' + ) + return parser def format_metadata(**kwargs): diff --git a/ldm/dream/generator/base.py b/ldm/dream/generator/base.py index ae1e6af555..969b8b47b5 100644 --- a/ldm/dream/generator/base.py +++ b/ldm/dream/generator/base.py @@ -23,6 +23,7 @@ class Generator(): self.downsampling_factor = downsampling # BUG: should come from model or config self.variation_amount = 0 self.with_variations = [] + self.use_mps_noise = False # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self,prompt,**kwargs): diff --git a/ldm/dream/generator/txt2img.py b/ldm/dream/generator/txt2img.py index b702edc92d..4b5df12cfc 100644 --- a/ldm/dream/generator/txt2img.py +++ b/ldm/dream/generator/txt2img.py @@ -59,7 +59,7 @@ class Txt2Img(Generator): # returns a tensor filled with random numbers from a normal distribution def get_noise(self,width,height): device = self.model.device - if device.type == 'mps': + if self.use_mps_noise or device.type == 'mps': x = torch.randn([1, self.latent_channels, height // self.downsampling_factor, diff --git a/ldm/dream/generator/txt2img2img.py b/ldm/dream/generator/txt2img2img.py index d6c0cdf168..9fcd3e8c65 100644 --- a/ldm/dream/generator/txt2img2img.py +++ b/ldm/dream/generator/txt2img2img.py @@ -116,7 +116,7 @@ class Txt2Img2Img(Generator): scaled_height = height device = self.model.device - if device.type == 'mps': + if self.use_mps_noise or device.type == 'mps': return torch.randn([1, self.latent_channels, scaled_height // self.downsampling_factor, diff --git a/ldm/generate.py b/ldm/generate.py index 37e84f05e6..55df76c5b7 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -258,6 +258,7 @@ class Generate: # Set this True to handle KeyboardInterrupt internally catch_interrupts = False, hires_fix = False, + use_mps_noise = False, **args, ): # eat up additional cruft """ @@ -386,6 +387,7 @@ class Generate: generator.set_variation( self.seed, variation_amount, with_variations) + generator.use_mps_noise = use_mps_noise results = generator.generate( prompt, iterations=iterations,