From 4513320bf161d3efdd99e2aef4f7c37472131373 Mon Sep 17 00:00:00 2001 From: damian0815 Date: Wed, 2 Nov 2022 00:31:58 +0100 Subject: [PATCH 1/4] save VRAM by not recombining tensors that have been sliced to save VRAM --- .../diffusion/cross_attention_control.py | 77 +++++-------------- 1 file changed, 19 insertions(+), 58 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 2f1470512f..1a161fbc86 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -151,72 +151,33 @@ class CrossAttentionControl: #else: # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - if self.use_last_attn_slice: - this_attn_slice = attn_slice - if self.last_attn_slice_mask is not None: - # indices and mask operate on dim=2, no need to slice - base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) - base_attn_slice_mask = self.last_attn_slice_mask - if dim is None: - base_attn_slice = base_attn_slice_full - #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 0: - base_attn_slice = base_attn_slice_full[start:end] - #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 1: - base_attn_slice = base_attn_slice_full[:, start:end] - #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - - attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ - base_attn_slice * base_attn_slice_mask + if dim is None: + last_attn_slice = self.last_attn_slice + # print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) else: - if dim is None: - attn_slice = self.last_attn_slice - #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 0: - attn_slice = self.last_attn_slice[start:end] - #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 1: - attn_slice = self.last_attn_slice[:, start:end] - #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + last_attn_slice = self.last_attn_slice[offset] + + if self.last_attn_slice_mask is None: + # just use everything + attn_slice = last_attn_slice + else: + last_attn_slice_mask = self.last_attn_slice_mask + remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices) + + this_attn_slice = attn_slice + this_attn_slice_mask = 1 - last_attn_slice_mask + attn_slice = this_attn_slice * this_attn_slice_mask + \ + remapped_last_attn_slice * last_attn_slice_mask if self.save_last_attn_slice: if dim is None: self.last_attn_slice = attn_slice - elif dim == 0: - # dynamically grow last_attn_slice if needed + else: if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - #print("no last_attn_slice: shape now", self.last_attn_slice.shape) - elif self.last_attn_slice.shape[0] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) - assert(self.last_attn_slice.shape[0] == end) - #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + self.last_attn_slice = { offset: attn_slice } else: - # no need to grow - self.last_attn_slice[start:end] = attn_slice - #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) - - elif dim == 1: - # dynamically grow last_attn_slice if needed - if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - elif self.last_attn_slice.shape[1] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) - assert(self.last_attn_slice.shape[1] == end) - else: - # no need to grow - self.last_attn_slice[:, start:end] = attn_slice - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - if dim is None: - weights = self.last_attn_slice_weights - elif dim == 0: - weights = self.last_attn_slice_weights[start:end] - elif dim == 1: - weights = self.last_attn_slice_weights[:, start:end] - attn_slice = attn_slice * weights + self.last_attn_slice[offset] = attn_slice return attn_slice From 688d7258f185035a5924bee23322b001574da931 Mon Sep 17 00:00:00 2001 From: damian0815 Date: Wed, 2 Nov 2022 00:33:00 +0100 Subject: [PATCH 2/4] fix a bug that broke cross attention control index mapping --- ldm/models/diffusion/cross_attention_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 1a161fbc86..9c8c597869 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -68,6 +68,8 @@ class CrossAttentionControl: indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 + cls.inject_attention_function(model) + for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): m.last_attn_slice_mask = None m.last_attn_slice_indices = None @@ -76,8 +78,6 @@ class CrossAttentionControl: m.last_attn_slice_mask = mask.to(device) m.last_attn_slice_indices = indices.to(device) - cls.inject_attention_function(model) - class CrossAttentionType(Enum): SELF = 1 From 0cc39f01a3c67f32873b7888e1e032d33626e269 Mon Sep 17 00:00:00 2001 From: damian0815 Date: Wed, 2 Nov 2022 01:18:50 +0100 Subject: [PATCH 3/4] report full size for fast latents and update conversion matrix for v1.5 --- backend/invoke_ai_web_server.py | 4 +++- ldm/invoke/generator/base.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 6f6837301d..b9fc147911 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -640,9 +640,11 @@ class InvokeAIWebServer: if generation_parameters['progress_latents']: image = self.generate.sample_to_lowres_estimated_image(sample) (width, height) = image.size + width *= 8 + height *= 8 buffered = io.BytesIO() image.save(buffered, format="PNG") - img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode('UTF-8') + img_base64 = "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode('UTF-8') self.socketio.emit( "intermediateResult", { diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 534fe2547c..3c6eca08a2 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -119,19 +119,19 @@ class Generator(): # write an approximate RGB image from latent samples for a single step to PNG def sample_to_lowres_estimated_image(self,samples): - # adapted from code by @erucipe and @keturn here: + # origingally adapted from code by @erucipe and @keturn here: # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 - # these numbers were determined empirically by @keturn - v1_4_latent_rgb_factors = torch.tensor([ - # R G B - [ 0.298, 0.207, 0.208], # L1 - [ 0.187, 0.286, 0.173], # L2 - [-0.158, 0.189, 0.264], # L3 - [-0.184, -0.271, -0.473], # L4 + # these updated numbers for v1.5 are from @torridgristle + v1_5_latent_rgb_factors = torch.tensor([ + # R G B + [ 0.3444, 0.1385, 0.0670], # L1 + [ 0.1247, 0.4027, 0.1494], # L2 + [-0.3192, 0.2513, 0.2103], # L3 + [-0.1307, -0.1874, -0.7445] # L4 ], dtype=samples.dtype, device=samples.device) - latent_image = samples[0].permute(1, 2, 0) @ v1_4_latent_rgb_factors + latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors latents_ubyte = (((latent_image + 1) / 2) .clamp(0, 1) # change scale from -1..1 to 0..1 .mul(0xFF) # to 0..255 From ef17aae8abe03fd87dc6f91b3c6d2fe03375f6c5 Mon Sep 17 00:00:00 2001 From: damian0815 Date: Wed, 2 Nov 2022 00:39:48 +0100 Subject: [PATCH 4/4] add damian0815 to contributors list --- docs/other/CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/other/CONTRIBUTORS.md b/docs/other/CONTRIBUTORS.md index 4f16d0fe0b..e016685405 100644 --- a/docs/other/CONTRIBUTORS.md +++ b/docs/other/CONTRIBUTORS.md @@ -60,6 +60,7 @@ We thank them for all of their time and hard work. - [Dmitry T.](https://github.com/ArDiouscuros) - [Kent Keirsey](https://github.com/hipsterusername) - [psychedelicious](https://github.com/psychedelicious) +- [damian0815](https://github.com/damian0815) ## **Original CompVis Authors:**