mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge PR #495 - keep using float16 in ldm.modules.attention
This commit is contained in:
parent
8748370f44
commit
70aa674e9e
@ -181,7 +181,7 @@ class CrossAttention(nn.Module):
|
|||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
if device_type == 'mps':
|
if device_type == 'mps':
|
||||||
mem_free_total = psutil.virtual_memory().available
|
mem_free_total = psutil.virtual_memory().available
|
||||||
@ -213,7 +213,7 @@ class CrossAttention(nn.Module):
|
|||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
@ -185,7 +185,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
|
|||||||
continue
|
continue
|
||||||
if opt.seed is not None and opt.seed < 0: # retrieve previous value!
|
if opt.seed is not None and opt.seed < 0: # retrieve previous value!
|
||||||
try:
|
try:
|
||||||
print(f'last seeds = {last_seeds}, opt.seed={opt.seed}')
|
|
||||||
opt.seed = last_seeds[opt.seed]
|
opt.seed = last_seeds[opt.seed]
|
||||||
print(f'reusing previous seed {opt.seed}')
|
print(f'reusing previous seed {opt.seed}')
|
||||||
except IndexError:
|
except IndexError:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user