make variations work with inpainting model

This commit is contained in:
Lincoln Stein 2022-10-26 00:18:31 -04:00
parent d3047c7cb0
commit 906dafe3cd
3 changed files with 10 additions and 3 deletions

View File

@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
import random import random
import os import os
import traceback
from tqdm import tqdm, trange from tqdm import tqdm, trange
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from einops import rearrange, repeat from einops import rearrange, repeat
@ -82,7 +83,9 @@ class Generator():
try: try:
x_T = self.get_noise(width,height) x_T = self.get_noise(width,height)
except: except:
pass print('** An error occurred while getting initial noise **')
print(traceback.format_exc())
image = make_image(x_T) image = make_image(x_T)
if self.safety_checker is not None: if self.safety_checker is not None:

View File

@ -14,7 +14,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
super().__init__(model, precision) super().__init__(model, precision)
self.init_latent = None # by get_noise() self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):

View File

@ -70,6 +70,7 @@ class Omnibus(Img2Img,Txt2Img):
mask_image = torch.ones(1, 1, height, width, device=self.model.device) mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image masked_image = init_image
self.init_latent = init_image
height = init_image.shape[2] height = init_image.shape[2]
width = init_image.shape[3] width = init_image.shape[3]
model = self.model model = self.model
@ -144,4 +145,7 @@ class Omnibus(Img2Img,Txt2Img):
return batch return batch
def get_noise(self, width:int, height:int): def get_noise(self, width:int, height:int):
return super(Txt2Img,self).get_noise(width,height) if self.init_latent is not None:
height = self.init_latent.shape[2]
width = self.init_latent.shape[3]
return Txt2Img.get_noise(self,width,height)