stop crashes on non-square images

This commit is contained in:
Lincoln Stein 2022-10-25 13:17:06 -04:00
parent dd07392045
commit 4352eb6628
3 changed files with 9 additions and 2 deletions

View File

@ -71,6 +71,8 @@ class Omnibus(Img2Img,Txt2Img):
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
masked_image = init_image
height = init_image.shape[2]
width = init_image.shape[3]
model = self.model
def make_image(x_T):
@ -88,7 +90,6 @@ class Omnibus(Img2Img,Txt2Img):
)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()

View File

@ -89,6 +89,9 @@ class Outcrop(object):
def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
mask_height = pixels if self.generate.model.model.conditioning_key in ('hybrid','concat') \
else pixels *2
# first paste places old image at top of extended image, stretch
# it, and applies a gaussian blur to it
# take the top half region, stretch and paste it
@ -105,7 +108,9 @@ class Outcrop(object):
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels*2))
alpha.paste(0,(0,0,extended_img.width,mask_height))
extended_img.putalpha(alpha)
extended_img.save('outputs/curly_extended.png')
return extended_img

View File

@ -265,6 +265,7 @@ class Sampler(object):
)
if mask is not None:
print('DEBUG: in masking routine')
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts