From 10db192cc4be66b3cebbdaa48a1806807578b56f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Sep 2022 09:26:10 -0400 Subject: [PATCH 1/2] changes to dogettx optimizations to run on m1 * Author @any-winter-4079 * Author @dogettx Thanks to many individuals who contributed time and hardware to benchmarking and debugging these changes. --- ldm/generate.py | 27 +- ldm/modules/attention.py | 94 ++- ldm/modules/diffusionmodules/model.py | 935 +++++++++++--------------- 3 files changed, 482 insertions(+), 574 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index ce54b04eba..3a81087d5f 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -35,17 +35,7 @@ Example Usage: from ldm.generate import Generate # Create an object with default values -gr = Generate(model = // models/ldm/stable-diffusion-v1/model.ckpt - config = // configs/stable-diffusion/v1-inference.yaml - iterations = // how many times to run the sampling (1) - steps = // 50 - seed = // current system time - sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms - grid = // false - width = // image width, multiple of 64 (512) - height = // image height, multiple of 64 (512) - cfg_scale = // condition-free guidance scale (7.5) - ) +gr = Generate() # do the slow model initialization gr.load_model() @@ -86,6 +76,21 @@ for row in results: Note that the old txt2img() and img2img() calls are deprecated but will still work. + +The full list of arguments to Generate() are: +gr = Generate( + weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') + config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml') + iterations = // how many times to run the sampling (1) + steps = // 50 + seed = // current system time + sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms + grid = // false + width = // image width, multiple of 64 (512) + height = // image height, multiple of 64 (512) + cfg_scale = // condition-free guidance scale (7.5) + ) + """ diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index c2f905688c..0dd957b407 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,20 +1,20 @@ -import math from inspect import isfunction - +import math import torch import torch.nn.functional as F -from einops import rearrange, repeat from torch import nn, einsum +from einops import rearrange, repeat from ldm.modules.diffusionmodules.util import checkpoint +import psutil def exists(val): return val is not None def uniq(arr): - return {el: True for el in arr}.keys() + return{el: True for el in arr}.keys() def default(val, d): @@ -83,14 +83,14 @@ class LinearAttention(nn.Module): super().__init__() self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) - k = k.softmax(dim=-1) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -132,12 +132,12 @@ class SpatialSelfAttention(nn.Module): v = self.v(h_) # compute attention - b, c, h, w = q.shape + b,c,h,w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) - w_ = w_ * (int(c) ** (-0.5)) + w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values @@ -147,7 +147,7 @@ class SpatialSelfAttention(nn.Module): h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) - return x + h_ + return x+h_ class CrossAttention(nn.Module): @@ -171,41 +171,66 @@ class CrossAttention(nn.Module): def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - device_type = x.device.type + k_in = self.to_k(context) + v_in = self.to_v(context) + device_type = 'mps' if x.device.type == 'mps' else 'cuda' del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40) - del q, k + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - del mask - - if device_type == 'mps': #special case for M1 - disable neonsecret optimization - sim = sim.softmax(dim=-1) + if device_type == 'mps': + mem_free_total = psutil.virtual_memory().available else: - sim[4:] = sim[4:].softmax(dim=-1) - sim[:4] = sim[:4].softmax(dim=-1) + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - sim = einsum('b i j, b j d -> b i d', sim, v) - sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) - return self.to_out(sim) + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, - dropout=dropout) # is a self-attention + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none @@ -233,7 +258,6 @@ class SpatialTransformer(nn.Module): Then apply standard transformer action. Finally, reshape to image """ - def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None): super().__init__() @@ -249,7 +273,7 @@ class SpatialTransformer(nn.Module): self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth)] + for d in range(depth)] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index cd79e37565..a6cefc82ad 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -8,6 +9,7 @@ from einops import rearrange from ldm.util import instantiate_from_config from ldm.modules.attention import LinearAttention +import psutil def get_timestep_embedding(timesteps, embedding_dim): """ @@ -26,19 +28,17 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb def nonlinearity(x): # swish - return x * torch.sigmoid(x) + return x*torch.sigmoid(x) def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): @@ -46,14 +46,14 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): - x = torch.nn.functional.interpolate( - x, scale_factor=2.0, mode='nearest' - ) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x @@ -65,14 +65,16 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) def forward(self, x): if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -80,15 +82,8 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -96,47 +91,60 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - ) + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - ) + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -144,12 +152,10 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x + h - + return x + h8 class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" - def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) @@ -160,87 +166,116 @@ class AttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm( - v, w_ - ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) + q2 = q1.reshape(b, c, h*w) + del q1 - h_ = self.proj_out(h_) + q = q2.permute(0, 2, 1) # b,hw,c + del q2 - return x + h_ + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + device_type = 'mps' if q.device.type == 'mps' else 'cuda' + + if device_type == 'mps': + mem_free_total = psutil.virtual_memory().available + else: + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 -def make_attn(in_channels, attn_type='vanilla'): - assert attn_type in [ - 'vanilla', - 'linear', - 'none', - ], f'attn_type {attn_type} unknown' - print( - f"making attention of type '{attn_type}' with {in_channels} in_channels" - ) - if attn_type == 'vanilla': +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": return AttnBlock(in_channels) - elif attn_type == 'none': + elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type='vanilla', - ): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: - attn_type = 'linear' + if use_linear_attn: attn_type = "linear" self.ch = ch - self.temb_ch = self.ch * 4 + self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -250,80 +285,70 @@ class Model(nn.Module): if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) + in_ch_mult = (1,)+tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -333,16 +358,18 @@ class Model(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution + #assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -364,7 +391,7 @@ class Model(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -375,10 +402,9 @@ class Model(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): + for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) + torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -395,27 +421,12 @@ class Model(nn.Module): class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type='vanilla', - **ignore_kwargs, - ): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): super().__init__() - if use_linear_attn: - attn_type = 'linear' + if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -424,64 +435,56 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) + in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): # timestep embedding @@ -495,7 +498,7 @@ class Encoder(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions - 1: + if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -512,28 +515,12 @@ class Encoder(nn.Module): class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type='vanilla', - **ignorekwargs, - ): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): super().__init__() - if use_linear_attn: - attn_type = 'linear' + if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -544,52 +531,43 @@ class Decoder(nn.Module): self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print( - 'Working with z of shape {} = {} dimensions.'.format( - self.z_shape, np.prod(self.z_shape) - ) - ) + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -599,87 +577,103 @@ class Decoder(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] + #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + device_type = 'mps' if h.device.type == 'mps' else 'cuda' + gc.collect() + if device_type == 'cuda': + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): + for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h class SimpleDecoder(nn.Module): def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList( - [ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock( - in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True), - ] - ) + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) # end self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): for i, layer in enumerate(self.model): - if i in [1, 2, 3]: + if i in [1,2,3]: x = layer(x, None) else: x = layer(x) @@ -691,16 +685,8 @@ class SimpleDecoder(nn.Module): class UpsampleDecoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ch, - num_res_blocks, - resolution, - ch_mult=(2, 2), - dropout=0.0, - ): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): super().__init__() # upsampling self.temb_ch = 0 @@ -714,14 +700,10 @@ class UpsampleDecoder(nn.Module): res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - res_block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -730,9 +712,11 @@ class UpsampleDecoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): # upsampling @@ -749,56 +733,35 @@ class UpsampleDecoder(nn.Module): class LatentRescaler(nn.Module): - def __init__( - self, factor, in_channels, mid_channels, out_channels, depth=2 - ): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor - self.conv_in = nn.Conv2d( - in_channels, mid_channels, kernel_size=3, stride=1, padding=1 - ) - self.res_block1 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) - self.conv_out = nn.Conv2d( - mid_channels, - out_channels, - kernel_size=1, - ) + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate( - x, - size=( - int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)), - ), - ) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -807,42 +770,17 @@ class LatentRescaler(nn.Module): class MergedRescaleEncoder(nn.Module): - def __init__( - self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1, - ): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder( - in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth, - ) + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) def forward(self, x): x = self.encoder(x) @@ -851,41 +789,15 @@ class MergedRescaleEncoder(nn.Module): class MergedRescaleDecoder(nn.Module): - def __init__( - self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1, - ): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): super().__init__() - tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder( - out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth, - ) + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) def forward(self, x): x = self.rescaler(x) @@ -894,32 +806,17 @@ class MergedRescaleDecoder(nn.Module): class Upsampler(nn.Module): - def __init__( - self, in_size, out_size, in_channels, out_channels, ch_mult=2 - ): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size - num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1.0 + (out_size % in_size) - print( - f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}' - ) - self.rescaler = LatentRescaler( - factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels, - ) - self.decoder = Decoder( - out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)], - ) + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) def forward(self, x): x = self.rescaler(x) @@ -928,55 +825,42 @@ class Upsampler(nn.Module): class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode='bilinear'): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: - print( - f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode' - ) + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=4, stride=2, padding=1 - ) + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) def forward(self, x, scale_factor=1.0): - if scale_factor == 1.0: + if scale_factor==1.0: return x else: - x = torch.nn.functional.interpolate( - x, - mode=self.mode, - align_corners=False, - scale_factor=scale_factor, - ) + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) return x - class FirstStagePostProcessor(nn.Module): - def __init__( - self, - ch_mult: list, - in_channels, - pretrained_model: nn.Module = None, - reshape=False, - n_channels=None, - dropout=0.0, - pretrained_config=None, - ): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): super().__init__() if pretrained_config is None: - assert ( - pretrained_model is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' self.pretrained_model = pretrained_model else: - assert ( - pretrained_config is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' self.instantiate_pretrained(pretrained_config) self.do_reshape = reshape @@ -984,28 +868,22 @@ class FirstStagePostProcessor(nn.Module): if n_channels is None: n_channels = self.pretrained_model.encoder.ch - self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) - self.proj = nn.Conv2d( - in_channels, n_channels, kernel_size=3, stride=1, padding=1 - ) + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) blocks = [] downs = [] ch_in = n_channels for m in ch_mult: - blocks.append( - ResnetBlock( - in_channels=ch_in, - out_channels=m * n_channels, - dropout=dropout, - ) - ) + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) ch_in = m * n_channels downs.append(Downsample(ch_in, with_conv=False)) self.model = nn.ModuleList(blocks) self.downsampler = nn.ModuleList(downs) + def instantiate_pretrained(self, config): model = instantiate_from_config(config) self.pretrained_model = model.eval() @@ -1013,23 +891,24 @@ class FirstStagePostProcessor(nn.Module): for param in self.pretrained_model.parameters(): param.requires_grad = False + @torch.no_grad() - def encode_with_pretrained(self, x): + def encode_with_pretrained(self,x): c = self.pretrained_model.encode(x) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() - return c + return c - def forward(self, x): + def forward(self,x): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) z = nonlinearity(z) - for submodel, downmodel in zip(self.model, self.downsampler): - z = submodel(z, temb=None) + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) z = downmodel(z) if self.do_reshape: - z = rearrange(z, 'b c h w -> b (h w) c') + z = rearrange(z,'b c h w -> b (h w) c') return z From 75f633cda887d7bfcca3ef529d25c52461e11d99 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Sep 2022 12:03:45 -0400 Subject: [PATCH 2/2] re-add new logo --- static/logo.png | Bin 0 -> 22220 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 static/logo.png diff --git a/static/logo.png b/static/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..fa0548ff78c7171d850ee6451b6a36f1559f5f3e GIT binary patch literal 22220 zcmc$_WmKF&(w~M)@ zJ;a070%Bw9Bu;hc-_wSD932GRCa)3Nlmw-m9Wmz1Ct^A-gJ zI6^$kX}ujCoZLmd#p(YcR}^^t{4qB@?LV4$*o)IkKMzO?R@I<=?d%4j72@LMwB+I8 zp%oJ0f>@dJS@2tP(DL!{@pAJ5KOjzCK2d%iQ4k;Pzy8q!!?{^mi)zWr|7$Sdoj5(z z!^1_Co7>CFi_43j%h}C_n^#0cgqw$tn~#qZXu;|3K29DH;75l?M3fiMnh-mnJ|3QbHdO@-!OGmj{J+=O z*~-@1=fBfbRaI2c$=$=;$r7R{D^3qgip$p4N>l(MA|z-bV!;XG;}rz{fOt7AM0l+@ zg$4M{t*nJbtOTqC{xL70zpS&R=kxYG_y4EeS~*(+GXAqeqUM4S9&-ypOHN@SAxlnv zOOO?(1;`4-DIx^2;^!0K1pzPqElu6c7C1WQ4*#pv=UG_+68S9!tU=ZSe4IiemX@4W zf*^iQ5fMv4PHSFaAwGz(xw(jdFg>l6rKq*Do1-~kC$^5}HV|$XCmVX&|2do2&JNCQ zs?Juxp7TFjGawyU#T#39Ky^OnAj zZL1~J+{p$4cusEm=f!Y8yO@6}h5LWk*FV1dU)<5Zn*%=X`Q?ASEAZw&{u$x~h;{?K zugXS?KOCG!v7)TBj`x>?uRiHax;{^m-&32f+SJs&MkKHhyQ_ZDV+`aLG3K(YwOLr; z3oze~_b?*Rr$gz}NmXS#k#vaR;o%_)v}K4`-F$a#=e~EBi)kh9fk6|ksv|1uS1RgW z%YNTSwta3xb_|U@OPnSfmd>)gUg#U$(n8FrUHDuGtKguhTL?o^#v&;BqtISPcr3W) zU}`h=Dr#@)8R7A;YU{tkfh!Ix0p2p6VzkzOXUOBMPbc80I$SONFENfOT}Y1cvKm#{ z$2q|^2-)yU5jJwRpbUohma&rA@VOli9o)f^D%0?BwsS|Yj`0r5j_;!u1*4QZAOsqe zK-@AEi9yicL7k&xdNKYJ=9LN7r3U%DVw`xy!1(R;+uu5e-X%zTq+*cPg6k*pIQd$* zMx1K~mpPq%7dE6lnv=4UvqCdd^0=*E>sTwSD}L{r^SWP@5wDo9#B@`I%jbP2x|cK~ zs&c7$m0o7nitUXpY+r*q1qIb65`^f7T&Zs*ykr2=AX39HPbZ+DjMq$X3poBa8c)X> z$a~zu6u1|(z@Ypnl1Ps`8jmX)NPCza3Anbx&mT%GqqKQeQP7l`1t4Ab=SSv^LP6{9 zBFEFHwX2V7L4$;JMPFdC^FvC@jh%s6f}fYtBCcODF{BQ#8SUbmHWJ zuFNZ|ClZwERgXB%qD6otewf9|GVb|m2vRBL_2CaJ1)xpuwz9Xt+2y)asOR7hY`N!N zTlLj}BB^FyM+hN$_8{^hMgz3_*|v5{Q~j_<4FVToexpqoEhf9(5XV^4bPJ)`sqbTaowJ`6URSJEVUS>ju7F;uU<} zAvbs8EJ=IWs-=)Lq33y~M0V``5}*guur2>n+7ha>N1#Nj|N0*&2vr6x#GI5gG6WAr zrsYLyezJ^GbTXF`YqA!Ea_5ITxp$}ls_@SiOF#}>q3zn)&B+}i(0ZNl*Y?@QnF!^} z_sLGK$|1HHGxe&pcs=!wD&t+Qlx|HN5i-=Q5Q7m&j4iTS&9QG}G`b+5)bU^~jKLbT z_Y6RDwiF8}^kZj8AaEpj#3&iQs(l*V1jII0V6uecfm0(K@6b@f!ew5NQ)ECL?Y~!7 zf4;p%xqpP4b%^v3jYxLm-YZiQ8A6C>yJ7)l%tmU(NYTDgRps3ANGvxLsWeRs4Z z4HA9BNA@yp(2^9T(AFZu+Vw5V@jhlmG6t)$(KIIp8B&JOg_H8p`W%89F7?U~4|_X|AM`rTmU?XRQK8iR^3_ z;ZWNQ?}`LB)1V_UwN*3qgz95+qAN}^$*XN*4U}AET*Lg`MUSvx+mCINyAb_D7ob10u(9LGq z*0?r?dW5(npsk>zWMLq9ae;k)aZ%-SCZDKygi_vx3Dy&PSxQfc!5{P}Bf9mDmooO# z$aZf_D$@%n6iT0>RAt`%HZn3&z7TU0KPs_Xm=tYna&f5T5SdY1k?-^G7w+oGhj+cb z*14nbCcZm9dX*-y!^P&~8-H1P0dT6oX8f59ojUys1zK<;6#eBF7rs8~X8gfZUJwZE z$Q`-O@y)U@zA(9~{$WCvUhLtV;Qr)9AqpEi&RoGV1og1syNyrelgu=6`fY7 zB5QG_a}P*0%V$e5%m36I*iJ=8t=IPU6dgaxPiN~B3JZzzlO1MjqA&KQPuF5cPU;)J z4r(e;y%v*D5&c?{YQF%5=j8YNF&1|5ivjnzd3jOui@ng$#A*UK94}vsHPCOnoH5aL z>ERHQg%Bf8FS$j?Ut1fqvh9>l=<(DyoQLgj4wLSeb*-4oQ$UZ!YyTBQn$DvuZ5F}- z@gIsg$z34Tw}$gGb@r%N%TKRN0z9!pq=GLm<3BlCTTo^}7NfLc;xead9p|mpL^e6U z1YG)n1nThOa8_Dt+2)`{bu5uB>OxAalJsTvz+l)f~( z2Jwk2wK;7l^q+Vn+7tSy{h}@}ahA>twaF+cNvpre|?$Bh6_G zHee*#3QBT3q_Sv=N8fVu)c!R7k<#6c{oz0=h=t2mf2q6$3acbdI_tt5aXTOFGcypZ z*n?Zz9i^f1i^|E7o|_Zp!B!mkmfQPM*!nX(4&$gfjaDD2;GEw`t(i!acH}7v>@?10pYIZ%??qY|0;3#OkI^QXC&|M$H~c&`NA8 zvUsEnGksA0SfGp@m3A~oQBO+{&fcEAzH7Sm7b<}Ez8f%{{n^^gyab7QJ0xYLw7H)s zSagjyg-?)4%#lA66QRpTl|@p+iBLhIhL3N0nl@7@-Y^riJRBS`5>wD;kLX!kht)7I zS=#}p$cXFx-cPwC(lN5@^X!=)S#}-vToZcvCHh=innB6^rp8w84~xi@(`Sf^JBcIP zysAf${8(W!ot>V>9D{cQ?=4~~{U@w%ZipRN&g%4I-ZXSB2Tp9wy~40M!)PD5E6PY+ zE=Zo3n6N~w0AF*J`rBefmhdLIU2h>NARU zoL>mIIv_mtSvna$BO@m-d82IKi&RQjzfwTqJx+)vEr^Q6qJuI(gQ*h`1uw@p^a~N^ z1BRxC-$TWDE{moMAr{M8HVm-kY!o|mToqqL`MZEft0dxf%@}-9Nr@h%{1MH=? z&PP&cu7B1^x(G^S-;T~7J*cXdf6JR(`LT{QXNYBoV20+x)T0&S=tpSyuIU_$#f>Z) zf(d+s&^G2N)>wV_08dUqaq~cP*b7n}7P(r)!a85qOIHWxx+f)k*fS69?Q=Ef8}Z;B z%;Upb@_BUv>u=+9%g4;;K z>U66ci`WtWKO?Babzqkd=+f}E$gId@n-5QI*Wc``rKgpWY9&|J?ZaZWO+2nA#gUy_ z#781Ixu_0MR#$qubd^ZlDBr(2YKZ{cCougYQAX_u#m zD?&usm+@dBtdU0Jb{LEsNk5RY$GyFH)&2?3tP>yco$>?l2vT!Y&rU|V1D2fxhta+bTWA$u*j5}bIKT^|kb!0OCw zK}oW#a>A5!U)5Yk#M_Z+NwHL~T@0Cfv|~8ujk(17M$Wt@kB0_B!|6NrxJoKXZ-Sz< zEg`{d!E;>Zd4su|%+n?mTb&g4;3?ahz}VR2EL;aMCHfYd5w64bCu*4-o=9Hwgh$Xf zWaYfij528SpUA2G_Tir|lh?1;!%IU_(Z$0O+X7*dPpVP#fmAu})eTA&k%R*)#|#V+ zHm$|RSuU*y{;|4O6UU#tZf9${nG7{lK;21NV+#udCAuRuz87p685wN!9IpY*@uMee zO-u^`*Y6&1EhfQ!q@6tV60pmO;k4gKr#t%#p@res8LP!<9K>pV5p3f+neM7|iaDN4 zK%H$eOdR6Pff}U7K@7w$XApbVyOF{=S*cfY@}B38RDk!@Cd(($h)aPM&#Z+>e>}L; zTu!XwdNC;cro;5sygXKkCEcu^=~$*k4otg?oGVSci$u?ex7GOmqDtyOi-+ZAn&$oW zojUStQ(wVKk!!$PCB@x6ikY?mcJ8)oO2FSj8fpf(z?1di#?dO5;+?+@HqD*^s zi$ABq&6h;v6n)uT=i+F#*h!E!){REXF z-zn+5m97l|JJ4r@dmZ3_u*Tho$Jv{9-_vp3XiLt&8wmTk1UU2Op@Ecgd_&go%kp6r}y`ri_nq<$WZ12&K6?pSp9Gwe5= zB`>L{eR@~>(D|29KUgx=lE^7vGfqG9 z{NCkE+%|Csi;f7JiREj>yvdJTZ1y_cEiDyU|KR@1iVuRx{FS#eQOlfk>&>VLT?Y>l zuB7QoYuZdzNASK%+!k6VN@yr*mM04#VI51;%RXB}%GBbR@rxnz^_*phB>qQP%Krp?!z0*GrxT)PZw!nV-xs?79)=FZFKn!w*_xzME<- zxY%SBGIydbudW|GiFLFPVD|{_Koa)-CyHk573VLjGb99Qa*E7fo9r)73|J$Uk8IX^>ll;J zfmo~;+w#{I(Le~zlOtY+bZa17b_$|J*45bKSlamg-(JdZf4^9+I1BXMJMH2{St!9pY@UL+V@JxpOi(*+C4{m z^YxC_vo)5{uv`L@(jBv2q(N7ViRtND|C{PE!z*$vHviOb6)1r>Qt$@UofUytf9-_O zZaY&#z5UaHh!|5h9+Lxo{Vh;(GFG4#`sh zL7yCG*m5)ETu@t^nUixWMrmn2FJaOB+F`D-4LsuLuW+}un5h0(oUOgI9j1C10<&wA zwn$i0m&n{m;3jVQdmbFZ-WT*KS)vpjot!v*pG?F5ktV67es~l0iSpctH|#DpUyIh4 zY5Cee*mObP>?&Bc`ox3!aRI3!}Jrvq6hDtF!p}LB3XTYkhRzdxIJ)GMg_xf>%p^FF83K z0mX^FWlYh6jJ?8@GhL>ho8NYsGB~2m9?{oXJkenJmCBj9L-Ut`lyVvV>o9AmbXd@J zK8=w;ebOTGC>;wH>+r2hqV66BN7j+St3$ewP-&rHUD6S3>Lbt2RBT;7nz!!8G&w3I%nFbm)9dm$Y6 z5G3Fy;I5zD#LlNY&86KQa)PU(uG*jI0jC-zkIK%RfZb95aDI8noTQe(4wmVtFoE4I zIn|yI5$~to#ik%=BBqBXJfytTQ{Kva&!Lm&%7o+}T7i>9vy{A99&?+4*iMQN5z#()2H0;uo)m(cH5Prc!m^ zUQ{N45mQcy%ZZi7Kbg_~>Vhc?2W~2|6p+SaAEYU1mZaev$9uJ8%Ec1bS`XP-R2K<} zv#i%~e|8Zd_G3aTjd-U33})z{m%v*ikh zGflO9i-n6@Yd@_<%*m$yYMiWc%8V2R$fn0Z2GwicdS${{D$EC;ABgw_iivfqt#xSe zI(SspPJ|<59)o>MBh5c};2rK}jD#Y`_FM76tcf<8Z%xXzrD<3CY8>#Rg^P7+a(gg3 z>RYR^aW2=d?)nx5Kk5D60|`HYFmvsG9tC9-)XSS?pqh5u(dUP6iAsam zZklPVuMaZm(P)YS7E&Ag4+C1zg)U=I>urC2R2qj*;UA9N=OP*1Rdex8gUr z59w%}|LFq^mGK7->*%Qo4dE4IyH9@b2zQ|sn=o-8MSkA(=4<4(JGDVSk^N>(nbG== zjc_LBD{*5|{i+uG6@eyfq)`6AC!HmGmAqmNm2t*TQAyd~rkZqu>4bgTgal}`W67~h z4A)QZ-=*!#c@#t(&;Rt`JvG2^&D9uQ2k(lR$3-xDvx?5>c2#ec<@xua2L?FgOf zjqWoyV4ihI)nSW<#7S~y;rsw%NYrSpl9>qKd(?SXqE*@nx$x;DR}Lh=-7>0IrdTrG%Zm;H=-!-F^y!7-PBQi8QWD?)k|moA)l?To+QJ4RP&hA8 zKDW?=;`#4g@Vs%Z)`<;&cu<%mbXNg#dFLH$-^jTPNiV=hV!O(WY3q*wqzrpUNP zrdPai%D>JZNlOdrz1fB#q28ug-fc&=E02 zV9m=8(#`M`ix04ITr8F;2M=@=4ft7CN&ve-kC1_|t_kIZG8N?4*T?%y^c^%TcMfv5 zN8zCawgn&fW(g~=_}cgNk5KHOxCg@dzy^$n8_!0N2};vIo#|Fy$>eS}Po zhCklm15e2dO{lHyX0o>0c7Bo?kQpTnQP0>VYvE&|pht8)#~2iXWekFZFxN%J@2ziG zt2r|xd*xGU0yr>FT!xMEnnAwMSV>Bt&Rz#XSpWy zeZWW^*?x`I2`uopGIx1yUY>iUR-3l!4&7X>jnLDbU71hR2xy^>_ZHJYkzuJ z6uarM6t0LJ*=A>!SHJYd%=hTubQ8PV*Z$%+?9+1bp*E`Ty+WdSf)xg`dKLXAw06rV!dUbF_Lb^T2NG4x*h$JdNV~| zZ+vz(&K3g$8Foo6!bzks{}%b;g){(4`d#jCb)p81z(|4+xCp+uUnS<|=8kXfwJ$C3x&YkiWd2o9m~Fi`Fj-NyL3-h5RC%?+qC;6$5Nh9szi|+hMf+JV98o(-ANsk6UI#}6J}Bj4-FzgqG6T)#9LW#JgK zI9zbjEBCd|RXssQYLGBmQO?Gdih9T_OYhL@pO@HJKh>4yda#Z3`N~``V%@SMZTXo8 zPkW=rL0R8~Q&%~qY;1j|@`Fb-1BLboB)7!1zJB(@U>R%to2sqj$ zrGkREUyrinRC(1pFLT}uC_LP+^hy6>PcLDuMhBrUomnQ~`ZI)$Tq5x&De{8|A&>VF z`C3ayiTy|bK*>mo<;&!}7SqIuqW{zmcSeGyBhuWzA^`oB9(rW-b@Q-~YA!RUurPJ0 zjr!f;LT9d=vLd5=8sf21Z_DK_Ib~nx(OHijZK3ycrV25Yqp$M(`&%hA4fxKCG$Xxo zFI%Tz4Hg;aBWe@VJ2vCRGn)`IX~E2b^75~5@>4_C&1h)h!SGZ=C<&5B$9<@nUv6H| zl+Bxhbl_#%Hw@EIXKNHrW1HPQhI!F$Vfh0?U}6AGTXXFn;lg#Fg4J(BX-cU{q-2km zy|$uZjIrC9>r?9kR%0GsLwUu;KKY~@D*Sl@i^@5exunQt`74eu-d za{Hp_mqhzj7c7^SoP1a`rwLr<88UQ|TFX${lI!ov4#TZIsPpXK&SMa^$oz$lw!PGa zH!#gIQM8n5X0ZyJuH^VtldSlppQI#1q<{;q`IgZJ$k>M|x=fmq2$S$R?4%rZTW;1_ zO=dZ4?US$lB>-eiffz3S{NsBS!PkcKlt9|?TuIs8-NkJzhxibm7EIgtY={fyqEy!3 zLNTXVE-`+r+&*Dwn>xD;VUyiUOM5kkeEEhB`{fjFlTyCCPmgzhdiO%&)`!L+#bmTC zAe$52gB(k2{kS{R4J1VNVZ7WSf6mTkT8#cvMokl(^~C$l2(BohETV3Eb=l(9X~Lzw zt_y@!dDKmxzU80@7gGn!C`BF9n*cCe*GFM3O$wB0Q!B~Fj=O7s3i$Oa~VyC^t! z|I^_(S!VG>=RuI^=m1H&u$(#bi;shT*eD$C*|@!PE!P&Lj`9V31DNE5AMTi`yhxcp zauj~~K(0`1x8op%1{BJyPS)~l?R-NW1jCktRxAP|gEPoG(<+BgD{38giOGsM-F};d zu$AE1rl|*XjzriP7zJ2n83tuXtx5 z+$*>`p!wDeQpTcwl*`&X{5?9!`?#Br7b+_@Tmu>0+09Td`Jnv-0I$~8=>sG{Fbiej)NQNY&kC8Y~`_bW>yf^Yq z4F(FMo$+XsL>736pBCrn#y=*cXiyVV6=OY|5S=ax771>2K_5&o={!ExFetY*E|<;| zkqyUV#P9x{-Tsk?J^m$etlGjo;@@4K>s3yJihExaQLG+sbAE(mxD=hD1N{0%9-Z?J z3MpynY+1*5n$1jXBiHKg)y*Ig;U`PQET74}5aBFvh>bB32C|yy4k$E-=f?l$UjHB9 zET~u~W?b!67=TOTNB4w#EDs%Y(XE} z-XP0_2+z^-fAhY(!(X`5$&x!`J`919qJ)J#a7f;Q<6j*jMMe%Rzts8%HY>)?J`8A8 z4e3PNH_kFRgP4!|ew1~*1tY0~ysbxnUU+hvV z1Xi#`x{j8L7wW+bRKPb5Y+4EEK3MqtJ0(^I zHX#opjn8mi5hgeq4BOVF-2~jC(cpO)BC2A-t13m-_cRR)&nImoG)B}0-6!LQ`5`7fCNJC?Rf`@z_y{; z??7Tq9tHQY(-55lF8;LTBod+2I~xW08mWiqm5}n`<1$Bko6yMmh%hNJN0xIgA%WAI@m?~p!;d2>RXv_g)HKV z;E#(KPbYg~)!`Ou?mN0nKnXn0iBSL_TjG=*zS8OPwjmnQJFzLGMhp)kcWCT0r$uqk zm}URz%b{>YLG@K3fEfvd`39}WD%65#Nz3=_(K&G16IW3z0fg0UX7(+ch+S<({K$s} zsCw`h&en{y3JMmOtuURbjv2rluQnSWu6$BCc}3QdV#n2a4CZYAj9UFMVn=(A1ja~t zf%Tmbu{-AAAyHpQ`LS_e^>Z54yZuD4;O`?)6-Vz7?!uLr67kx5iE~AA@{GSnN9Cl| zp=5+*MqJ83KzeZC?tQZfRPZoIXHgPb|DM*zYJuv>S#0b|1kf;RSZC(gZJmkJBWY9A z2%=@7&e{BljXT0PpIpL$sIddM=0S>$hh!=&(kQT&|wAYQl1=rCp|Js3g1%Lx7s7j$LB^hyV(=#QyPhELXW| z)|yDXKHv2ASqf)ZJHCbz9}poMF8dM#XCIZOB`q?tFbtpbIDGJkfFsLh?ajBhmc=?!>}K6(Zsn-pSYY-b_gY;3A|j__q`-{#}2!s`&?A%2h||D2H|MA#1Va zj!F-|I`W^J+XBvV{D*!Tv?`AszURWel458eA|0W=v8e&BsdnDCI>nX7to}d%2_Ksj zNbvUet~@ZBX40W41b|w7Q$JX#&Kz9x9n;aROh!FUckF|RY$rZKe(=|3#(HdQPlRbA zVYQIqVJC;X73$v%M07|%inzLJzI0J>m7BZ~f+{x3ri&m#8bc-wnnEYv>-jaoo1}^L z(nyeyum+lP;}9}(DJp2VwZwdu<4+W6jvp)MYWzMSVBZ}mmLs&b`r)kZ+r*Pyn8Pey zQ+iJDe(_qmS@4{8oS2x>Bq+{dwvvL1I=z!-PJ)w{Y$Z0=SR+wkWRlnIBLD4e!dyUS0ei*9-7-~1Wm7} z@l-`XoD{8w0aGHfa0EiXRbp?iJlE}t!p{D_wV4^5<%c$}9i7C~RK2CUCCr;m?n)cr z@ZLtE$iOFkdpfoPqlK(DV&PokV?{Qu1dzOlp1_McG-VGGqz?xoYdIy@tWA6D?XCP_ z<2v*0-5u1)3BCwTdE$Zo9Rb_h74=v%`Fg_V^gBMpW=7Ra;}jpv=h8vNR*Iwccp zrCCsqVAV`PV*9-!Rzl8^)1u;H$WjZkho@(f+IPk8+OUi*KZXKjM@Mlb)@MG9paWRp z8AlVan%$_y}hq72I9D|C&G#hZ*$W$9ftINQSUfD-A)A?ukI2tk0jAB*s?x#!TyjsI`k zBeTvfSbCu%hoH`mkdJ|;(&(Q*QR~8DkJqoCSylPtwm{{ln&U25m%0$Kn+GVl0SLm_ z(o(0$c1tXQx)@|5url-MoX$a9gq&pw)h}ET?Hka%-P$!d_lOlh+|kD^q>o;&Ajlp2Pjx& zMn3r8vjdp{-T`0=>pz|9{1e_*figaudDNPW{{S}}>H?`}BsRrVVt+pRdj_O1I;`$= z{1becm7iSwR-q1D%9gJQ2Gc5609dC0;Gy+feX-&*z}B0D;{T)uipPk_9O+|(zB^-^ z90oNXR5Pb*ZP$GT9p|KgRF3j{@W^)G5w4lR6diiZ-cP`TI0!1|XaEIsKIbI_V$Qb% ziA;lk_rE9rYCtzA;;TmRfR~K3mS?@OS1v8VB;i;sD(HJgB!<7v&o;SJG`~AE6Q*WJ zQA!+yuzJd{srD=1X(;6;ztjO@831W|&*1Uy-T0(Hc9VPavsi#~SFDp564HioGN?w) z2_{!1V~HL}wAIvo_AMDyu6iZDjLH49cDXRqb~z#BY60eYg?3Yd8R91FVk@5NE!J zt8v4%sR%u|fM&1;udBlq1E}%V%u?=n_*Op9us(yyY*>A+E|l+8dI`RcPHFBuiMO|i zmgz3r$v#V6O-8a6L>~_+YJ(sEv`_od(1KBD2@0@^xI()>|L@)u{fx1P9D{=1Czz)P%P2(vgv-sEq+KqFB%NdccLJbN zS^7GFqKJZuI#X>)aU%P{gH2T;8|F!QM3gfM#lf7ve1V1p5G^Mr1^&#;$XZ$wXtCi( zn^X1`H^#prq62Fw0cdVrT^*p0lZ9IK8rw6J%l%nZfGe3jvi_4xV^Nh>AVVn!#6;ND z@k1pvgcc~pdUy`m0Lj>MHUOYt096YkLI&m$%~cA_2w)*?ZRr8--Ur~T8Psmc`TD+N z2$yCd>^iyvc#d@%wxIEFEsSZ;ouebmyUTKbdZx``)EJ(^u5Ja3S%;VAAeM_djn$h4 zxKN4!?SnG@7~qGU{$dC+8_6;T*oOc|?t@3~B)}n!#X}gZsVYMQNCFjmH2|T}B1J3mCoeA(`h#ZLbI?vK`5D>m<9n7bB>>R)`Jr}{+2a(C>R=Qg zbC0AkJ+tg)etPl;oxY@5*+^RceSC~J>rn5Y-9tvg!Dj~mgmF?BC?jMe4#rL@;Y6^fwlM?~ouT{+WbSC*U z1UOM!=j>QKN{zDcU=_7bZ`f*(KjQy6dZqRDU>IGQ_lzTo$Qf5 znnxEQAPP7eX~&yRA;h>L%cdTFMjipxvM}9TR0eedVgWXm_Ls+l)^hXUy{=&sMY>9; z9iPY+OH!(6$O!%<=~3aR4i>^kDYox3i!Tk@brT1bc6X6Du~C!`I8C&a03s7mg%vyf z^$;#&J}k1m2vFgMuV-vjs<6fx0S0Ctl~6UX&0D`eH)m%np{bYE?C(#eE8J8_@i!UHQiEgYW~lVn8QiHVkOo0TC9owS_24T*>(dN&+svurM;3thy?+llH{>1HfaY0jPVBn@go? zUaVhMM9sd8K_{JXl8+{&b4U*FG!g$2I>uT5{KXo@9v=FWxpTUmYOQs9#{ar9v21Ch z$|@wY#hNCBIHNpyW+Pari?y_21-#;s6J)tSG$wni=JNs%Z31|765W>X(<8r z=yOuAJzE1I@bvWLCyoRixUF2F0=~7(;EX;;!}|#y&6)ffhKq9j{^-3~n(Rd}h9dbG z)cOUGj@J#tmN9T2Ya(Hb%!Xb^+;0IFZN=BQJ6Q_U9((mLzZ*Z}=6loaYU(3S08I68 zuV?6`tR~vKCD}$2DIFcd%F{QS18~kwI5a4jI$6nMrn$E00X~rS&k* z3S#!S!cYyIkNWrR3D?EZr1>>>@9CN7O>TSWOj+j=wabRK33j*O1fWogG0aP{RwFRE9 zMUZaT|Br7}f=Z_tD9sT9S082&eg_DC1hAos>-E3G*5pH0qoJNfEO>$N7m+WuTovu4 z#wa{VKd_^$Cpwtq1P60MO9Y^0L_Y*bCqS7R&a|f(7Wg#-_Q(~$wTHn8Mh!#A2^NKq zV?I|j@#gBQwWrGRY3JYoi}#s^OTqdQzo|7PZ}+Ud&AoCzSSrBvDL^@Y&x4Inb_=uwR6j z@0W{eg|4&tq8JPAyELBlb9kDELC!e2gvC1)hSC50n2^ExS(H8(q*Oc#mH!j!NlShFm?|@?xkl%1QNSZFx7DPWB(ibA5yT^sU$+;y6>j$I*Uqt<%D~XkSqqCy z6&5|q96~VjZ(BZLr%!V9oHxA*0D(AGlRAK9Ta|zTVWC)(Gp=xMlz0j%Z>Z}EOIc1h z%Wf#_XNR3Ab0nV@dTYC}AXkJIz70zi4a07xT6&73dt9wUGrg-t>-w${rltNuAVU+{ z$+j9=MDoX>W;dF0`QA0O&t$l^sym;{VMQegrj1{Le83$2VU6O6%4Sb3jmG%x`*h21 zm(bBMs1nfCKzEUuk?85R8|mqG5U6lp!-%zb9W4aN^<@}og#$kg)djub$4go`{_4BW zMNrp0nOiW*GamrX^VcoS=02|AmcRL)%>*qH0*g-fq$Hhbs}dJ?fXp{#VwssQJv#79> zye(VOT%j_JIh}L&0h`fb{eaq93)HGt8R40al)OEY3GsUJOm2U&4F!g`a?1RepQ+TYw0`(z_Lr(}-`r(0djE}}{92B( zC}0gkpM+A9h-PBnMlxmJJAE*I;4*!*z_$1$hrRR5duV^{7cqyf%E%rc9jfH5Yhq9n zog!bQuysg;nYA7qE=gZR>RZHirt4dEf^&H${lGk8b9S9<9KS!mspA(|b1>-{%|5~A z^f8Q%H(5;YMuc-dni-_+AcyFELYz<0Ai(Fyv^jk+O!7b6)gvS3*)$ z+F#Fo!vLID&W8GU_ENFU>vf{Hn(IRDZgokojMA||XiBU@;0-wPkz2{>!9@JghuN;# zaI%ty4o(tJ%ZwywXC=BCe@Q7krvuJ>HJ~5oBZ?bBL4a>IKIl|dr13JyO`>?>U* zhvIx@cnxT7(ujcZ{(O#5-z3|Ng~32`Q*RPnWbqdkb3dqil|LW{zQIl>1G-xcwNO8L9J^n6Av;Gzm;(GpAaMz-Eh0 z8JJ;wiU*elnxVALH#eQS$>Q1UU6L{_Bi;A?TO^qN>wV~lK5T=)LkE^`%Ne>#MDgM= z6}$``4qZi_ob!yXhJ1(pobjp$r6I*-@9WOBY%kj4Nc280TqJkK?jgUrLyCm=q?1$C z7{!xlH2#TyYguKR)_Ig2aw(r=Nuv08K9tC>=K3N|%GC8IxEP^Agt0`wN=aAhm8yoz z!APS26^UZ5tuyi;JcU$UPg)Ed4hQ~Oq7?tXhOh6pKW?mN<1qLd>Et4GQOA6`%Qcba zbnu^D`RS)Z?bAuCA3TKu3w@_Pgd<=!4f-ayflh_MC*tj$(#cSYr6%6H;eO;QvZ_P! z366*&Z>g}tpd5~x3z zb9mQab`OgqYx}$IH=PZ;cn2_i0sb4+G|F8n3^$i;W8(J4KC^W$X`50D#W10-nI&@; zd7lQ(oH?QTR&K{%b~s}YcwHU6@?6RI_?5V0XC?9neXsMto}{oxeaqQ zA*Hzbi%Y-tFHT12ffVsyE5QxQ$4%@GhqvYELj##v!($|db=o(?Q&GB*V7hDAQ-0&= zPb5E!#LY8KEX&Kcx|Vb)`mY}wk%=4bO;g*r6`C$g$%`_byWV--7Rutw%1SewJM&%d zEHbzIb#-V~K`XlO_!)HaG`;(b{#1~D=$16GKN*k@Ykx`(`=i}cl3&Bg(E!!uteI(= z-CQ@NTC05YBqae?e}n4o1b05Lh*YW@S)Kgu6|)I4|9DY~Ij)^6?>l5A_Z@-h8ph2j z;TbVhj!W?)nQ}LI`qtA3S`-pN6_ZKLd))&O+%pyUrcCGmSN4=Gmk(Jvgr(ps=<|Zh z>x|lV(#_$gLAP_U*wot{b_4%>jT)|0;D99j{Bo5-mleQ>ipj~1v+F`LW|1BW)&5iF zy~dlQZfbiLmZ@^8RsBQR+>qZ{mkN!wb&F>3ClQ4&Qg-ve*6=G%sAd$3xLg@An|hU1 zb?vD|?BYCQ_RTGx>9q-#0!#c?XS(qE5**~d2$;L-snW?(d#u|PC)wZn7Tu|S69x2K zy-p)o;RVX_&$%u3C#UK_63WSfq6D3I;4w~CXQQ*XL?OXzhLS=}r{okbU*fr^RP4S* z)``b0-6*va>fpP+;2!tC+eWX){Yfdl+y$zuw5t*3@@ODE$5E7UgW*lf8gt$FZ~Vz@ zh{w?Wb%9oK7uP?8n}KCxW|C>6soOo+lcGTr=^}Nq_a571(ZL;-{fa|Wjo$a1&4Njr zER^ZQAN?z_xIL!-*~{*Q(>IZ+@?q=5-lMyd$Fw?p@tEb3^T=O0V*q*|rgl14KEhz|<{4q8FLL%n{E)uH+UcXD zC!Vc_CKdE;bh2($a#v>RZ(b>~1rP0v!Ju6O-N7vb@LWlTmJ!KrS7?DTv;L!~qUpru z%d}~_t@I*sOOGwqf^iTFTIXP?N=577M)YK8M%3D0r-5|=SVW&OzL7^Upl^ImJTNZb*|`jav#<;KD-F~5*`7;0G6dx8ujVh+Li+Llw)nv z{uxYa+ypQR#(_(LYe%$OpdY{etLLFE{p=~tSxwLEIZSvK>r1M%;`-Pk`Z1?k`N)$d-49EL#d>?f08l`}`8FY_9S9v)UnEeNjd-?@_^;(W}XpqJHw+P=7q9Dy=DeI zSN8Xrg_qY0DS}>$wtGfBe0j?RdBpxhIz7s3j z%k{J88Sq~*R`o?mTu}f(Zx$1oRA!~tF%sYP-T6gsa+l3)?b{zpO;LlfrBT^!V_TQ= zoKV?H`e9lCw6&xU?{L0*V>RY*v5Bh&A;Z#&c4V+HtCfLASE8rT{u+;Q#Y}zJrO|s>DiRABM(em=3XK*biJyz zGEIj$%rur*%SV+b)suDV;D!I?fWHlcq%HyH#zQ3Kv8Vw<3=@G-oF8e+yd&zEO9ebA06fu?< zgH#OJ5}A=EyDUWvhL9~}%QmuQNIggt-#h(&|9$?rpSzrU&bjw>-|zFrj3x=d1`I%{ zX#)(vi9Z+U7&|-3v8j&7=9-nKeBV?Xeb4XV(hbec;Ht{D;~lZ-f`Zr1m1kE?-Gx}P z-G`MS_>011B^a#p$+#?L`_i-Z+6!dD*?p=l-q!exuO<$xPK1wLfs?VwTk(V_azbCv zkS(x3Z@0uYPa^V^Y@TamVK2Y^%)**&SGgfr-8GEID?1p@%jAYAD0672C^DIQ;rk|} zlJW^mBc^=djs30+pv<0^hhNSi4@&C4Sg%hSJc?y_F~_Zq*>ois&?a#e?I$>y4Zi{F-ZH2%PRGLon981{HkH;1`ra)c*pNRx3?fg8 zW+^9y#C9J?W8pK9ki1lTAP-{s@z zRF6%2Dv@M}4MzC;oiI^>!eA>7%VB8OE>{j)q zVN-K6M%uWKlKy+e*GIx#^CbkWqc>L1=iI0$!#oG^l&c6~pQ+_daszIW5PF8FDp1K0 zLB^TF zb&a#G)KvADcnpWE{?(vz1bqL%#&K2r-N{7E0P(jP^7 z^fGB39$qq4Bu?!?EZbd)7&771hHGvik=BbheEX7CO;D>2M8w2CN-$xh8Iym=loLI% z(l8IP=RNH+1i8wQO^^VNGu1L-2_w#=q+jvFs#jU?a>`t2os+d+@l3Odad-m;qkM%=9u$+a=2b!1GRv^|}|)xAU& z76c#b4i$<*V6P+ex2LW(=VhOh9lj1$Spt21^0cMo)QH~8!Y2Y1e*_)jfxC9R@^9z9 z%;-s5@~+Al`#qNd35q}gEzP`@0%c#Hth={-){YETD;oFfR!f*_QA%Y&$)~EmdDrf=Ffj^Cvx%y-DL4s07JGadA2Hdf4vwK& zIoe8N-G6KngFd?`4%gRTOtrl0 zMh|}NhmXbM%QfeGrI6K8A?TP{%Rs|RedFP>YP;;ZvG{@Do5XhnnxEUIFuc0D*|o>X z?UrRa<~t>v`6x`a0&6{)FX4I=41!Fi_itBDEe>+>RPj7a-K}Rl;7FLp9eu{}o7`>2 zq^{h($Ez3PSR+qq=_~rXA3Cm+2^lEOfxUOSBI{oxJx(fsqs6&xVd!1Jt-~PPjY?p^ z6-@D%WuSA_O{vIgs%_;82pB}rl9%jle&B5@n2z8Xh)6QXXm5-M3@B;-cSECUTn469 z$u3>&`Lsbqu1bT%a2Y>2H*W^Rbp#)VPj)haNXdzSs{LfY(EQS2r2eYG^N#X|P~e&a zyvYoKqlRLWs>40{(Sk+muTK)Ig)HHUe7@5Jw)>Erp2HBOa~9b?u|m*j?nONrs3rN%SPtle^UDqO6PMm;|7oN#VBzUOI2>8F|-S~bQDMU_Nbep)lrJ%1@R zhd?PinB#NAcU&rW!0-nr;C+ecxizr|OT%-M7&?)OaRGbg5;@qwD`aWmlOW|Mw2hw+ zkU~0YQWsA=DGyNHjK^mg#(v8)?cs<5rUV-?aR?Gg7|$z)6NX~=936%15olt&0%eRc>D%y~XA|k{Wl9t8J+CRvIycmuHF~UI zMTE)F_~Jda-D{r$(BV4Ep91>wOg%~q9c}=pNs@A~G?1A6BL@(t-@2l4`X?p|kF59o z_+dRzP!W!q7d7r{wQ0Riw~h{HN-7?JS$Hot2Dw0ism-);d9#~cOkT(>U0W}ne?o8+xls7v9@ke^Sl~^f&FSr&8+TddwpMG7GK+Z52(n(+H|y9y{{K5 zQheXdW$x_9+CrdMqa6dKKIv(9ZIS-o=02P|id3~2>S+n^VV88Dzuz9MYhr;obLQ4t zfaxnM5r93!nbyOxc(a_*u`%tHD?$VbA=QvDhS4&(nN6-BI5<<>Blmt5GwnavK8snz zkSh!Q4x;Lq;#XW3uNbm{QZbGdd5^(@W-q06A)!o?O4SH{)DP5%0+O3#ig9$CU$Oh0 z!Jm^iIq<{T*#r+L?nO%3XyCU^a;Rj?#{0B!tTF81Ze8H?iJsN@wuZDN^~r{TKAaT2 zM;rXLp9fOg1jy55T593|l>KpiY_0^~UN2t6%eh(o@S5ANj4o&WTczLzq;yB#PS*Lo zsN!Z<7Q-mt*x2;WF7l3d1O!~^(p%HmIu!&(s6M}b@B^MsSWg+^0QFKKZLGcLz^>IRo-d>%YpmFEhq# zVF{Ph?~A)F9jFJ`@f1V$89}{wW-pVfYDCLs&?-fJ(A?3UPJ?5Ob((N^xwBJoP<;tN zpi}QRs@hDW4h#NRj{T7R%)q%>&o*Wnuga(OLP0Jl6nMeIp^^o4Dt7myyV}T!^74}F zD$8V;w0yiye#F|wz;K<`fVGO1n6HPUhh9w>A{nrmbVC-CY1~??p)Db^etJHE>eK3x zni&#_VmbYMY)p5K59N_=3nl=V2F#P%Mqg;~vNaTroD|DZsAk}ssZt0>q$FH&#w#bc zXPiT*T}t-DTs9xoAJE74x!bguxY$4HfAjB2X1){#g5QY;D4vvl&1^s_(+w zX3WnFC~&f{9IR2PE2Bk63>WA2Zvoc1uitlMspqc-&XWqng>-!GxFDK*6o(ss;4b(j zOan;o*i~w$q+U1>tlW5PD(U0&impMp{UBGn#9a^|vm=06W~O@#2OIyB2o zbU-D_oXJp-lM4>Yk7<%J(L2wJn3=xshcrJO=b~WEM5#!YidTF9J`@%XMdSd<7&O4c z3&J(wMn)$awuH(?veK`$|J)8q1N^f=-9rS8)RJOcV16E9*-4S3^}saL=cF`hk^}0E zl0YK!0jM6(GFWJN-G3xQ6_$ny@dXfUWq zi6fC>O0$+y z;hyCFVUAB!cmVP|t0MxD_1GaBE)7AF;>j?_FbJq?HFI`mP3cm&cTWemP1WHDRw-`c z-#6B{boC8e&E$zB_u*NcB6w;x4Fhhg$|0@M0ZAo?Kuol*eSwTWfFjOktlm`Q%mj-t zGCBt+Jt`Mp4}YAt1VjX!>F`9O&!v7*t9Aq(9$L)&FIr#2L`C>7!J;}y)D+5o|MQC% zPQ{+-DtT9~h4JJQ;(`vb-XA1?MLz|c^(FP?210j9iid~*n(;G-i^IF`R11)eiEHzx z&-;-Pm(vL{p(4qFYH#Nm?3uCS$BVzJ_r99#f4lzBZL^t6j9*E$`PF6nTQJ^<)P3Z(Nq^*J znhbta@%&F?0ry=^{fP=S3N3Hd`&Vr(@ExPe#gR*_&KD_q=-~gCU<-_4`aMbtL`mPj zS*f3z4G?Chqci&-NjIZk;F>Gxj#f?_(?r<1DO&9}MEgX|30VuO?Qnf}zE~m{hc33; zz>?8Kd1|_>jkF<9VS@q|DX!*MXKspAij7t$=S?v7(CcZZ-2=9p>WpX)TWIdxa+7)4 zAIpr=+U|f*&49t~v~d51x)#{3ghZ;?ZYrE*wCs?zsPn<@0qpK&zToz4l->9@f=9B= zY}=|{92^x#$S0uY0aHkG^1gk)zxcMiP4m*auX;9#Lj}PZ&1U6THyR=W09TEsc1Bc6 zKL2f^;4gtEwH{r+^R0>#Tfj4R7sM145saS%;>@x+2; QOVgXZ0=G1-G<1vqA1xx5L;wH) literal 0 HcmV?d00001