Fix bug in InvokeInt8Params that was causing it to use double the necessary VRAM.

This commit is contained in:
Ryan Dick 2024-08-21 19:01:11 +00:00 committed by Brandon
parent fd68a2475b
commit 19a68afb3a

View File

@ -22,7 +22,7 @@ class InvokeInt8Params(bnb.nn.Int8Params):
return super().cuda(device)
elif self.CB is not None and self.SCB is not None:
self.data = self.data.cuda()
self.CB = self.CB.cuda()
self.CB = self.data
self.SCB = self.SCB.cuda()
else:
# we store the 8-bit rows-major weight