From 8c8b34a889f416e282755804929bec2b7f9934a9 Mon Sep 17 00:00:00 2001 From: Peter Baylies Date: Mon, 5 Sep 2022 22:57:33 -0400 Subject: [PATCH] * Update to resolve conflicts. --- ldm/dream/generator/base.py | 13 +++++++++---- ldm/dream/generator/img2img.py | 12 ++++++++---- ldm/dream/generator/txt2img.py | 14 ++++++++++---- ldm/generate.py | 6 +++++- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/ldm/dream/generator/base.py b/ldm/dream/generator/base.py index 9bed3df719..38d2e373af 100644 --- a/ldm/dream/generator/base.py +++ b/ldm/dream/generator/base.py @@ -10,6 +10,7 @@ from PIL import Image from einops import rearrange, repeat from pytorch_lightning import seed_everything from ldm.dream.devices import choose_autocast_device +from ldm.util import rand_perlin_2d downsampling = 8 @@ -36,7 +37,7 @@ class Generator(): self.with_variations = with_variations def generate(self,prompt,init_image,width,height,iterations=1,seed=None, - image_callback=None, step_callback=None, + image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, **kwargs): device_type,scope = choose_autocast_device(self.model.device) make_image = self.get_make_image( @@ -45,6 +46,8 @@ class Generator(): width = width, height = height, step_callback = step_callback, + threshold = threshold, + perlin = perlin, **kwargs ) @@ -63,10 +66,8 @@ class Generator(): x_T = initial_noise else: seed_everything(seed) - if self.model.device.type == 'mps': - x_T = self.get_noise(width,height) + x_T = self.get_noise(width,height) - # make_image will do the equivalent of get_noise itself image = make_image(x_T) results.append([image, seed]) if image_callback is not None: @@ -115,6 +116,10 @@ class Generator(): """ raise NotImplementedError("get_noise() must be implemented in a descendent class") + def get_perlin_noise(self,width,height): + return torch.stack([rand_perlin_2d((height, width), (8, 8)).to(self.model.device) for _ in range(self.latent_channels)], dim=0) + + def new_seed(self): self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed diff --git a/ldm/dream/generator/img2img.py b/ldm/dream/generator/img2img.py index 242912d0eb..4b2590398c 100644 --- a/ldm/dream/generator/img2img.py +++ b/ldm/dream/generator/img2img.py @@ -15,12 +15,12 @@ class Img2Img(Generator): @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,init_image,strength,step_callback=None,**kwargs): + conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it. """ - + self.perlin = perlin # PLMS sampler not supported yet, so ignore previous sampler if not isinstance(sampler,DDIMSampler): print( @@ -67,6 +67,10 @@ class Img2Img(Generator): init_latent = self.init_latent assert init_latent is not None,'call to get_noise() when init_latent not set' if device.type == 'mps': - return torch.randn_like(init_latent, device='cpu').to(device) + x = torch.randn_like(init_latent, device='cpu').to(device) else: - return torch.randn_like(init_latent, device=device) + x = torch.randn_like(init_latent, device=device) + if self.perlin > 0.0: + shape = init_latent.shape + x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) + return x diff --git a/ldm/dream/generator/txt2img.py b/ldm/dream/generator/txt2img.py index d4cd25cb51..6e9c5149cd 100644 --- a/ldm/dream/generator/txt2img.py +++ b/ldm/dream/generator/txt2img.py @@ -12,12 +12,13 @@ class Txt2Img(Generator): @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,step_callback=None,**kwargs): + conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ + self.perlin = perlin uc, c = conditioning @torch.no_grad() @@ -37,7 +38,8 @@ class Txt2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, eta = ddim_eta, - img_callback = step_callback + img_callback = step_callback, + threshold = threshold, ) return self.sample_to_image(samples) @@ -48,14 +50,18 @@ class Txt2Img(Generator): def get_noise(self,width,height): device = self.model.device if device.type == 'mps': - return torch.randn([1, + x = torch.randn([1, self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor], device='cpu').to(device) else: - return torch.randn([1, + x = torch.randn([1, self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor], device=device) + print(self.perlin) + if self.perlin > 0.0: + x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) + return x diff --git a/ldm/generate.py b/ldm/generate.py index 9ba72c3676..3cae680724 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -193,6 +193,8 @@ class Generate: log_tokenization= False, with_variations = None, variation_amount = 0.0, + threshold = 0.0, + perlin = 0.0, # these are specific to img2img init_img = None, mask = None, @@ -335,7 +337,9 @@ class Generate: height = height, init_image = init_image, # notice that init_image is different from init_img init_mask = init_mask_image, - strength = strength + strength = strength, + threshold = threshold, + perlin = perlin, ) if upscale is not None or gfpgan_strength > 0: