From 6858c14d949784c6a6ed5ad31ca0a990449492cb Mon Sep 17 00:00:00 2001 From: Kyle Lacy Date: Fri, 23 Sep 2022 01:57:50 -0700 Subject: [PATCH] Allow `Generate` to take images as readers or `Image` instances --- ldm/generate.py | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index ed3c2390be..78a691dbeb 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -591,8 +591,8 @@ class Generate: def _make_images( self, - img_path, - mask_path, + img, + mask, width, height, fit=False, @@ -600,11 +600,11 @@ class Generate: ): init_image = None init_mask = None - if not img_path: + if not img: return None, None image = self._load_img( - img_path, + img, width, height, fit=fit @@ -614,7 +614,7 @@ class Generate: init_image = self._create_init_image(image) # this returns a torch tensor # if image has a transparent area and no mask was provided, then try to generate mask - if self._has_transparency(image) and not mask_path: + if self._has_transparency(image) and not mask: print( '>> Initial image has transparent areas. Will inpaint in these regions.') if self._check_for_erasure(image): @@ -626,9 +626,9 @@ class Generate: # this returns a torch tensor init_mask = self._create_init_mask(image) - if mask_path: + if mask: mask_image = self._load_img( - mask_path, width, height, fit=fit) # this returns an Image + mask, width, height, fit=fit) # this returns an Image init_mask = self._create_init_mask(mask_image) return init_image, init_mask @@ -834,15 +834,25 @@ class Generate: return model - def _load_img(self, path, width, height, fit=False): - assert os.path.exists(path), f'>> {path}: File not found' + def _load_img(self, img, width, height, fit=False): + if isinstance(img, Image.Image): + image = img + print( + f'>> using provided input image of size {image.width}x{image.height}' + ) + elif isinstance(image, str): + assert os.path.exists(img), f'>> {img}: File not found' + + image = Image.open(img) + print( + f'>> loaded input image of size {image.width}x{image.height} from {img}' + ) + else: + image = Image.open(img) + print( + f'>> loaded input image of size {image.width}x{image.height}' + ) - # with Image.open(path) as img: - # image = img.convert('RGBA') - image = Image.open(path) - print( - f'>> loaded input image of size {image.width}x{image.height} from {path}' - ) if fit: image = self._fit_image(image, (width, height)) else: