Allow Generate to take images as readers or Image instances

This commit is contained in:
Kyle Lacy 2022-09-23 01:57:50 -07:00 committed by Lincoln Stein
parent bf21a0bf02
commit 6858c14d94

View File

@ -591,8 +591,8 @@ class Generate:
def _make_images(
self,
img_path,
mask_path,
img,
mask,
width,
height,
fit=False,
@ -600,11 +600,11 @@ class Generate:
):
init_image = None
init_mask = None
if not img_path:
if not img:
return None, None
image = self._load_img(
img_path,
img,
width,
height,
fit=fit
@ -614,7 +614,7 @@ class Generate:
init_image = self._create_init_image(image) # this returns a torch tensor
# if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image) and not mask_path:
if self._has_transparency(image) and not mask:
print(
'>> Initial image has transparent areas. Will inpaint in these regions.')
if self._check_for_erasure(image):
@ -626,9 +626,9 @@ class Generate:
# this returns a torch tensor
init_mask = self._create_init_mask(image)
if mask_path:
if mask:
mask_image = self._load_img(
mask_path, width, height, fit=fit) # this returns an Image
mask, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image)
return init_image, init_mask
@ -834,15 +834,25 @@ class Generate:
return model
def _load_img(self, path, width, height, fit=False):
assert os.path.exists(path), f'>> {path}: File not found'
def _load_img(self, img, width, height, fit=False):
if isinstance(img, Image.Image):
image = img
print(
f'>> using provided input image of size {image.width}x{image.height}'
)
elif isinstance(image, str):
assert os.path.exists(img), f'>> {img}: File not found'
image = Image.open(img)
print(
f'>> loaded input image of size {image.width}x{image.height} from {img}'
)
else:
image = Image.open(img)
print(
f'>> loaded input image of size {image.width}x{image.height}'
)
# with Image.open(path) as img:
# image = img.convert('RGBA')
image = Image.open(path)
print(
f'>> loaded input image of size {image.width}x{image.height} from {path}'
)
if fit:
image = self._fit_image(image, (width, height))
else: