do not encode init image in starting latent

This commit is contained in:
Lincoln Stein 2022-10-25 22:44:42 -04:00
parent 8d5a225011
commit d3047c7cb0

View File

@ -64,7 +64,6 @@ class Omnibus(Img2Img,Txt2Img):
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
# and the masked image is just a copy of the original
masked_image = init_image
t_enc = int(strength * steps)
else: # txt2img
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
@ -111,7 +110,7 @@ class Omnibus(Img2Img,Txt2Img):
samples, _ = sampler.sample(
batch_size = 1,
S = t_enc,
S = steps,
x_T = x_T,
conditioning = cond,
shape = shape,
@ -145,7 +144,4 @@ class Omnibus(Img2Img,Txt2Img):
return batch
def get_noise(self, width:int, height:int):
if self.init_latent:
return super(Img2Img,self).get_noise(width,height)
else:
return super(Txt2Img,self).get_noise(width,height)
return super(Txt2Img,self).get_noise(width,height)