Fix performance issue introduced by torch cuda cache clear during generation

This commit is contained in:
Kyle Schouviller 2022-11-10 21:43:56 -08:00 committed by Lincoln Stein
parent fa3670270e
commit b116715490

View File

@ -282,7 +282,6 @@ class CrossAttention(nn.Module):
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
torch.cuda.empty_cache()
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)