diff --git a/README.md b/README.md index 6b0daeab43..d102a2dfc3 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,36 @@ -

Stable Diffusion Dream Script

+
-

- -

+# Stable Diffusion Dream Script -

- release - stars - forks -
- CI status on main - CI status on dev - last-dev-commit -
- open-issues - open-prs -

+![project logo](docs/assets/logo.png) + +[![discord badge]][discord link] + +[![latest release badge]][latest release link] [![github stars badge]][github stars link] [![github forks badge]][github forks link] + +[![CI checks on main badge]][CI checks on main link] [![CI checks on dev badge]][CI checks on dev link] [![latest commit to dev badge]][latest commit to dev link] + +[![github open issues badge]][github open issues link] [![github open prs badge]][github open prs link] + +[CI checks on dev badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github +[CI checks on dev link]: https://github.com/lstein/stable-diffusion/actions?query=branch%3Adevelopment +[CI checks on main badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github +[CI checks on main link]: https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml +[discord badge]: https://flat.badgen.net/discord/members/htRgbc7e?icon=discord +[discord link]: https://discord.com/invite/htRgbc7e +[github forks badge]: https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github +[github forks link]: https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion +[github open issues badge]: https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github +[github open issues link]: https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen +[github open prs badge]: https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github +[github open prs link]: https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen +[github stars badge]: https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github +[github stars link]: https://github.com/lstein/stable-diffusion/stargazers +[latest commit to dev badge]: https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900 +[latest commit to dev link]: https://github.com/lstein/stable-diffusion/commits/development +[latest release badge]: https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github +[latest release link]: https://github.com/lstein/stable-diffusion/releases +
This is a fork of [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion), the open source text-to-image generator. It provides a streamlined process with various new features and @@ -26,7 +41,7 @@ _Note: This fork is rapidly evolving. Please use the [Issues](https://github.com/lstein/stable-diffusion/issues) tab to report bugs and make feature requests. Be sure to use the provided templates. They will help aid diagnose issues faster._ -**Table of Contents** +## Table of Contents 1. [Installation](#installation) 2. [Hardware Requirements](#hardware-requirements) @@ -38,38 +53,38 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss 8. [Support](#support) 9. [Further Reading](#further-reading) -## Installation +### Installation This fork is supported across multiple platforms. You can find individual installation instructions below. -- ### [Linux](docs/installation/INSTALL_LINUX.md) +- #### [Linux](docs/installation/INSTALL_LINUX.md) -- ### [Windows](docs/installation/INSTALL_WINDOWS.md) +- #### [Windows](docs/installation/INSTALL_WINDOWS.md) -- ### [Macintosh](docs/installation/INSTALL_MAC.md) +- #### [Macintosh](docs/installation/INSTALL_MAC.md) -## Hardware Requirements +### Hardware Requirements -**System** +#### System You wil need one of the following: - An NVIDIA-based graphics card with 4 GB or more VRAM memory. - An Apple computer with an M1 chip. -**Memory** +#### Memory - At least 12 GB Main Memory RAM. -**Disk** +#### Disk - At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies. -**Note** - -If you are have a Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in -full-precision mode as shown below. +> Note +> +> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in +> full-precision mode as shown below. Similarly, specify full-precision mode on Apple M1 hardware. @@ -79,45 +94,31 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag (ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision ``` -## Features +### Features -### Major Features +#### Major Features -- #### [Interactive Command Line Interface](docs/features/CLI.md) +- [Interactive Command Line Interface](docs/features/CLI.md) +- [Image To Image](docs/features/IMG2IMG.md) +- [Inpainting Support](docs/features/INPAINTING.md) +- [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md) +- [Seamless Tiling](docs/features/OTHER.md#seamless-tiling) +- [Google Colab](docs/features/OTHER.md#google-colab) +- [Web Server](docs/features/WEB.md) +- [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file) +- [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds) +- [Weighted Prompts](docs/features/OTHER.md#weighted-prompts) +- [Thresholding and Perlin Noise Initialization Options](/docs/features/OTHER.md#thresholding-and-perlin-noise-initialization-options) +- [Variations](docs/features/VARIATIONS.md) +- [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md) +- [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api) -- #### [Image To Image](docs/features/IMG2IMG.md) +#### Other Features -- #### [Inpainting Support](docs/features/INPAINTING.md) +- [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting) +- [Preload Models](docs/features/OTHER.md#preload-models) -- #### [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md) - -- #### [Seamless Tiling](docs/features/OTHER.md#seamless-tiling) - -- #### [Google Colab](docs/features/OTHER.md#google-colab) - -- #### [Web Server](docs/features/WEB.md) - -- #### [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file) - -- #### [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds) - -- #### [Weighted Prompts](docs/features/OTHER.md#weighted-prompts) - -- #### [Thresholding and Perlin Noise Initialization Options](/docs/features/OTHER.md#thresholding-and-perlin-noise-initialization-options) - -- #### [Variations](docs/features/VARIATIONS.md) - -- #### [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md) - -- #### [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api) - -### Other Features - -- #### [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting) - -- #### [Preload Models](docs/features/OTHER.md#preload-models) - -## Latest Changes +### Latest Changes - v1.14 (11 September 2022) @@ -149,12 +150,12 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag For older changelogs, please visit the **[CHANGELOG](docs/features/CHANGELOG.md)**. -## Troubleshooting +### Troubleshooting Please check out our **[Q&A](docs/help/TROUBLESHOOT.md)** to get solutions for common installation problems and other issues. -## Contributing +### Contributing Anyone who wishes to contribute to this project, whether documentation, features, bug fixes, code cleanup, testing, or code reviews, is very much encouraged to do so. If you are unfamiliar with how @@ -166,13 +167,13 @@ important thing is to **make your pull request against the "development" branch* "main". This will help keep public breakage to a minimum and will allow you to propose more radical changes. -## Contributors +### Contributors This fork is a combined effort of various people from across the world. [Check out the list of all these amazing people](docs/other/CONTRIBUTORS.md). We thank them for their time, hard work and effort. -## Support +### Support For support, please use this repository's GitHub Issues tracking service. Feel free to send me an email if you use and like the script. @@ -180,7 +181,7 @@ email if you use and like the script. Original portions of the software are Copyright (c) 2020 [Lincoln D. Stein](https://github.com/lstein) -## Further Reading +### Further Reading Please see the original README for more information on this software and underlying algorithm, located in the file [README-CompViz.md](docs/other/README-CompViz.md). diff --git a/ldm/dream/args.py b/ldm/dream/args.py index 0b6cfda4cc..fe7ca9ffe2 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -2,7 +2,10 @@ The Args class parses both the command line (shell) arguments, as well as the command string passed at the dream> prompt. It serves as the definitive repository -of all the arguments used by Generate and their default values. +of all the arguments used by Generate and their default values, and implements the +preliminary metadata standards discussed here: + +https://github.com/lstein/stable-diffusion/issues/266 To use: opt = Args() @@ -52,10 +55,32 @@ you wish to apply logic as to which one to use. For example: To add new attributes, edit the _create_arg_parser() and _create_dream_cmd_parser() methods. -We also export the function build_metadata +**Generating and retrieving sd-metadata** + +To generate a dict representing RFC266 metadata: + + metadata = metadata_dumps(opt,) + +This will generate an RFC266 dictionary that can then be turned into a JSON +and written to the PNG file. The optional seeds, weights, model_hash and +postprocesser arguments are not available to the opt object and so must be +provided externally. See how dream.py does it. + +Note that this function was originally called format_metadata() and a wrapper +is provided that issues a deprecation notice. + +To retrieve a (series of) opt objects corresponding to the metadata, do this: + + opt_list = metadata_loads(metadata) + +The metadata should be pulled out of the PNG image. pngwriter has a method +retrieve_metadata that will do this. + + """ import argparse +from argparse import Namespace import shlex import json import hashlib @@ -540,17 +565,20 @@ class Args(object): ) return parser -# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266 -# it does not write all the required top-level metadata, writes too much image -# data, and doesn't support grids yet. But you gotta start somewhere, no? -def format_metadata(opt, - seeds=[], - weights=None, - model_hash=None, - postprocessing=None): +def format_metadata(**kwargs): + print(f'format_metadata() is deprecated. Please use metadata_dumps()') + return metadata_dumps(kwargs) + +def metadata_dumps(opt, + seeds=[], + model_hash=None, + postprocessing=None): ''' - Given an Args object, returns a partial implementation of - the stable diffusion metadata standard + Given an Args object, returns a dict containing the keys and + structure of the proposed stable diffusion metadata standard + https://github.com/lstein/stable-diffusion/discussions/392 + This is intended to be turned into JSON and stored in the + "sd ''' # add some RFC266 fields that are generated internally, and not as # user args @@ -598,6 +626,9 @@ def format_metadata(opt, rfc_dict['type'] = 'txt2img' images = [] + if len(seeds)==0 and opt.seed: + seeds=[seed] + for seed in seeds: rfc_dict['seed'] = seed images.append(copy.copy(rfc_dict)) @@ -611,6 +642,27 @@ def format_metadata(opt, 'images' : images, } +def metadata_loads(metadata): + ''' + Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266) + and returns a series of opt objects for each of the images described in the dictionary. + ''' + results = [] + try: + images = metadata['sd-metadata']['images'] + for image in images: + # repack the prompt and variations + image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']]) + image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']]) + opt = Args() + opt._cmd_switches = Namespace(**image) + results.append(opt) + except KeyError as e: + import sys, traceback + print('>> badly-formatted metadata',file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return results + # image can either be a file path on disk or a base64-encoded # representation of the file's contents def calculate_init_img_hash(image_string): diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 5a573c8aa1..7010829c22 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -4,7 +4,7 @@ import copy import base64 import mimetypes import os -from ldm.dream.args import Args, format_metadata +from ldm.dream.args import Args, metadata_dumps from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from ldm.dream.pngwriter import PngWriter from threading import Event @@ -177,10 +177,9 @@ class DreamServer(BaseHTTPRequestHandler): path = pngwriter.save_image_and_prompt_to_png( image, dream_prompt = formatted_prompt, - metadata = format_metadata(iter_opt, - seeds = [seed], - weights = self.model.weights, - model_hash = self.model.model_hash + metadata = metadata_dumps(iter_opt, + seeds = [seed], + model_hash = self.model.model_hash ), name = name, ) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ec96230b46..ef9c2d3e65 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -90,7 +90,7 @@ class LinearAttention(nn.Module): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -167,101 +167,85 @@ class CrossAttention(nn.Module): nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) - - if torch.cuda.is_available(): - self.einsum_op = self.einsum_op_cuda - else: - self.mem_total = psutil.virtual_memory().total / (1024**3) - self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2 - def einsum_op_compvis(self, q, k, v, r1): - s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - r1 = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 + self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - def einsum_op_mps_v1(self, q, k, v, r1): + def einsum_op_compvis(self, q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + + def einsum_op_slice_0(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + + def einsum_op_slice_1(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) + return r + + def einsum_op_mps_v1(self, q, k, v): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - r1 = self.einsum_op_compvis(q, k, v, r1) + return self.einsum_op_compvis(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 + return self.einsum_op_slice_1(q, k, v, slice_size) - def einsum_op_mps_v2(self, q, k, v, r1): - if self.mem_total >= 8 and q.shape[1] <= 4096: - r1 = self.einsum_op_compvis(q, k, v, r1) + def einsum_op_mps_v2(self, q, k, v): + if self.mem_total_gb > 8 and q.shape[1] <= 4096: + return self.einsum_op_compvis(q, k, v) else: - slice_size = 1 - for i in range(0, q.shape[0], slice_size): - end = min(q.shape[0], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - return r1 - - def einsum_op_cuda(self, q, k, v, r1): + return self.einsum_op_slice_0(q, k, v, 1) + + def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return self.einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) + return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + + def einsum_op_cuda(self, q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 - mem_required = tensor_size * 2.5 - steps = 1 + def einsum_op(self, q, k, v): + if q.device.type == 'cuda': + return self.einsum_op_cuda(q, k, v) - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + if q.device.type == 'mps': + if self.mem_total_gb >= 32: + return self.einsum_op_mps_v1(q, k, v) + return self.einsum_op_mps_v2(q, k, v) - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = min(q.shape[1], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return self.einsum_op_tensor_mem(q, k, v, 32) def forward(self, x, context=None, mask=None): h = self.heads - q_in = self.to_q(x) + q = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) - v_in = self.to_v(context) - device_type = 'mps' if x.device.type == 'mps' else 'cuda' + k = self.to_k(context) * self.scale + v = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - r1 = self.einsum_op(q, k, v, r1) - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = self.einsum_op(q, k, v) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) class BasicTransformerBlock(nn.Module): diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index a3598c40ef..78876a0919 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -3,6 +3,7 @@ import gc import math import torch import torch.nn as nn +from torch.nn.functional import silu import numpy as np from einops import rearrange @@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim): return emb -def nonlinearity(x): - # swish - return x*torch.sigmoid(x) - - def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) @@ -122,14 +118,14 @@ class ResnetBlock(nn.Module): def forward(self, x, temb): h = self.norm1(x) - h = nonlinearity(h) + h = silu(h) h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(silu(temb))[:,:,None,None] h = self.norm2(h) - h = nonlinearity(h) + h = silu(h) h = self.dropout(h) h = self.conv2(h) @@ -368,7 +364,7 @@ class Model(nn.Module): assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) + temb = silu(temb) temb = self.temb.dense[1](temb) else: temb = None @@ -402,7 +398,7 @@ class Model(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -499,7 +495,7 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -611,7 +607,7 @@ class Decoder(nn.Module): return h h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) if self.tanh_out: h = torch.tanh(h) @@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module): x = layer(x) h = self.norm_out(x) - h = nonlinearity(h) + h = silu(h) x = self.conv_out(h) return x @@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module): if i_level != self.num_resolutions - 1: h = self.upsample_blocks[k](h) h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) - z = nonlinearity(z) + z = silu(z) for submodel, downmodel in zip(self.model,self.downsampler): z = submodel(z,temb=None) diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 2cb56a14a0..60b4d8a028 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -252,12 +252,6 @@ def normalization(channels): return GroupNorm32(32, channels) -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) diff --git a/scripts/dream.py b/scripts/dream.py index 20a2d87e65..289a89e8ad 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -8,7 +8,7 @@ import copy import warnings import time import ldm.dream.readline -from ldm.dream.args import Args, format_metadata +from ldm.dream.args import Args, metadata_dumps from ldm.dream.pngwriter import PngWriter from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid @@ -245,10 +245,9 @@ def main_loop(gen, opt, infile): path = file_writer.save_image_and_prompt_to_png( image = image, dream_prompt = formatted_dream_prompt, - metadata = format_metadata( + metadata = metadata_dumps( opt, seeds = [seed], - weights = gen.weights, model_hash = gen.model_hash, ), name = filename, @@ -272,7 +271,7 @@ def main_loop(gen, opt, infile): filename = f'{prefix}.{first_seed}.png' formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images)) formatted_dream_prompt += f' # {grid_seeds}' - metadata = format_metadata( + metadata = metadata.dumps( opt, seeds = grid_seeds, weights = gen.weights,