From 70aa674e9e10d03eb462249764695ef1d4e1e28c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 11 Sep 2022 10:34:06 -0400 Subject: [PATCH] merge PR #495 - keep using float16 in ldm.modules.attention --- ldm/modules/attention.py | 4 ++-- scripts/dream.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 0dd957b407..a8756c4875 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -181,7 +181,7 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) if device_type == 'mps': mem_free_total = psutil.virtual_memory().available @@ -213,7 +213,7 @@ class CrossAttention(nn.Module): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1) + s2 = s1.softmax(dim=-1, dtype=r1.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) diff --git a/scripts/dream.py b/scripts/dream.py index b1c62dcbee..df02a4837f 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -185,7 +185,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): continue if opt.seed is not None and opt.seed < 0: # retrieve previous value! try: - print(f'last seeds = {last_seeds}, opt.seed={opt.seed}') opt.seed = last_seeds[opt.seed] print(f'reusing previous seed {opt.seed}') except IndexError: