mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Hi res mode fix duplicates with img2img scaling
Add message about interpolation size Fix crash if sampler not set to DDIM, change parameter name to hires_fix Hi res mode fix duplicates with img2img scaling
This commit is contained in:
parent
33162355be
commit
0c354eccaa
@ -569,6 +569,12 @@ class Args(object):
|
||||
type=str,
|
||||
help='Directory to save generated images and a log of prompts and seeds',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--hires_fix',
|
||||
action='store_true',
|
||||
dest='hires_fix',
|
||||
help='Create hires image using img2img to prevent dupes'
|
||||
)
|
||||
img2img_group.add_argument(
|
||||
'-I',
|
||||
'--init_img',
|
||||
|
126
ldm/dream/generator/txt2img2img.py
Normal file
126
ldm/dream/generator/txt2img2img.py
Normal file
@ -0,0 +1,126 @@
|
||||
'''
|
||||
ldm.dream.generator.txt2img inherits from ldm.dream.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,strength,step_callback=None,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
||||
trained_square = 512 * 512
|
||||
actual_square = width * height
|
||||
scale = math.sqrt(trained_square / actual_square)
|
||||
|
||||
init_width = math.ceil(scale * width / 64) * 64
|
||||
init_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
init_height // self.downsampling_factor,
|
||||
init_width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
x = self.get_noise(init_width, init_height)
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height}"
|
||||
)
|
||||
|
||||
# resizing
|
||||
samples = torch.nn.functional.interpolate(
|
||||
samples,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
|
||||
x = None
|
||||
|
||||
# Other samplers not supported yet, so ignore previous sampler
|
||||
if not isinstance(sampler,DDIMSampler):
|
||||
print(
|
||||
f"\n>> Sampler '{sampler.__class__.__name__}' is not yet supported for img2img. Using DDIM sampler"
|
||||
)
|
||||
img_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
img_sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
else:
|
||||
img_sampler = sampler
|
||||
|
||||
z_enc = img_sampler.stochastic_encode(
|
||||
samples,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = img_sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
if device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
@ -287,6 +287,7 @@ class Generate:
|
||||
upscale = None,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
@ -403,6 +404,8 @@ class Generate:
|
||||
generator = self._make_embiggen()
|
||||
elif init_image is not None:
|
||||
generator = self._make_img2img()
|
||||
elif hires_fix:
|
||||
generator = self._make_txt2img2img()
|
||||
else:
|
||||
generator = self._make_txt2img()
|
||||
|
||||
@ -660,6 +663,13 @@ class Generate:
|
||||
self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['txt2img']
|
||||
|
||||
def _make_txt2img2img(self):
|
||||
if not self.generators.get('txt2img2'):
|
||||
from ldm.dream.generator.txt2img2img import Txt2Img2Img
|
||||
self.generators['txt2img2'] = Txt2Img2Img(self.model, self.precision)
|
||||
self.generators['txt2img2'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['txt2img2']
|
||||
|
||||
def _make_inpaint(self):
|
||||
if not self.generators.get('inpaint'):
|
||||
from ldm.dream.generator.inpaint import Inpaint
|
||||
|
Loading…
Reference in New Issue
Block a user