diff --git a/ldm/generate.py b/ldm/generate.py index 7f7bd43397..6c45c02d47 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -965,6 +965,8 @@ class Generate: # Obtain the mask from the transparency channel if mask_image.mode == 'L': mask = mask_image + elif mask_image.mode in ('RGB', 'P'): + mask = mask_image.convert('L') else: # Obtain the mask from the transparency channel mask = Image.new(mode="L", size=mask_image.size, color=255)