InvokeAI/ldm/dream/generator/inpaint.py
Mihail Dumitrescu d176fb07cd Replace --full_precision with --precision that works even if not specified
Allowed values are 'auto', 'float32', 'autocast', 'float16'. If not specified or 'auto' a working precision is automatically selected based on the torch device.
Context: #526
Deprecated --full_precision / -F

Tested on both cuda and cpu by calling scripts/dream.py without arguments and checked the auto configuration worked. With --precision=auto/float32/autocast/float16 it performs as expected, either working or failing with a reasonable error. Also checked Img2Img.
2022-09-20 17:08:00 -04:00

78 lines
2.6 KiB
Python

'''
ldm.dream.generator.inpaint descends from ldm.dream.generator
'''
import torch
import numpy as np
from einops import rearrange, repeat
from ldm.dream.devices import choose_autocast
from ldm.dream.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
class Inpaint(Img2Img):
def __init__(self, model, precision):
self.init_latent = None
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
# PLMS sampler not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
print(
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
print(f">> target t_enc is {t_enc} steps")
@torch.no_grad()
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
mask = mask_image,
init_latent = self.init_latent
)
return self.sample_to_image(samples)
return make_image