diff --git a/README.md b/README.md index 8f88a5c9d2..eec1fa34ca 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,36 @@ -

Stable Diffusion Dream Script

+
-

- -

+# Stable Diffusion Dream Script -

- release - stars - forks -
- CI status on main - CI status on dev - last-dev-commit -
- open-issues - open-prs -

+![project logo](docs/assets/logo.png) + +[![discord badge]][discord link] + +[![latest release badge]][latest release link] [![github stars badge]][github stars link] [![github forks badge]][github forks link] + +[![CI checks on main badge]][CI checks on main link] [![CI checks on dev badge]][CI checks on dev link] [![latest commit to dev badge]][latest commit to dev link] + +[![github open issues badge]][github open issues link] [![github open prs badge]][github open prs link] + +[CI checks on dev badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github +[CI checks on dev link]: https://github.com/lstein/stable-diffusion/actions?query=branch%3Adevelopment +[CI checks on main badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github +[CI checks on main link]: https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml +[discord badge]: https://flat.badgen.net/discord/members/htRgbc7e?icon=discord +[discord link]: https://discord.com/invite/htRgbc7e +[github forks badge]: https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github +[github forks link]: https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion +[github open issues badge]: https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github +[github open issues link]: https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen +[github open prs badge]: https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github +[github open prs link]: https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen +[github stars badge]: https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github +[github stars link]: https://github.com/lstein/stable-diffusion/stargazers +[latest commit to dev badge]: https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900 +[latest commit to dev link]: https://github.com/lstein/stable-diffusion/commits/development +[latest release badge]: https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github +[latest release link]: https://github.com/lstein/stable-diffusion/releases +
This is a fork of [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion), the open source text-to-image generator. It provides a streamlined process with various new features and @@ -26,7 +41,7 @@ _Note: This fork is rapidly evolving. Please use the [Issues](https://github.com/lstein/stable-diffusion/issues) tab to report bugs and make feature requests. Be sure to use the provided templates. They will help aid diagnose issues faster._ -**Table of Contents** +## Table of Contents 1. [Installation](#installation) 2. [Hardware Requirements](#hardware-requirements) @@ -38,38 +53,38 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss 8. [Support](#support) 9. [Further Reading](#further-reading) -## Installation +### Installation This fork is supported across multiple platforms. You can find individual installation instructions below. -- ### [Linux](docs/installation/INSTALL_LINUX.md) +- #### [Linux](docs/installation/INSTALL_LINUX.md) -- ### [Windows](docs/installation/INSTALL_WINDOWS.md) +- #### [Windows](docs/installation/INSTALL_WINDOWS.md) -- ### [Macintosh](docs/installation/INSTALL_MAC.md) +- #### [Macintosh](docs/installation/INSTALL_MAC.md) -## Hardware Requirements +### Hardware Requirements -**System** +#### System You wil need one of the following: - An NVIDIA-based graphics card with 4 GB or more VRAM memory. - An Apple computer with an M1 chip. -**Memory** +#### Memory - At least 12 GB Main Memory RAM. -**Disk** +#### Disk - At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies. -**Note** - -If you are have a Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in -full-precision mode as shown below. +> Note +> +> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in +> full-precision mode as shown below. Similarly, specify full-precision mode on Apple M1 hardware. @@ -79,43 +94,30 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag (ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision ``` -## Features +### Features -### Major Features +#### Major Features -- #### [Interactive Command Line Interface](docs/features/CLI.md) +- [Interactive Command Line Interface](docs/features/CLI.md) +- [Image To Image](docs/features/IMG2IMG.md) +- [Inpainting Support](docs/features/INPAINTING.md) +- [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md) +- [Seamless Tiling](docs/features/OTHER.md#seamless-tiling) +- [Google Colab](docs/features/OTHER.md#google-colab) +- [Web Server](docs/features/WEB.md) +- [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file) +- [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds) +- [Weighted Prompts](docs/features/OTHER.md#weighted-prompts) +- [Variations](docs/features/VARIATIONS.md) +- [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md) +- [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api) -- #### [Image To Image](docs/features/IMG2IMG.md) +#### Other Features -- #### [Inpainting Support](docs/features/INPAINTING.md) +- [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting) +- [Preload Models](docs/features/OTHER.md#preload-models) -- #### [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md) - -- #### [Seamless Tiling](docs/features/OTHER.md#seamless-tiling) - -- #### [Google Colab](docs/features/OTHER.md#google-colab) - -- #### [Web Server](docs/features/WEB.md) - -- #### [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file) - -- #### [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds) - -- #### [Weighted Prompts](docs/features/OTHER.md#weighted-prompts) - -- #### [Variations](docs/features/VARIATIONS.md) - -- #### [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md) - -- #### [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api) - -### Other Features - -- #### [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting) - -- #### [Preload Models](docs/features/OTHER.md#preload-models) - -## Latest Changes +### Latest Changes - v1.14 (11 September 2022) @@ -147,12 +149,12 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag For older changelogs, please visit the **[CHANGELOG](docs/features/CHANGELOG.md)**. -## Troubleshooting +### Troubleshooting Please check out our **[Q&A](docs/help/TROUBLESHOOT.md)** to get solutions for common installation problems and other issues. -## Contributing +### Contributing Anyone who wishes to contribute to this project, whether documentation, features, bug fixes, code cleanup, testing, or code reviews, is very much encouraged to do so. If you are unfamiliar with how @@ -164,13 +166,13 @@ important thing is to **make your pull request against the "development" branch* "main". This will help keep public breakage to a minimum and will allow you to propose more radical changes. -## Contributors +### Contributors This fork is a combined effort of various people from across the world. [Check out the list of all these amazing people](docs/other/CONTRIBUTORS.md). We thank them for their time, hard work and effort. -## Support +### Support For support, please use this repository's GitHub Issues tracking service. Feel free to send me an email if you use and like the script. @@ -178,7 +180,7 @@ email if you use and like the script. Original portions of the software are Copyright (c) 2020 [Lincoln D. Stein](https://github.com/lstein) -## Further Reading +### Further Reading Please see the original README for more information on this software and underlying algorithm, located in the file [README-CompViz.md](docs/other/README-CompViz.md). diff --git a/environment.yaml b/environment.yaml index 621079f5ef..eaf4d0e02a 100644 --- a/environment.yaml +++ b/environment.yaml @@ -21,7 +21,7 @@ dependencies: - test-tube>=0.7.5 - streamlit==1.12.0 - send2trash==1.8.0 - - pillow==6.2.0 + - pillow==9.2.0 - einops==0.3.0 - torch-fidelity==0.3.0 - transformers==4.19.2 diff --git a/ldm/dream/args.py b/ldm/dream/args.py index e9b70a7199..fe7ca9ffe2 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -2,7 +2,10 @@ The Args class parses both the command line (shell) arguments, as well as the command string passed at the dream> prompt. It serves as the definitive repository -of all the arguments used by Generate and their default values. +of all the arguments used by Generate and their default values, and implements the +preliminary metadata standards discussed here: + +https://github.com/lstein/stable-diffusion/issues/266 To use: opt = Args() @@ -52,15 +55,38 @@ you wish to apply logic as to which one to use. For example: To add new attributes, edit the _create_arg_parser() and _create_dream_cmd_parser() methods. -We also export the function build_metadata +**Generating and retrieving sd-metadata** + +To generate a dict representing RFC266 metadata: + + metadata = metadata_dumps(opt,) + +This will generate an RFC266 dictionary that can then be turned into a JSON +and written to the PNG file. The optional seeds, weights, model_hash and +postprocesser arguments are not available to the opt object and so must be +provided externally. See how dream.py does it. + +Note that this function was originally called format_metadata() and a wrapper +is provided that issues a deprecation notice. + +To retrieve a (series of) opt objects corresponding to the metadata, do this: + + opt_list = metadata_loads(metadata) + +The metadata should be pulled out of the PNG image. pngwriter has a method +retrieve_metadata that will do this. + + """ import argparse +from argparse import Namespace import shlex import json import hashlib import os import copy +import base64 from ldm.dream.conditioning import split_weighted_subprompts SAMPLER_CHOICES = [ @@ -142,7 +168,7 @@ class Args(object): a = vars(self) a.update(kwargs) switches = list() - switches.append(f'"{a["prompt"]}') + switches.append(f'"{a["prompt"]}"') switches.append(f'-s {a["steps"]}') switches.append(f'-W {a["width"]}') switches.append(f'-H {a["height"]}') @@ -151,15 +177,13 @@ class Args(object): switches.append(f'-S {a["seed"]}') if a['grid']: switches.append('--grid') - if a['iterations'] and a['iterations']>0: - switches.append(f'-n {a["iterations"]}') if a['seamless']: switches.append('--seamless') if a['init_img'] and len(a['init_img'])>0: switches.append(f'-I {a["init_img"]}') if a['fit']: switches.append(f'--fit') - if a['strength'] and a['strength']>0: + if a['init_img'] and a['strength'] and a['strength']>0: switches.append(f'-f {a["strength"]}') if a['gfpgan_strength']: switches.append(f'-G {a["gfpgan_strength"]}') @@ -541,17 +565,20 @@ class Args(object): ) return parser -# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266 -# it does not write all the required top-level metadata, writes too much image -# data, and doesn't support grids yet. But you gotta start somewhere, no? -def format_metadata(opt, - seeds=[], - weights=None, - model_hash=None, - postprocessing=None): +def format_metadata(**kwargs): + print(f'format_metadata() is deprecated. Please use metadata_dumps()') + return metadata_dumps(kwargs) + +def metadata_dumps(opt, + seeds=[], + model_hash=None, + postprocessing=None): ''' - Given an Args object, returns a partial implementation of - the stable diffusion metadata standard + Given an Args object, returns a dict containing the keys and + structure of the proposed stable diffusion metadata standard + https://github.com/lstein/stable-diffusion/discussions/392 + This is intended to be turned into JSON and stored in the + "sd ''' # add some RFC266 fields that are generated internally, and not as # user args @@ -593,12 +620,15 @@ def format_metadata(opt, if opt.init_img: rfc_dict['type'] = 'img2img' rfc_dict['strength_steps'] = rfc_dict.pop('strength') - rfc_dict['orig_hash'] = sha256(image_dict['init_img']) + rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img) rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS else: rfc_dict['type'] = 'txt2img' images = [] + if len(seeds)==0 and opt.seed: + seeds=[seed] + for seed in seeds: rfc_dict['seed'] = seed images.append(copy.copy(rfc_dict)) @@ -612,6 +642,44 @@ def format_metadata(opt, 'images' : images, } +def metadata_loads(metadata): + ''' + Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266) + and returns a series of opt objects for each of the images described in the dictionary. + ''' + results = [] + try: + images = metadata['sd-metadata']['images'] + for image in images: + # repack the prompt and variations + image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']]) + image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']]) + opt = Args() + opt._cmd_switches = Namespace(**image) + results.append(opt) + except KeyError as e: + import sys, traceback + print('>> badly-formatted metadata',file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return results + +# image can either be a file path on disk or a base64-encoded +# representation of the file's contents +def calculate_init_img_hash(image_string): + prefix = 'data:image/png;base64,' + hash = None + if image_string.startswith(prefix): + imagebase64 = image_string[len(prefix):] + imagedata = base64.b64decode(imagebase64) + with open('outputs/test.png','wb') as file: + file.write(imagedata) + sha = hashlib.sha256() + sha.update(imagedata) + hash = sha.hexdigest() + else: + hash = sha256(image_string) + return hash + # Bah. This should be moved somewhere else... def sha256(path): sha = hashlib.sha256() diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 372d719052..9e37c070d1 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -4,7 +4,7 @@ import copy import base64 import mimetypes import os -from ldm.dream.args import Args, format_metadata +from ldm.dream.args import Args, metadata_dumps from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from ldm.dream.pngwriter import PngWriter from threading import Event @@ -76,7 +76,7 @@ class DreamServer(BaseHTTPRequestHandler): self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() - with open("./static/dream_web/index.html", "rb") as content: + with open("./static/legacy_web/index.html", "rb") as content: self.wfile.write(content.read()) elif self.path == "/config.js": # unfortunately this import can't be at the top level, since that would cause a circular import @@ -94,7 +94,7 @@ class DreamServer(BaseHTTPRequestHandler): self.end_headers() output = [] - log_file = os.path.join(self.outdir, "dream_web_log.txt") + log_file = os.path.join(self.outdir, "legacy_web_log.txt") if os.path.exists(log_file): with open(log_file, "r") as log: for line in log: @@ -114,7 +114,7 @@ class DreamServer(BaseHTTPRequestHandler): else: path_dir = os.path.dirname(self.path) out_dir = os.path.realpath(self.outdir.rstrip('/')) - if self.path.startswith('/static/dream_web/'): + if self.path.startswith('/static/legacy_web/'): path = '.' + self.path elif out_dir.replace('\\', '/').endswith(path_dir): file = os.path.basename(self.path) @@ -145,7 +145,6 @@ class DreamServer(BaseHTTPRequestHandler): opt = build_opt(post_data, self.model.seed, gfpgan_model_exists) self.canceled.clear() - print(f">> Request to generate with prompt: {opt.prompt}") # In order to handle upscaled images, the PngWriter needs to maintain state # across images generated by each call to prompt2img(), so we define it in # the outer scope of image_done() @@ -176,10 +175,9 @@ class DreamServer(BaseHTTPRequestHandler): path = pngwriter.save_image_and_prompt_to_png( image, dream_prompt = formatted_prompt, - metadata = format_metadata(iter_opt, - seeds = [seed], - weights = self.model.weights, - model_hash = self.model.model_hash + metadata = metadata_dumps(iter_opt, + seeds = [seed], + model_hash = self.model.model_hash ), name = name, ) @@ -188,7 +186,7 @@ class DreamServer(BaseHTTPRequestHandler): config['seed'] = seed # Append post_data to log, but only once! if not upscaled: - with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log: + with open(os.path.join(self.outdir, "legacy_web_log.txt"), "a") as log: log.write(f"{path}: {json.dumps(config)}\n") self.wfile.write(bytes(json.dumps( diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ec96230b46..ef9c2d3e65 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -90,7 +90,7 @@ class LinearAttention(nn.Module): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -167,101 +167,85 @@ class CrossAttention(nn.Module): nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) - - if torch.cuda.is_available(): - self.einsum_op = self.einsum_op_cuda - else: - self.mem_total = psutil.virtual_memory().total / (1024**3) - self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2 - def einsum_op_compvis(self, q, k, v, r1): - s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - r1 = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 + self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - def einsum_op_mps_v1(self, q, k, v, r1): + def einsum_op_compvis(self, q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + + def einsum_op_slice_0(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + + def einsum_op_slice_1(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) + return r + + def einsum_op_mps_v1(self, q, k, v): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - r1 = self.einsum_op_compvis(q, k, v, r1) + return self.einsum_op_compvis(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 + return self.einsum_op_slice_1(q, k, v, slice_size) - def einsum_op_mps_v2(self, q, k, v, r1): - if self.mem_total >= 8 and q.shape[1] <= 4096: - r1 = self.einsum_op_compvis(q, k, v, r1) + def einsum_op_mps_v2(self, q, k, v): + if self.mem_total_gb > 8 and q.shape[1] <= 4096: + return self.einsum_op_compvis(q, k, v) else: - slice_size = 1 - for i in range(0, q.shape[0], slice_size): - end = min(q.shape[0], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - return r1 - - def einsum_op_cuda(self, q, k, v, r1): + return self.einsum_op_slice_0(q, k, v, 1) + + def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return self.einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) + return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + + def einsum_op_cuda(self, q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 - mem_required = tensor_size * 2.5 - steps = 1 + def einsum_op(self, q, k, v): + if q.device.type == 'cuda': + return self.einsum_op_cuda(q, k, v) - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + if q.device.type == 'mps': + if self.mem_total_gb >= 32: + return self.einsum_op_mps_v1(q, k, v) + return self.einsum_op_mps_v2(q, k, v) - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = min(q.shape[1], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return self.einsum_op_tensor_mem(q, k, v, 32) def forward(self, x, context=None, mask=None): h = self.heads - q_in = self.to_q(x) + q = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) - v_in = self.to_v(context) - device_type = 'mps' if x.device.type == 'mps' else 'cuda' + k = self.to_k(context) * self.scale + v = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - r1 = self.einsum_op(q, k, v, r1) - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = self.einsum_op(q, k, v) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) class BasicTransformerBlock(nn.Module): diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index a3598c40ef..78876a0919 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -3,6 +3,7 @@ import gc import math import torch import torch.nn as nn +from torch.nn.functional import silu import numpy as np from einops import rearrange @@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim): return emb -def nonlinearity(x): - # swish - return x*torch.sigmoid(x) - - def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) @@ -122,14 +118,14 @@ class ResnetBlock(nn.Module): def forward(self, x, temb): h = self.norm1(x) - h = nonlinearity(h) + h = silu(h) h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(silu(temb))[:,:,None,None] h = self.norm2(h) - h = nonlinearity(h) + h = silu(h) h = self.dropout(h) h = self.conv2(h) @@ -368,7 +364,7 @@ class Model(nn.Module): assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) + temb = silu(temb) temb = self.temb.dense[1](temb) else: temb = None @@ -402,7 +398,7 @@ class Model(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -499,7 +495,7 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -611,7 +607,7 @@ class Decoder(nn.Module): return h h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) if self.tanh_out: h = torch.tanh(h) @@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module): x = layer(x) h = self.norm_out(x) - h = nonlinearity(h) + h = silu(h) x = self.conv_out(h) return x @@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module): if i_level != self.num_resolutions - 1: h = self.upsample_blocks[k](h) h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) - z = nonlinearity(z) + z = silu(z) for submodel, downmodel in zip(self.model,self.downsampler): z = submodel(z,temb=None) diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 2cb56a14a0..60b4d8a028 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -252,12 +252,6 @@ def normalization(channels): return GroupNorm32(32, channels) -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) diff --git a/scripts/dream.py b/scripts/dream.py index f147008d78..289a89e8ad 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -8,7 +8,7 @@ import copy import warnings import time import ldm.dream.readline -from ldm.dream.args import Args, format_metadata +from ldm.dream.args import Args, metadata_dumps from ldm.dream.pngwriter import PngWriter from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid @@ -218,10 +218,14 @@ def main_loop(gen, opt, infile): file_writer = PngWriter(current_outdir) prefix = file_writer.unique_prefix() results = [] # list of filename, prompt pairs - grid_images = dict() # seed -> Image, only used if `opt.grid` + grid_images = dict() # seed -> Image, only used if `opt.grid` + prior_variations = opt.with_variations or [] + first_seed = opt.seed def image_writer(image, seed, upscaled=False): path = None + nonlocal first_seed + nonlocal prior_variations if opt.grid: grid_images[seed] = image else: @@ -229,29 +233,21 @@ def main_loop(gen, opt, infile): filename = f'{prefix}.{seed}.postprocessed.png' else: filename = f'{prefix}.{seed}.png' - # the handling of variations is probably broken - # Also, given the ability to add stuff to the dream_prompt_str, it isn't - # necessary to make a copy of the opt option just to change its attributes if opt.variation_amount > 0: - iter_opt = copy.copy(opt) - this_variation = [[seed, opt.variation_amount]] - if opt.with_variations is None: - iter_opt.with_variations = this_variation - else: - iter_opt.with_variations = opt.with_variations + this_variation - iter_opt.variation_amount = 0 - formatted_dream_prompt = iter_opt.dream_prompt_str(seed=seed) - elif opt.with_variations is not None: - formatted_dream_prompt = opt.dream_prompt_str(seed=seed) + first_seed = first_seed or seed + this_variation = [[seed, opt.variation_amount]] + opt.with_variations = prior_variations + this_variation + formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) + elif len(prior_variations) > 0: + formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) else: formatted_dream_prompt = opt.dream_prompt_str(seed=seed) path = file_writer.save_image_and_prompt_to_png( image = image, dream_prompt = formatted_dream_prompt, - metadata = format_metadata( + metadata = metadata_dumps( opt, seeds = [seed], - weights = gen.weights, model_hash = gen.model_hash, ), name = filename, @@ -275,7 +271,7 @@ def main_loop(gen, opt, infile): filename = f'{prefix}.{first_seed}.png' formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images)) formatted_dream_prompt += f' # {grid_seeds}' - metadata = format_metadata( + metadata = metadata.dumps( opt, seeds = grid_seeds, weights = gen.weights, diff --git a/static/legacy_web/favicon.ico b/static/legacy_web/favicon.ico new file mode 100644 index 0000000000..51eb844a6a Binary files /dev/null and b/static/legacy_web/favicon.ico differ diff --git a/static/legacy_web/index.css b/static/legacy_web/index.css new file mode 100644 index 0000000000..51f0f267c3 --- /dev/null +++ b/static/legacy_web/index.css @@ -0,0 +1,152 @@ +* { + font-family: 'Arial'; + font-size: 100%; +} +body { + font-size: 1em; +} +textarea { + font-size: 0.95em; +} +header, form, #progress-section { + margin-left: auto; + margin-right: auto; + max-width: 1024px; + text-align: center; +} +fieldset { + border: none; + line-height: 2.2em; +} +select, input { + margin-right: 10px; + padding: 2px; +} +input[type=submit] { + background-color: #666; + color: white; +} +input[type=checkbox] { + margin-right: 0px; + width: 20px; + height: 20px; + vertical-align: middle; +} +input#seed { + margin-right: 0px; +} +div { + padding: 10px 10px 10px 10px; +} +header { + margin-bottom: 16px; +} +header h1 { + margin-bottom: 0; + font-size: 2em; +} +#search-box { + display: flex; +} +#scaling-inprocess-message { + font-weight: bold; + font-style: italic; + display: none; +} +#prompt { + flex-grow: 1; + padding: 5px 10px 5px 10px; + border: 1px solid #999; + outline: none; +} +#submit { + padding: 5px 10px 5px 10px; + border: 1px solid #999; +} +#reset-all, #remove-image { + margin-top: 12px; + font-size: 0.8em; + background-color: pink; + border: 1px solid #999; + border-radius: 4px; +} +#results { + text-align: center; + margin: auto; + padding-top: 10px; +} +#results figure { + display: inline-block; + margin: 10px; +} +#results figcaption { + font-size: 0.8em; + padding: 3px; + color: #888; + cursor: pointer; +} +#results img { + border-radius: 5px; + object-fit: cover; +} +#fieldset-config { + line-height:2em; + background-color: #F0F0F0; +} +input[type="number"] { + width: 60px; +} +#seed { + width: 150px; +} +button#reset-seed { + font-size: 1.7em; + background: #efefef; + border: 1px solid #999; + border-radius: 4px; + line-height: 0.8; + margin: 0 10px 0 0; + padding: 0 5px 3px; + vertical-align: middle; +} +label { + white-space: nowrap; +} +#progress-section { + display: none; +} +#progress-image { + width: 30vh; + height: 30vh; +} +#cancel-button { + cursor: pointer; + color: red; +} +#basic-parameters { + background-color: #EEEEEE; +} +#txt2img { + background-color: #DCDCDC; +} +#variations { + background-color: #EEEEEE; +} +#img2img { + background-color: #DCDCDC; +} +#gfpgan { + background-color: #EEEEEE; +} +#progress-section { + background-color: #F5F5F5; +} +.section-header { + text-align: left; + font-weight: bold; + padding: 0 0 0 0; +} +#no-results-message:not(:only-child) { + display: none; +} + diff --git a/static/legacy_web/index.html b/static/legacy_web/index.html new file mode 100644 index 0000000000..5ce8b45baf --- /dev/null +++ b/static/legacy_web/index.html @@ -0,0 +1,129 @@ + + + Stable Diffusion Dream Server + + + + + + + + +
+

Stable Diffusion Dream Server

+
+ For news and support for this web service, visit our GitHub site +
+
+ +
+
+
+ +
+
+
Basic options
+ + + + + + + + + + +
+ + + + + + + + + + + + + + + + +
+
+
Image-to-image options
+ + + +
+ + + + +
+
+
Post-processing options
+ + + + + + +
+
+
+
+
+ + +
+ +
+ Postprocessing...1/3 +
+ +
+ +
+
+

No results...

+
+
+
+ + diff --git a/static/legacy_web/index.js b/static/legacy_web/index.js new file mode 100644 index 0000000000..ac68034920 --- /dev/null +++ b/static/legacy_web/index.js @@ -0,0 +1,213 @@ +function toBase64(file) { + return new Promise((resolve, reject) => { + const r = new FileReader(); + r.readAsDataURL(file); + r.onload = () => resolve(r.result); + r.onerror = (error) => reject(error); + }); +} + +function appendOutput(src, seed, config) { + let outputNode = document.createElement("figure"); + + let variations = config.with_variations; + if (config.variation_amount > 0) { + variations = (variations ? variations + ',' : '') + seed + ':' + config.variation_amount; + } + let baseseed = (config.with_variations || config.variation_amount > 0) ? config.seed : seed; + let altText = baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt; + + // img needs width and height for lazy loading to work + const figureContents = ` + + ${altText} + +
${seed}
+ `; + + outputNode.innerHTML = figureContents; + let figcaption = outputNode.querySelector('figcaption'); + + // Reload image config + figcaption.addEventListener('click', () => { + let form = document.querySelector("#generate-form"); + for (const [k, v] of new FormData(form)) { + if (k == 'initimg') { continue; } + form.querySelector(`*[name=${k}]`).value = config[k]; + } + + document.querySelector("#seed").value = baseseed; + document.querySelector("#with_variations").value = variations || ''; + if (document.querySelector("#variation_amount").value <= 0) { + document.querySelector("#variation_amount").value = 0.2; + } + + saveFields(document.querySelector("#generate-form")); + }); + + document.querySelector("#results").prepend(outputNode); +} + +function saveFields(form) { + for (const [k, v] of new FormData(form)) { + if (typeof v !== 'object') { // Don't save 'file' type + localStorage.setItem(k, v); + } + } +} + +function loadFields(form) { + for (const [k, v] of new FormData(form)) { + const item = localStorage.getItem(k); + if (item != null) { + form.querySelector(`*[name=${k}]`).value = item; + } + } +} + +function clearFields(form) { + localStorage.clear(); + let prompt = form.prompt.value; + form.reset(); + form.prompt.value = prompt; +} + +const BLANK_IMAGE_URL = 'data:image/svg+xml,'; +async function generateSubmit(form) { + const prompt = document.querySelector("#prompt").value; + + // Convert file data to base64 + let formData = Object.fromEntries(new FormData(form)); + formData.initimg_name = formData.initimg.name + formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; + + let strength = formData.strength; + let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps; + + let progressSectionEle = document.querySelector('#progress-section'); + progressSectionEle.style.display = 'initial'; + let progressEle = document.querySelector('#progress-bar'); + progressEle.setAttribute('max', totalSteps); + let progressImageEle = document.querySelector('#progress-image'); + progressImageEle.src = BLANK_IMAGE_URL; + + progressImageEle.style.display = {}.hasOwnProperty.call(formData, 'progress_images') ? 'initial': 'none'; + + // Post as JSON, using Fetch streaming to get results + fetch(form.action, { + method: form.method, + body: JSON.stringify(formData), + }).then(async (response) => { + const reader = response.body.getReader(); + + let noOutputs = true; + while (true) { + let {value, done} = await reader.read(); + value = new TextDecoder().decode(value); + if (done) { + progressSectionEle.style.display = 'none'; + break; + } + + for (let event of value.split('\n').filter(e => e !== '')) { + const data = JSON.parse(event); + + if (data.event === 'result') { + noOutputs = false; + appendOutput(data.url, data.seed, data.config); + progressEle.setAttribute('value', 0); + progressEle.setAttribute('max', totalSteps); + } else if (data.event === 'upscaling-started') { + document.getElementById("processing_cnt").textContent=data.processed_file_cnt; + document.getElementById("scaling-inprocess-message").style.display = "block"; + } else if (data.event === 'upscaling-done') { + document.getElementById("scaling-inprocess-message").style.display = "none"; + } else if (data.event === 'step') { + progressEle.setAttribute('value', data.step); + if (data.url) { + progressImageEle.src = data.url; + } + } else if (data.event === 'canceled') { + // avoid alerting as if this were an error case + noOutputs = false; + } + } + } + + // Re-enable form, remove no-results-message + form.querySelector('fieldset').removeAttribute('disabled'); + document.querySelector("#prompt").value = prompt; + document.querySelector('progress').setAttribute('value', '0'); + + if (noOutputs) { + alert("Error occurred while generating."); + } + }); + + // Disable form while generating + form.querySelector('fieldset').setAttribute('disabled',''); + document.querySelector("#prompt").value = `Generating: "${prompt}"`; +} + +async function fetchRunLog() { + try { + let response = await fetch('/run_log.json') + const data = await response.json(); + for(let item of data.run_log) { + appendOutput(item.url, item.seed, item); + } + } catch (e) { + console.error(e); + } +} + +window.onload = async () => { + document.querySelector("#prompt").addEventListener("keydown", (e) => { + if (e.key === "Enter" && !e.shiftKey) { + const form = e.target.form; + generateSubmit(form); + } + }); + document.querySelector("#generate-form").addEventListener('submit', (e) => { + e.preventDefault(); + const form = e.target; + + generateSubmit(form); + }); + document.querySelector("#generate-form").addEventListener('change', (e) => { + saveFields(e.target.form); + }); + document.querySelector("#reset-seed").addEventListener('click', (e) => { + document.querySelector("#seed").value = -1; + saveFields(e.target.form); + }); + document.querySelector("#reset-all").addEventListener('click', (e) => { + clearFields(e.target.form); + }); + document.querySelector("#remove-image").addEventListener('click', (e) => { + initimg.value=null; + }); + loadFields(document.querySelector("#generate-form")); + + document.querySelector('#cancel-button').addEventListener('click', () => { + fetch('/cancel').catch(e => { + console.error(e); + }); + }); + document.documentElement.addEventListener('keydown', (e) => { + if (e.key === "Escape") + fetch('/cancel').catch(err => { + console.error(err); + }); + }); + + if (!config.gfpgan_model_exists) { + document.querySelector("#gfpgan").style.display = 'none'; + } + await fetchRunLog() +};