Fix _negative_cross_attn_mask_score.

This commit is contained in:
Ryan Dick 2024-03-05 15:55:13 -05:00
parent 57266d36a2
commit b5c334d8ca

View File

@ -32,7 +32,7 @@ class RegionalPromptData:
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks( self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor regions, max_downscale_factor
) )
self._negative_cross_attn_mask_score = 0.0 self._negative_cross_attn_mask_score = -10000.0
def _prepare_spatial_masks( def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8