Fix the padding behavior when max-pooling regional prompt masks to mirror the downscaling behavior of SD and SDXL. Prior to this change, denoising with input latent dimensions that were not evenly divisible by 8 would raise an exception.

This commit is contained in:
Ryan Dick 2024-04-09 15:15:12 -04:00 committed by Kent Keirsey
parent 69f6c24f52
commit fba40eb1bd

View File

@ -61,9 +61,12 @@ class RegionalPromptData:
if downscale_factor <= max_downscale_factor: if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# regions to be lost entirely. # regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g. # TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g.
# nearest interpolation), and could potentially use a weighted mask rather than a binary mask. # nearest interpolation), and could potentially use a weighted mask rather than a binary mask.
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2) batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2, ceil_mode=True)
return batch_sample_masks_by_seq_len return batch_sample_masks_by_seq_len