From 4513320bf161d3efdd99e2aef4f7c37472131373 Mon Sep 17 00:00:00 2001 From: damian0815 Date: Wed, 2 Nov 2022 00:31:58 +0100 Subject: [PATCH] save VRAM by not recombining tensors that have been sliced to save VRAM --- .../diffusion/cross_attention_control.py | 77 +++++-------------- 1 file changed, 19 insertions(+), 58 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 2f1470512f..1a161fbc86 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -151,72 +151,33 @@ class CrossAttentionControl: #else: # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - if self.use_last_attn_slice: - this_attn_slice = attn_slice - if self.last_attn_slice_mask is not None: - # indices and mask operate on dim=2, no need to slice - base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) - base_attn_slice_mask = self.last_attn_slice_mask - if dim is None: - base_attn_slice = base_attn_slice_full - #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 0: - base_attn_slice = base_attn_slice_full[start:end] - #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 1: - base_attn_slice = base_attn_slice_full[:, start:end] - #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - - attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ - base_attn_slice * base_attn_slice_mask + if dim is None: + last_attn_slice = self.last_attn_slice + # print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) else: - if dim is None: - attn_slice = self.last_attn_slice - #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 0: - attn_slice = self.last_attn_slice[start:end] - #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 1: - attn_slice = self.last_attn_slice[:, start:end] - #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + last_attn_slice = self.last_attn_slice[offset] + + if self.last_attn_slice_mask is None: + # just use everything + attn_slice = last_attn_slice + else: + last_attn_slice_mask = self.last_attn_slice_mask + remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices) + + this_attn_slice = attn_slice + this_attn_slice_mask = 1 - last_attn_slice_mask + attn_slice = this_attn_slice * this_attn_slice_mask + \ + remapped_last_attn_slice * last_attn_slice_mask if self.save_last_attn_slice: if dim is None: self.last_attn_slice = attn_slice - elif dim == 0: - # dynamically grow last_attn_slice if needed + else: if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - #print("no last_attn_slice: shape now", self.last_attn_slice.shape) - elif self.last_attn_slice.shape[0] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) - assert(self.last_attn_slice.shape[0] == end) - #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + self.last_attn_slice = { offset: attn_slice } else: - # no need to grow - self.last_attn_slice[start:end] = attn_slice - #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) - - elif dim == 1: - # dynamically grow last_attn_slice if needed - if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - elif self.last_attn_slice.shape[1] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) - assert(self.last_attn_slice.shape[1] == end) - else: - # no need to grow - self.last_attn_slice[:, start:end] = attn_slice - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - if dim is None: - weights = self.last_attn_slice_weights - elif dim == 0: - weights = self.last_attn_slice_weights[start:end] - elif dim == 1: - weights = self.last_attn_slice_weights[:, start:end] - attn_slice = attn_slice * weights + self.last_attn_slice[offset] = attn_slice return attn_slice