mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
save VRAM by not recombining tensors that have been sliced to save VRAM
This commit is contained in:
parent
dff5681cf0
commit
4513320bf1
@ -151,72 +151,33 @@ class CrossAttentionControl:
|
|||||||
#else:
|
#else:
|
||||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||||
|
|
||||||
|
|
||||||
if self.use_last_attn_slice:
|
if self.use_last_attn_slice:
|
||||||
this_attn_slice = attn_slice
|
|
||||||
if self.last_attn_slice_mask is not None:
|
|
||||||
# indices and mask operate on dim=2, no need to slice
|
|
||||||
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
|
||||||
base_attn_slice_mask = self.last_attn_slice_mask
|
|
||||||
if dim is None:
|
if dim is None:
|
||||||
base_attn_slice = base_attn_slice_full
|
last_attn_slice = self.last_attn_slice
|
||||||
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
||||||
elif dim == 0:
|
|
||||||
base_attn_slice = base_attn_slice_full[start:end]
|
|
||||||
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
||||||
elif dim == 1:
|
|
||||||
base_attn_slice = base_attn_slice_full[:, start:end]
|
|
||||||
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
||||||
|
|
||||||
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
|
||||||
base_attn_slice * base_attn_slice_mask
|
|
||||||
else:
|
|
||||||
if dim is None:
|
|
||||||
attn_slice = self.last_attn_slice
|
|
||||||
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||||
elif dim == 0:
|
else:
|
||||||
attn_slice = self.last_attn_slice[start:end]
|
last_attn_slice = self.last_attn_slice[offset]
|
||||||
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
|
||||||
elif dim == 1:
|
if self.last_attn_slice_mask is None:
|
||||||
attn_slice = self.last_attn_slice[:, start:end]
|
# just use everything
|
||||||
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
attn_slice = last_attn_slice
|
||||||
|
else:
|
||||||
|
last_attn_slice_mask = self.last_attn_slice_mask
|
||||||
|
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
|
||||||
|
|
||||||
|
this_attn_slice = attn_slice
|
||||||
|
this_attn_slice_mask = 1 - last_attn_slice_mask
|
||||||
|
attn_slice = this_attn_slice * this_attn_slice_mask + \
|
||||||
|
remapped_last_attn_slice * last_attn_slice_mask
|
||||||
|
|
||||||
if self.save_last_attn_slice:
|
if self.save_last_attn_slice:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
self.last_attn_slice = attn_slice
|
self.last_attn_slice = attn_slice
|
||||||
elif dim == 0:
|
|
||||||
# dynamically grow last_attn_slice if needed
|
|
||||||
if self.last_attn_slice is None:
|
|
||||||
self.last_attn_slice = attn_slice
|
|
||||||
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
|
||||||
elif self.last_attn_slice.shape[0] == start:
|
|
||||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
|
||||||
assert(self.last_attn_slice.shape[0] == end)
|
|
||||||
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
|
||||||
else:
|
else:
|
||||||
# no need to grow
|
|
||||||
self.last_attn_slice[start:end] = attn_slice
|
|
||||||
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
|
||||||
|
|
||||||
elif dim == 1:
|
|
||||||
# dynamically grow last_attn_slice if needed
|
|
||||||
if self.last_attn_slice is None:
|
if self.last_attn_slice is None:
|
||||||
self.last_attn_slice = attn_slice
|
self.last_attn_slice = { offset: attn_slice }
|
||||||
elif self.last_attn_slice.shape[1] == start:
|
|
||||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
|
||||||
assert(self.last_attn_slice.shape[1] == end)
|
|
||||||
else:
|
else:
|
||||||
# no need to grow
|
self.last_attn_slice[offset] = attn_slice
|
||||||
self.last_attn_slice[:, start:end] = attn_slice
|
|
||||||
|
|
||||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
|
||||||
if dim is None:
|
|
||||||
weights = self.last_attn_slice_weights
|
|
||||||
elif dim == 0:
|
|
||||||
weights = self.last_attn_slice_weights[start:end]
|
|
||||||
elif dim == 1:
|
|
||||||
weights = self.last_attn_slice_weights[:, start:end]
|
|
||||||
attn_slice = attn_slice * weights
|
|
||||||
|
|
||||||
return attn_slice
|
return attn_slice
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user