InvokeAI/ldm/dream/generator/txt2img2img.py
ArDiouscuros a626533cd4 Hi res mode fix duplicates with img2img scaling
Add message about interpolation size

Fix crash if sampler not set to DDIM, change parameter name to hires_fix

Hi res mode fix duplicates with img2img scaling
2022-10-01 11:56:59 +02:00

127 lines
4.5 KiB
Python

'''
ldm.dream.generator.txt2img inherits from ldm.dream.generator
'''
import torch
import numpy as np
import math
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # for get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,strength,step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
uc, c = conditioning
@torch.no_grad()
def make_image(x_T):
trained_square = 512 * 512
actual_square = width * height
scale = math.sqrt(trained_square / actual_square)
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
shape = [
self.latent_channels,
init_height // self.downsampling_factor,
init_width // self.downsampling_factor,
]
x = self.get_noise(init_width, init_height)
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x,
conditioning = c,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback
)
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height}"
)
# resizing
samples = torch.nn.functional.interpolate(
samples,
size=(height // self.downsampling_factor, width // self.downsampling_factor),
mode="bilinear"
)
t_enc = int(strength * steps)
x = None
# Other samplers not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
print(
f"\n>> Sampler '{sampler.__class__.__name__}' is not yet supported for img2img. Using DDIM sampler"
)
img_sampler = DDIMSampler(self.model, device=self.model.device)
img_sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
else:
img_sampler = sampler
z_enc = img_sampler.stochastic_encode(
samples,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = img_sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
return self.sample_to_image(samples)
return make_image
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)