stop crashes on non-square images

This commit is contained in:
Lincoln Stein
2022-10-25 13:17:06 -04:00
parent dd07392045
commit 4352eb6628
3 changed files with 9 additions and 2 deletions

View File

@ -71,6 +71,8 @@ class Omnibus(Img2Img,Txt2Img):
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image
height = init_image.shape[2]
width = init_image.shape[3]
model = self.model
def make_image(x_T):
@ -88,7 +90,6 @@ class Omnibus(Img2Img,Txt2Img):
)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()