diff --git a/ldm/generate.py b/ldm/generate.py index 4f8f486524..7f7bd43397 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -947,7 +947,7 @@ class Generate: def _create_init_image(self, image: Image.Image, width, height, fit=True): if image.mode != 'RGBA': - image = image.convert('RGB') + image = image.convert('RGBA') image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) return image