Hi res mode fix duplicates with img2img scaling

Add message about interpolation size

Fix crash if sampler not set to DDIM, change parameter name to hires_fix

Hi res mode fix duplicates with img2img scaling
This commit is contained in:
ArDiouscuros 2022-09-30 00:58:06 +02:00 committed by Lincoln Stein
parent 33162355be
commit 0c354eccaa
3 changed files with 142 additions and 0 deletions

View File

@ -569,6 +569,12 @@ class Args(object):
type=str,
help='Directory to save generated images and a log of prompts and seeds',
)
render_group.add_argument(
'--hires_fix',
action='store_true',
dest='hires_fix',
help='Create hires image using img2img to prevent dupes'
)
img2img_group.add_argument(
'-I',
'--init_img',

View File

@ -0,0 +1,126 @@
'''
ldm.dream.generator.txt2img inherits from ldm.dream.generator
'''
import torch
import numpy as np
import math
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # for get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,strength,step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
uc, c = conditioning
@torch.no_grad()
def make_image(x_T):
trained_square = 512 * 512
actual_square = width * height
scale = math.sqrt(trained_square / actual_square)
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
shape = [
self.latent_channels,
init_height // self.downsampling_factor,
init_width // self.downsampling_factor,
]
x = self.get_noise(init_width, init_height)
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x,
conditioning = c,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback
)
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height}"
)
# resizing
samples = torch.nn.functional.interpolate(
samples,
size=(height // self.downsampling_factor, width // self.downsampling_factor),
mode="bilinear"
)
t_enc = int(strength * steps)
x = None
# Other samplers not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
print(
f"\n>> Sampler '{sampler.__class__.__name__}' is not yet supported for img2img. Using DDIM sampler"
)
img_sampler = DDIMSampler(self.model, device=self.model.device)
img_sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
else:
img_sampler = sampler
z_enc = img_sampler.stochastic_encode(
samples,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = img_sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)

View File

@ -287,6 +287,7 @@ class Generate:
upscale = None,
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
**args,
): # eat up additional cruft
"""
@ -403,6 +404,8 @@ class Generate:
generator = self._make_embiggen()
elif init_image is not None:
generator = self._make_img2img()
elif hires_fix:
generator = self._make_txt2img2img()
else:
generator = self._make_txt2img()
@ -660,6 +663,13 @@ class Generate:
self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['txt2img']
def _make_txt2img2img(self):
if not self.generators.get('txt2img2'):
from ldm.dream.generator.txt2img2img import Txt2Img2Img
self.generators['txt2img2'] = Txt2Img2Img(self.model, self.precision)
self.generators['txt2img2'].free_gpu_mem = self.free_gpu_mem
return self.generators['txt2img2']
def _make_inpaint(self):
if not self.generators.get('inpaint'):
from ldm.dream.generator.inpaint import Inpaint