From 0c354eccaa04db52f1a165092e5a6cc74e69f74c Mon Sep 17 00:00:00 2001 From: ArDiouscuros <72071512+ArDiouscuros@users.noreply.github.com> Date: Fri, 30 Sep 2022 00:58:06 +0200 Subject: [PATCH] Hi res mode fix duplicates with img2img scaling Add message about interpolation size Fix crash if sampler not set to DDIM, change parameter name to hires_fix Hi res mode fix duplicates with img2img scaling --- ldm/dream/args.py | 6 ++ ldm/dream/generator/txt2img2img.py | 126 +++++++++++++++++++++++++++++ ldm/generate.py | 10 +++ 3 files changed, 142 insertions(+) create mode 100644 ldm/dream/generator/txt2img2img.py diff --git a/ldm/dream/args.py b/ldm/dream/args.py index aabd900213..b0ac1b906a 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -569,6 +569,12 @@ class Args(object): type=str, help='Directory to save generated images and a log of prompts and seeds', ) + render_group.add_argument( + '--hires_fix', + action='store_true', + dest='hires_fix', + help='Create hires image using img2img to prevent dupes' + ) img2img_group.add_argument( '-I', '--init_img', diff --git a/ldm/dream/generator/txt2img2img.py b/ldm/dream/generator/txt2img2img.py new file mode 100644 index 0000000000..502a2bdca3 --- /dev/null +++ b/ldm/dream/generator/txt2img2img.py @@ -0,0 +1,126 @@ +''' +ldm.dream.generator.txt2img inherits from ldm.dream.generator +''' + +import torch +import numpy as np +import math +from ldm.dream.generator.base import Generator +from ldm.models.diffusion.ddim import DDIMSampler + + +class Txt2Img2Img(Generator): + def __init__(self, model, precision): + super().__init__(model, precision) + self.init_latent = None # for get_noise() + + @torch.no_grad() + def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,width,height,strength,step_callback=None,**kwargs): + """ + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it + kwargs are 'width' and 'height' + """ + uc, c = conditioning + + @torch.no_grad() + def make_image(x_T): + + trained_square = 512 * 512 + actual_square = width * height + scale = math.sqrt(trained_square / actual_square) + + init_width = math.ceil(scale * width / 64) * 64 + init_height = math.ceil(scale * height / 64) * 64 + + shape = [ + self.latent_channels, + init_height // self.downsampling_factor, + init_width // self.downsampling_factor, + ] + + x = self.get_noise(init_width, init_height) + + if self.free_gpu_mem and self.model.model.device != self.model.device: + self.model.model.to(self.model.device) + + samples, _ = sampler.sample( + batch_size = 1, + S = steps, + x_T = x, + conditioning = c, + shape = shape, + verbose = False, + unconditional_guidance_scale = cfg_scale, + unconditional_conditioning = uc, + eta = ddim_eta, + img_callback = step_callback + ) + + print( + f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height}" + ) + + # resizing + samples = torch.nn.functional.interpolate( + samples, + size=(height // self.downsampling_factor, width // self.downsampling_factor), + mode="bilinear" + ) + + t_enc = int(strength * steps) + + x = None + + # Other samplers not supported yet, so ignore previous sampler + if not isinstance(sampler,DDIMSampler): + print( + f"\n>> Sampler '{sampler.__class__.__name__}' is not yet supported for img2img. Using DDIM sampler" + ) + img_sampler = DDIMSampler(self.model, device=self.model.device) + img_sampler.make_schedule( + ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False + ) + else: + img_sampler = sampler + + z_enc = img_sampler.stochastic_encode( + samples, + torch.tensor([t_enc]).to(self.model.device), + noise=x_T + ) + + # decode it + samples = img_sampler.decode( + z_enc, + c, + t_enc, + img_callback = step_callback, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + ) + + if self.free_gpu_mem: + self.model.model.to("cpu") + + return self.sample_to_image(samples) + + return make_image + + + # returns a tensor filled with random numbers from a normal distribution + def get_noise(self,width,height): + device = self.model.device + if device.type == 'mps': + return torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device='cpu').to(device) + else: + return torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device=device) diff --git a/ldm/generate.py b/ldm/generate.py index 8961bbb9f7..c0936fdccf 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -287,6 +287,7 @@ class Generate: upscale = None, # Set this True to handle KeyboardInterrupt internally catch_interrupts = False, + hires_fix = False, **args, ): # eat up additional cruft """ @@ -403,6 +404,8 @@ class Generate: generator = self._make_embiggen() elif init_image is not None: generator = self._make_img2img() + elif hires_fix: + generator = self._make_txt2img2img() else: generator = self._make_txt2img() @@ -660,6 +663,13 @@ class Generate: self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem return self.generators['txt2img'] + def _make_txt2img2img(self): + if not self.generators.get('txt2img2'): + from ldm.dream.generator.txt2img2img import Txt2Img2Img + self.generators['txt2img2'] = Txt2Img2Img(self.model, self.precision) + self.generators['txt2img2'].free_gpu_mem = self.free_gpu_mem + return self.generators['txt2img2'] + def _make_inpaint(self): if not self.generators.get('inpaint'): from ldm.dream.generator.inpaint import Inpaint