changes to dogettx optimizations to run on m1

* Author @any-winter-4079
* Author @dogettx
Thanks to many individuals who contributed time and hardware to
benchmarking and debugging these changes.
This commit is contained in:
Lincoln Stein 2022-09-09 09:26:10 -04:00
parent c85ae00b33
commit 10db192cc4
3 changed files with 482 additions and 574 deletions

View File

@ -35,17 +35,7 @@ Example Usage:
from ldm.generate import Generate from ldm.generate import Generate
# Create an object with default values # Create an object with default values
gr = Generate(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt gr = Generate()
config = <path> // configs/stable-diffusion/v1-inference.yaml
iterations = <integer> // how many times to run the sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // condition-free guidance scale (7.5)
)
# do the slow model initialization # do the slow model initialization
gr.load_model() gr.load_model()
@ -86,6 +76,21 @@ for row in results:
Note that the old txt2img() and img2img() calls are deprecated but will Note that the old txt2img() and img2img() calls are deprecated but will
still work. still work.
The full list of arguments to Generate() are:
gr = Generate(
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
iterations = <integer> // how many times to run the sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // condition-free guidance scale (7.5)
)
""" """

View File

@ -1,13 +1,13 @@
import math
from inspect import isfunction from inspect import isfunction
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat
from ldm.modules.diffusionmodules.util import checkpoint from ldm.modules.diffusionmodules.util import checkpoint
import psutil
def exists(val): def exists(val):
return val is not None return val is not None
@ -171,41 +171,66 @@ class CrossAttention(nn.Module):
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k_in = self.to_k(context)
v = self.to_v(context) v_in = self.to_v(context)
device_type = x.device.type device_type = 'mps' if x.device.type == 'mps' else 'cuda'
del context, x del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del q, k
if exists(mask): if device_type == 'mps':
mask = rearrange(mask, 'b ... -> b (...)') mem_free_total = psutil.virtual_memory().available
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
del mask
if device_type == 'mps': #special case for M1 - disable neonsecret optimization
sim = sim.softmax(dim=-1)
else: else:
sim[4:] = sim[4:].softmax(dim=-1) stats = torch.cuda.memory_stats(q.device)
sim[:4] = sim[:4].softmax(dim=-1) mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
sim = einsum('b i j, b j d -> b i d', sim, v) gb = 1024 ** 3
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
return self.to_out(sim) mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
@ -233,7 +258,6 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None): depth=1, dropout=0., context_dim=None):
super().__init__() super().__init__()

File diff suppressed because it is too large Load Diff