remove unneeded warnings from attention.py

This commit is contained in:
Lincoln Stein 2022-10-27 22:50:06 -04:00
parent 362b234cd1
commit 3033331f65

View File

@ -236,9 +236,7 @@ class CrossAttention(nn.Module):
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
print("warning: untested call to einsum_op_slice_dim0")
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
print("warning: untested call to einsum_op_slice_dim1")
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):