Merge remote-tracking branch 'upstream/development' into development

This commit is contained in:
psychedelicious 2022-09-18 18:44:40 +10:00
commit 09bf6dd7c1
12 changed files with 739 additions and 207 deletions

134
README.md
View File

@ -1,21 +1,36 @@
<h1 align='center'><b>Stable Diffusion Dream Script</b></h1> <div align="center">
<p align='center'> # Stable Diffusion Dream Script
<img src="docs/assets/logo.png"/>
</p>
<p align="center"> ![project logo](docs/assets/logo.png)
<a href="https://github.com/lstein/stable-diffusion/releases"><img src="https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github" alt="release"/></a>
<a href="https://github.com/lstein/stable-diffusion/stargazers"><img src="https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github" alt="stars"/></a> [![discord badge]][discord link]
<a href="https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion"><img src="https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github" alt="forks"/></a>
<br /> [![latest release badge]][latest release link] [![github stars badge]][github stars link] [![github forks badge]][github forks link]
<a href="https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml"><img src="https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github" alt="CI status on main"/></a>
<a href="https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml"><img src="https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github" alt="CI status on dev"/></a> [![CI checks on main badge]][CI checks on main link] [![CI checks on dev badge]][CI checks on dev link] [![latest commit to dev badge]][latest commit to dev link]
<a href="https://github.com/lstein/stable-diffusion/commits/development"><img src="https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900" alt="last-dev-commit"/></a>
<br /> [![github open issues badge]][github open issues link] [![github open prs badge]][github open prs link]
<a href="https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen"><img src="https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github" alt="open-issues"/></a>
<a href="https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen"><img src="https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github" alt="open-prs"/></a> [CI checks on dev badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github
</p> [CI checks on dev link]: https://github.com/lstein/stable-diffusion/actions?query=branch%3Adevelopment
[CI checks on main badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github
[CI checks on main link]: https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml
[discord badge]: https://flat.badgen.net/discord/members/htRgbc7e?icon=discord
[discord link]: https://discord.com/invite/htRgbc7e
[github forks badge]: https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github
[github forks link]: https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion
[github open issues badge]: https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github
[github open issues link]: https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen
[github open prs badge]: https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github
[github open prs link]: https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen
[github stars badge]: https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github
[github stars link]: https://github.com/lstein/stable-diffusion/stargazers
[latest commit to dev badge]: https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900
[latest commit to dev link]: https://github.com/lstein/stable-diffusion/commits/development
[latest release badge]: https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github
[latest release link]: https://github.com/lstein/stable-diffusion/releases
</div>
This is a fork of [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion), the open This is a fork of [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion), the open
source text-to-image generator. It provides a streamlined process with various new features and source text-to-image generator. It provides a streamlined process with various new features and
@ -26,7 +41,7 @@ _Note: This fork is rapidly evolving. Please use the
[Issues](https://github.com/lstein/stable-diffusion/issues) tab to report bugs and make feature [Issues](https://github.com/lstein/stable-diffusion/issues) tab to report bugs and make feature
requests. Be sure to use the provided templates. They will help aid diagnose issues faster._ requests. Be sure to use the provided templates. They will help aid diagnose issues faster._
**Table of Contents** ## Table of Contents
1. [Installation](#installation) 1. [Installation](#installation)
2. [Hardware Requirements](#hardware-requirements) 2. [Hardware Requirements](#hardware-requirements)
@ -38,38 +53,38 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss
8. [Support](#support) 8. [Support](#support)
9. [Further Reading](#further-reading) 9. [Further Reading](#further-reading)
## Installation ### Installation
This fork is supported across multiple platforms. You can find individual installation instructions This fork is supported across multiple platforms. You can find individual installation instructions
below. below.
- ### [Linux](docs/installation/INSTALL_LINUX.md) - #### [Linux](docs/installation/INSTALL_LINUX.md)
- ### [Windows](docs/installation/INSTALL_WINDOWS.md) - #### [Windows](docs/installation/INSTALL_WINDOWS.md)
- ### [Macintosh](docs/installation/INSTALL_MAC.md) - #### [Macintosh](docs/installation/INSTALL_MAC.md)
## Hardware Requirements ### Hardware Requirements
**System** #### System
You wil need one of the following: You wil need one of the following:
- An NVIDIA-based graphics card with 4 GB or more VRAM memory. - An NVIDIA-based graphics card with 4 GB or more VRAM memory.
- An Apple computer with an M1 chip. - An Apple computer with an M1 chip.
**Memory** #### Memory
- At least 12 GB Main Memory RAM. - At least 12 GB Main Memory RAM.
**Disk** #### Disk
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies. - At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
**Note** > Note
>
If you are have a Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in > If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
full-precision mode as shown below. > full-precision mode as shown below.
Similarly, specify full-precision mode on Apple M1 hardware. Similarly, specify full-precision mode on Apple M1 hardware.
@ -79,43 +94,30 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
(ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision (ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision
``` ```
## Features ### Features
### Major Features #### Major Features
- #### [Interactive Command Line Interface](docs/features/CLI.md) - [Interactive Command Line Interface](docs/features/CLI.md)
- [Image To Image](docs/features/IMG2IMG.md)
- [Inpainting Support](docs/features/INPAINTING.md)
- [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md)
- [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
- [Google Colab](docs/features/OTHER.md#google-colab)
- [Web Server](docs/features/WEB.md)
- [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file)
- [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
- [Weighted Prompts](docs/features/OTHER.md#weighted-prompts)
- [Variations](docs/features/VARIATIONS.md)
- [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
- [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
- #### [Image To Image](docs/features/IMG2IMG.md) #### Other Features
- #### [Inpainting Support](docs/features/INPAINTING.md) - [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
- [Preload Models](docs/features/OTHER.md#preload-models)
- #### [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md) ### Latest Changes
- #### [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
- #### [Google Colab](docs/features/OTHER.md#google-colab)
- #### [Web Server](docs/features/WEB.md)
- #### [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file)
- #### [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
- #### [Weighted Prompts](docs/features/OTHER.md#weighted-prompts)
- #### [Variations](docs/features/VARIATIONS.md)
- #### [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
- #### [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
### Other Features
- #### [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
- #### [Preload Models](docs/features/OTHER.md#preload-models)
## Latest Changes
- v1.14 (11 September 2022) - v1.14 (11 September 2022)
@ -147,12 +149,12 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
For older changelogs, please visit the **[CHANGELOG](docs/features/CHANGELOG.md)**. For older changelogs, please visit the **[CHANGELOG](docs/features/CHANGELOG.md)**.
## Troubleshooting ### Troubleshooting
Please check out our **[Q&A](docs/help/TROUBLESHOOT.md)** to get solutions for common installation Please check out our **[Q&A](docs/help/TROUBLESHOOT.md)** to get solutions for common installation
problems and other issues. problems and other issues.
## Contributing ### Contributing
Anyone who wishes to contribute to this project, whether documentation, features, bug fixes, code Anyone who wishes to contribute to this project, whether documentation, features, bug fixes, code
cleanup, testing, or code reviews, is very much encouraged to do so. If you are unfamiliar with how cleanup, testing, or code reviews, is very much encouraged to do so. If you are unfamiliar with how
@ -164,13 +166,13 @@ important thing is to **make your pull request against the "development" branch*
"main". This will help keep public breakage to a minimum and will allow you to propose more radical "main". This will help keep public breakage to a minimum and will allow you to propose more radical
changes. changes.
## Contributors ### Contributors
This fork is a combined effort of various people from across the world. This fork is a combined effort of various people from across the world.
[Check out the list of all these amazing people](docs/other/CONTRIBUTORS.md). We thank them for [Check out the list of all these amazing people](docs/other/CONTRIBUTORS.md). We thank them for
their time, hard work and effort. their time, hard work and effort.
## Support ### Support
For support, please use this repository's GitHub Issues tracking service. Feel free to send me an For support, please use this repository's GitHub Issues tracking service. Feel free to send me an
email if you use and like the script. email if you use and like the script.
@ -178,7 +180,7 @@ email if you use and like the script.
Original portions of the software are Copyright (c) 2020 Original portions of the software are Copyright (c) 2020
[Lincoln D. Stein](https://github.com/lstein) [Lincoln D. Stein](https://github.com/lstein)
## Further Reading ### Further Reading
Please see the original README for more information on this software and underlying algorithm, Please see the original README for more information on this software and underlying algorithm,
located in the file [README-CompViz.md](docs/other/README-CompViz.md). located in the file [README-CompViz.md](docs/other/README-CompViz.md).

View File

@ -21,7 +21,7 @@ dependencies:
- test-tube>=0.7.5 - test-tube>=0.7.5
- streamlit==1.12.0 - streamlit==1.12.0
- send2trash==1.8.0 - send2trash==1.8.0
- pillow==6.2.0 - pillow==9.2.0
- einops==0.3.0 - einops==0.3.0
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0
- transformers==4.19.2 - transformers==4.19.2

View File

@ -2,7 +2,10 @@
The Args class parses both the command line (shell) arguments, as well as the The Args class parses both the command line (shell) arguments, as well as the
command string passed at the dream> prompt. It serves as the definitive repository command string passed at the dream> prompt. It serves as the definitive repository
of all the arguments used by Generate and their default values. of all the arguments used by Generate and their default values, and implements the
preliminary metadata standards discussed here:
https://github.com/lstein/stable-diffusion/issues/266
To use: To use:
opt = Args() opt = Args()
@ -52,15 +55,38 @@ you wish to apply logic as to which one to use. For example:
To add new attributes, edit the _create_arg_parser() and To add new attributes, edit the _create_arg_parser() and
_create_dream_cmd_parser() methods. _create_dream_cmd_parser() methods.
We also export the function build_metadata **Generating and retrieving sd-metadata**
To generate a dict representing RFC266 metadata:
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
This will generate an RFC266 dictionary that can then be turned into a JSON
and written to the PNG file. The optional seeds, weights, model_hash and
postprocesser arguments are not available to the opt object and so must be
provided externally. See how dream.py does it.
Note that this function was originally called format_metadata() and a wrapper
is provided that issues a deprecation notice.
To retrieve a (series of) opt objects corresponding to the metadata, do this:
opt_list = metadata_loads(metadata)
The metadata should be pulled out of the PNG image. pngwriter has a method
retrieve_metadata that will do this.
""" """
import argparse import argparse
from argparse import Namespace
import shlex import shlex
import json import json
import hashlib import hashlib
import os import os
import copy import copy
import base64
from ldm.dream.conditioning import split_weighted_subprompts from ldm.dream.conditioning import split_weighted_subprompts
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
@ -142,7 +168,7 @@ class Args(object):
a = vars(self) a = vars(self)
a.update(kwargs) a.update(kwargs)
switches = list() switches = list()
switches.append(f'"{a["prompt"]}') switches.append(f'"{a["prompt"]}"')
switches.append(f'-s {a["steps"]}') switches.append(f'-s {a["steps"]}')
switches.append(f'-W {a["width"]}') switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}') switches.append(f'-H {a["height"]}')
@ -151,15 +177,13 @@ class Args(object):
switches.append(f'-S {a["seed"]}') switches.append(f'-S {a["seed"]}')
if a['grid']: if a['grid']:
switches.append('--grid') switches.append('--grid')
if a['iterations'] and a['iterations']>0:
switches.append(f'-n {a["iterations"]}')
if a['seamless']: if a['seamless']:
switches.append('--seamless') switches.append('--seamless')
if a['init_img'] and len(a['init_img'])>0: if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}') switches.append(f'-I {a["init_img"]}')
if a['fit']: if a['fit']:
switches.append(f'--fit') switches.append(f'--fit')
if a['strength'] and a['strength']>0: if a['init_img'] and a['strength'] and a['strength']>0:
switches.append(f'-f {a["strength"]}') switches.append(f'-f {a["strength"]}')
if a['gfpgan_strength']: if a['gfpgan_strength']:
switches.append(f'-G {a["gfpgan_strength"]}') switches.append(f'-G {a["gfpgan_strength"]}')
@ -541,17 +565,20 @@ class Args(object):
) )
return parser return parser
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266 def format_metadata(**kwargs):
# it does not write all the required top-level metadata, writes too much image print(f'format_metadata() is deprecated. Please use metadata_dumps()')
# data, and doesn't support grids yet. But you gotta start somewhere, no? return metadata_dumps(kwargs)
def format_metadata(opt,
def metadata_dumps(opt,
seeds=[], seeds=[],
weights=None,
model_hash=None, model_hash=None,
postprocessing=None): postprocessing=None):
''' '''
Given an Args object, returns a partial implementation of Given an Args object, returns a dict containing the keys and
the stable diffusion metadata standard structure of the proposed stable diffusion metadata standard
https://github.com/lstein/stable-diffusion/discussions/392
This is intended to be turned into JSON and stored in the
"sd
''' '''
# add some RFC266 fields that are generated internally, and not as # add some RFC266 fields that are generated internally, and not as
# user args # user args
@ -593,12 +620,15 @@ def format_metadata(opt,
if opt.init_img: if opt.init_img:
rfc_dict['type'] = 'img2img' rfc_dict['type'] = 'img2img'
rfc_dict['strength_steps'] = rfc_dict.pop('strength') rfc_dict['strength_steps'] = rfc_dict.pop('strength')
rfc_dict['orig_hash'] = sha256(image_dict['init_img']) rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
else: else:
rfc_dict['type'] = 'txt2img' rfc_dict['type'] = 'txt2img'
images = [] images = []
if len(seeds)==0 and opt.seed:
seeds=[seed]
for seed in seeds: for seed in seeds:
rfc_dict['seed'] = seed rfc_dict['seed'] = seed
images.append(copy.copy(rfc_dict)) images.append(copy.copy(rfc_dict))
@ -612,6 +642,44 @@ def format_metadata(opt,
'images' : images, 'images' : images,
} }
def metadata_loads(metadata):
'''
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
and returns a series of opt objects for each of the images described in the dictionary.
'''
results = []
try:
images = metadata['sd-metadata']['images']
for image in images:
# repack the prompt and variations
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
opt = Args()
opt._cmd_switches = Namespace(**image)
results.append(opt)
except KeyError as e:
import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results
# image can either be a file path on disk or a base64-encoded
# representation of the file's contents
def calculate_init_img_hash(image_string):
prefix = 'data:image/png;base64,'
hash = None
if image_string.startswith(prefix):
imagebase64 = image_string[len(prefix):]
imagedata = base64.b64decode(imagebase64)
with open('outputs/test.png','wb') as file:
file.write(imagedata)
sha = hashlib.sha256()
sha.update(imagedata)
hash = sha.hexdigest()
else:
hash = sha256(image_string)
return hash
# Bah. This should be moved somewhere else... # Bah. This should be moved somewhere else...
def sha256(path): def sha256(path):
sha = hashlib.sha256() sha = hashlib.sha256()

View File

@ -4,7 +4,7 @@ import copy
import base64 import base64
import mimetypes import mimetypes
import os import os
from ldm.dream.args import Args, format_metadata from ldm.dream.args import Args, metadata_dumps
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from threading import Event from threading import Event
@ -76,7 +76,7 @@ class DreamServer(BaseHTTPRequestHandler):
self.send_response(200) self.send_response(200)
self.send_header("Content-type", "text/html") self.send_header("Content-type", "text/html")
self.end_headers() self.end_headers()
with open("./static/dream_web/index.html", "rb") as content: with open("./static/legacy_web/index.html", "rb") as content:
self.wfile.write(content.read()) self.wfile.write(content.read())
elif self.path == "/config.js": elif self.path == "/config.js":
# unfortunately this import can't be at the top level, since that would cause a circular import # unfortunately this import can't be at the top level, since that would cause a circular import
@ -94,7 +94,7 @@ class DreamServer(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
output = [] output = []
log_file = os.path.join(self.outdir, "dream_web_log.txt") log_file = os.path.join(self.outdir, "legacy_web_log.txt")
if os.path.exists(log_file): if os.path.exists(log_file):
with open(log_file, "r") as log: with open(log_file, "r") as log:
for line in log: for line in log:
@ -114,7 +114,7 @@ class DreamServer(BaseHTTPRequestHandler):
else: else:
path_dir = os.path.dirname(self.path) path_dir = os.path.dirname(self.path)
out_dir = os.path.realpath(self.outdir.rstrip('/')) out_dir = os.path.realpath(self.outdir.rstrip('/'))
if self.path.startswith('/static/dream_web/'): if self.path.startswith('/static/legacy_web/'):
path = '.' + self.path path = '.' + self.path
elif out_dir.replace('\\', '/').endswith(path_dir): elif out_dir.replace('\\', '/').endswith(path_dir):
file = os.path.basename(self.path) file = os.path.basename(self.path)
@ -145,7 +145,6 @@ class DreamServer(BaseHTTPRequestHandler):
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists) opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
self.canceled.clear() self.canceled.clear()
print(f">> Request to generate with prompt: {opt.prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state # In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in # across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done() # the outer scope of image_done()
@ -176,9 +175,8 @@ class DreamServer(BaseHTTPRequestHandler):
path = pngwriter.save_image_and_prompt_to_png( path = pngwriter.save_image_and_prompt_to_png(
image, image,
dream_prompt = formatted_prompt, dream_prompt = formatted_prompt,
metadata = format_metadata(iter_opt, metadata = metadata_dumps(iter_opt,
seeds = [seed], seeds = [seed],
weights = self.model.weights,
model_hash = self.model.model_hash model_hash = self.model.model_hash
), ),
name = name, name = name,
@ -188,7 +186,7 @@ class DreamServer(BaseHTTPRequestHandler):
config['seed'] = seed config['seed'] = seed
# Append post_data to log, but only once! # Append post_data to log, but only once!
if not upscaled: if not upscaled:
with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log: with open(os.path.join(self.outdir, "legacy_web_log.txt"), "a") as log:
log.write(f"{path}: {json.dumps(config)}\n") log.write(f"{path}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps( self.wfile.write(bytes(json.dumps(

View File

@ -168,100 +168,84 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
if torch.cuda.is_available(): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.einsum_op = self.einsum_op_cuda
else:
self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
def einsum_op_compvis(self, q, k, v, r1): def einsum_op_compvis(self, q, k, v):
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster s = einsum('b i d, b j d -> b i j', q, k)
s2 = s1.softmax(dim=-1, dtype=q.dtype) s = s.softmax(dim=-1, dtype=s.dtype)
del s1 return einsum('b i j, b j d -> b i d', s, v)
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
def einsum_op_mps_v1(self, q, k, v, r1): def einsum_op_slice_0(self, q, k, v, slice_size):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
r1 = self.einsum_op_compvis(q, k, v, r1) for i in range(0, q.shape[0], slice_size):
else: end = i + slice_size
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
def einsum_op_slice_1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
s2 = s1.softmax(dim=-1, dtype=r1.dtype) return r
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
def einsum_op_mps_v2(self, q, k, v, r1): def einsum_op_mps_v1(self, q, k, v):
if self.mem_total >= 8 and q.shape[1] <= 4096: if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1) return self.einsum_op_compvis(q, k, v)
else: else:
slice_size = 1 slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[0], slice_size): return self.einsum_op_slice_1(q, k, v, slice_size)
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1
def einsum_op_cuda(self, q, k, v, r1): def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v)
else:
return self.einsum_op_slice_0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device) stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
gb = 1024 ** 3 def einsum_op(self, q, k, v):
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 if q.device.type == 'cuda':
mem_required = tensor_size * 2.5 return self.einsum_op_cuda(q, k, v)
steps = 1
if mem_required > mem_free_total: if q.device.type == 'mps':
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
if steps > 64: # Smaller slices are faster due to L2/L3/SLC caches.
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 # Tested on i7 with 8MB L3 cache.
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' return self.einsum_op_tensor_mem(q, k, v, 32)
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = min(q.shape[1], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k_in = self.to_k(context) k = self.to_k(context) * self.scale
v_in = self.to_v(context) v = self.to_v(context)
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
del context, x del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
del q_in, k_in, v_in r = self.einsum_op(q, k, v)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
r1 = self.einsum_op(q, k, v, r1)
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):

View File

@ -3,6 +3,7 @@ import gc
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.functional import silu
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@ -122,14 +118,14 @@ class ResnetBlock(nn.Module):
def forward(self, x, temb): def forward(self, x, temb):
h = self.norm1(x) h = self.norm1(x)
h = nonlinearity(h) h = silu(h)
h = self.conv1(h) h = self.conv1(h)
if temb is not None: if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] h = h + self.temb_proj(silu(temb))[:,:,None,None]
h = self.norm2(h) h = self.norm2(h)
h = nonlinearity(h) h = silu(h)
h = self.dropout(h) h = self.dropout(h)
h = self.conv2(h) h = self.conv2(h)
@ -368,7 +364,7 @@ class Model(nn.Module):
assert t is not None assert t is not None
temb = get_timestep_embedding(t, self.ch) temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb) temb = self.temb.dense[0](temb)
temb = nonlinearity(temb) temb = silu(temb)
temb = self.temb.dense[1](temb) temb = self.temb.dense[1](temb)
else: else:
temb = None temb = None
@ -402,7 +398,7 @@ class Model(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
@ -499,7 +495,7 @@ class Encoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
@ -611,7 +607,7 @@ class Decoder(nn.Module):
return h return h
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = silu(h)
h = self.conv_out(h) h = self.conv_out(h)
if self.tanh_out: if self.tanh_out:
h = torch.tanh(h) h = torch.tanh(h)
@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module):
x = layer(x) x = layer(x)
h = self.norm_out(x) h = self.norm_out(x)
h = nonlinearity(h) h = silu(h)
x = self.conv_out(h) x = self.conv_out(h)
return x return x
@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module):
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h) h = self.upsample_blocks[k](h)
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module):
z_fs = self.encode_with_pretrained(x) z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs) z = self.proj_norm(z_fs)
z = self.proj(z) z = self.proj(z)
z = nonlinearity(z) z = silu(z)
for submodel, downmodel in zip(self.model,self.downsampler): for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None) z = submodel(z,temb=None)

View File

@ -252,12 +252,6 @@ def normalization(channels):
return GroupNorm32(32, channels) return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)

View File

@ -8,7 +8,7 @@ import copy
import warnings import warnings
import time import time
import ldm.dream.readline import ldm.dream.readline
from ldm.dream.args import Args, format_metadata from ldm.dream.args import Args, metadata_dumps
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid from ldm.dream.image_util import make_grid
@ -219,9 +219,13 @@ def main_loop(gen, opt, infile):
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `opt.grid` grid_images = dict() # seed -> Image, only used if `opt.grid`
prior_variations = opt.with_variations or []
first_seed = opt.seed
def image_writer(image, seed, upscaled=False): def image_writer(image, seed, upscaled=False):
path = None path = None
nonlocal first_seed
nonlocal prior_variations
if opt.grid: if opt.grid:
grid_images[seed] = image grid_images[seed] = image
else: else:
@ -229,29 +233,21 @@ def main_loop(gen, opt, infile):
filename = f'{prefix}.{seed}.postprocessed.png' filename = f'{prefix}.{seed}.postprocessed.png'
else: else:
filename = f'{prefix}.{seed}.png' filename = f'{prefix}.{seed}.png'
# the handling of variations is probably broken
# Also, given the ability to add stuff to the dream_prompt_str, it isn't
# necessary to make a copy of the opt option just to change its attributes
if opt.variation_amount > 0: if opt.variation_amount > 0:
iter_opt = copy.copy(opt) first_seed = first_seed or seed
this_variation = [[seed, opt.variation_amount]] this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None: opt.with_variations = prior_variations + this_variation
iter_opt.with_variations = this_variation formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
else: elif len(prior_variations) > 0:
iter_opt.with_variations = opt.with_variations + this_variation formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
iter_opt.variation_amount = 0
formatted_dream_prompt = iter_opt.dream_prompt_str(seed=seed)
elif opt.with_variations is not None:
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
else: else:
formatted_dream_prompt = opt.dream_prompt_str(seed=seed) formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
image = image, image = image,
dream_prompt = formatted_dream_prompt, dream_prompt = formatted_dream_prompt,
metadata = format_metadata( metadata = metadata_dumps(
opt, opt,
seeds = [seed], seeds = [seed],
weights = gen.weights,
model_hash = gen.model_hash, model_hash = gen.model_hash,
), ),
name = filename, name = filename,
@ -275,7 +271,7 @@ def main_loop(gen, opt, infile):
filename = f'{prefix}.{first_seed}.png' filename = f'{prefix}.{first_seed}.png'
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images)) formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
formatted_dream_prompt += f' # {grid_seeds}' formatted_dream_prompt += f' # {grid_seeds}'
metadata = format_metadata( metadata = metadata.dumps(
opt, opt,
seeds = grid_seeds, seeds = grid_seeds,
weights = gen.weights, weights = gen.weights,

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

152
static/legacy_web/index.css Normal file
View File

@ -0,0 +1,152 @@
* {
font-family: 'Arial';
font-size: 100%;
}
body {
font-size: 1em;
}
textarea {
font-size: 0.95em;
}
header, form, #progress-section {
margin-left: auto;
margin-right: auto;
max-width: 1024px;
text-align: center;
}
fieldset {
border: none;
line-height: 2.2em;
}
select, input {
margin-right: 10px;
padding: 2px;
}
input[type=submit] {
background-color: #666;
color: white;
}
input[type=checkbox] {
margin-right: 0px;
width: 20px;
height: 20px;
vertical-align: middle;
}
input#seed {
margin-right: 0px;
}
div {
padding: 10px 10px 10px 10px;
}
header {
margin-bottom: 16px;
}
header h1 {
margin-bottom: 0;
font-size: 2em;
}
#search-box {
display: flex;
}
#scaling-inprocess-message {
font-weight: bold;
font-style: italic;
display: none;
}
#prompt {
flex-grow: 1;
padding: 5px 10px 5px 10px;
border: 1px solid #999;
outline: none;
}
#submit {
padding: 5px 10px 5px 10px;
border: 1px solid #999;
}
#reset-all, #remove-image {
margin-top: 12px;
font-size: 0.8em;
background-color: pink;
border: 1px solid #999;
border-radius: 4px;
}
#results {
text-align: center;
margin: auto;
padding-top: 10px;
}
#results figure {
display: inline-block;
margin: 10px;
}
#results figcaption {
font-size: 0.8em;
padding: 3px;
color: #888;
cursor: pointer;
}
#results img {
border-radius: 5px;
object-fit: cover;
}
#fieldset-config {
line-height:2em;
background-color: #F0F0F0;
}
input[type="number"] {
width: 60px;
}
#seed {
width: 150px;
}
button#reset-seed {
font-size: 1.7em;
background: #efefef;
border: 1px solid #999;
border-radius: 4px;
line-height: 0.8;
margin: 0 10px 0 0;
padding: 0 5px 3px;
vertical-align: middle;
}
label {
white-space: nowrap;
}
#progress-section {
display: none;
}
#progress-image {
width: 30vh;
height: 30vh;
}
#cancel-button {
cursor: pointer;
color: red;
}
#basic-parameters {
background-color: #EEEEEE;
}
#txt2img {
background-color: #DCDCDC;
}
#variations {
background-color: #EEEEEE;
}
#img2img {
background-color: #DCDCDC;
}
#gfpgan {
background-color: #EEEEEE;
}
#progress-section {
background-color: #F5F5F5;
}
.section-header {
text-align: left;
font-weight: bold;
padding: 0 0 0 0;
}
#no-results-message:not(:only-child) {
display: none;
}

View File

@ -0,0 +1,129 @@
<html lang="en">
<head>
<title>Stable Diffusion Dream Server</title>
<meta charset="utf-8">
<link rel="icon" type="image/x-icon" href="static/legacy_web/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="static/legacy_web/index.css">
<script src="config.js"></script>
<script src="static/legacy_web/index.js"></script>
</head>
<body>
<header>
<h1>Stable Diffusion Dream Server</h1>
<div id="about">
For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a>
</div>
</header>
<main>
<form id="generate-form" method="post" action="#">
<fieldset id="txt2img">
<div id="search-box">
<textarea rows="3" id="prompt" name="prompt"></textarea>
<input type="submit" id="submit" value="Generate">
</div>
</fieldset>
<fieldset id="fieldset-config">
<div class="section-header">Basic options</div>
<label for="iterations">Images to generate:</label>
<input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps">
<label for="cfg_scale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
<label for="sampler_name">Sampler:</label>
<select id="sampler_name" name="sampler_name" value="k_lms">
<option value="ddim">DDIM</option>
<option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option>
<option value="k_dpm_2">KDPM_2</option>
<option value="k_dpm_2_a">KDPM_2A</option>
<option value="k_euler">KEULER</option>
<option value="k_euler_a">KEULER_A</option>
<option value="k_heun">KHEUN</option>
</select>
<input type="checkbox" name="seamless" id="seamless">
<label for="seamless">Seamless circular tiling</label>
<br>
<label title="Set to multiple of 64" for="width">Width:</label>
<select id="width" name="width" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
</select>
<label title="Set to multiple of 64" for="height">Height:</label>
<select id="height" name="height" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
</select>
<label title="Set to -1 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed">
<button type="button" id="reset-seed">&olarr;</button>
<input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slower)</label>
<button type="button" id="reset-all">Reset to Defaults</button>
<span id="variations">
<label title="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different." for="variation_amount">Variation amount (0 to disable):</label>
<input value="0" type="number" id="variation_amount" name="variation_amount" step="0.01" min="0" max="1">
<label title="list of variations to apply, in the format `seed:weight,seed:weight,..." for="with_variations">With variations (seed:weight,seed:weight,...):</label>
<input value="" type="text" id="with_variations" name="with_variations">
</span>
</fieldset>
<fieldset id="img2img">
<div class="section-header">Image-to-image options</div>
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<button type="button" id="remove-image">Remove Image</button>
<br>
<label for="strength">Img2Img Strength:</label>
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
<input type="checkbox" id="fit" name="fit" checked>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height</label>
</fieldset>
<fieldset id="gfpgan">
<div class="section-header">Post-processing options</div>
<label title="Strength of the gfpgan (face fixing) algorithm." for="gfpgan_strength">GPFGAN Strength (0 to disable):</label>
<input value="0.0" min="0" max="1" type="number" id="gfpgan_strength" name="gfpgan_strength" step="0.1">
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level</label>
<select id="upscale_level" name="upscale_level" value="">
<option value="" selected>None</option>
<option value="2">2x</option>
<option value="4">4x</option>
</select>
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
</fieldset>
</form>
<br>
<section id="progress-section">
<div id="progress-container">
<progress id="progress-bar" value="0" max="1"></progress>
<span id="cancel-button" title="Cancel">&#10006;</span>
<br>
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'>
<div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
</div>
</span>
</section>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>
</div>
</div>
</main>
</body>
</html>

213
static/legacy_web/index.js Normal file
View File

@ -0,0 +1,213 @@
function toBase64(file) {
return new Promise((resolve, reject) => {
const r = new FileReader();
r.readAsDataURL(file);
r.onload = () => resolve(r.result);
r.onerror = (error) => reject(error);
});
}
function appendOutput(src, seed, config) {
let outputNode = document.createElement("figure");
let variations = config.with_variations;
if (config.variation_amount > 0) {
variations = (variations ? variations + ',' : '') + seed + ':' + config.variation_amount;
}
let baseseed = (config.with_variations || config.variation_amount > 0) ? config.seed : seed;
let altText = baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt;
// img needs width and height for lazy loading to work
const figureContents = `
<a href="${src}" target="_blank">
<img src="${src}"
alt="${altText}"
title="${altText}"
loading="lazy"
width="256"
height="256">
</a>
<figcaption>${seed}</figcaption>
`;
outputNode.innerHTML = figureContents;
let figcaption = outputNode.querySelector('figcaption');
// Reload image config
figcaption.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
if (k == 'initimg') { continue; }
form.querySelector(`*[name=${k}]`).value = config[k];
}
document.querySelector("#seed").value = baseseed;
document.querySelector("#with_variations").value = variations || '';
if (document.querySelector("#variation_amount").value <= 0) {
document.querySelector("#variation_amount").value = 0.2;
}
saveFields(document.querySelector("#generate-form"));
});
document.querySelector("#results").prepend(outputNode);
}
function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
localStorage.setItem(k, v);
}
}
}
function loadFields(form) {
for (const [k, v] of new FormData(form)) {
const item = localStorage.getItem(k);
if (item != null) {
form.querySelector(`*[name=${k}]`).value = item;
}
}
}
function clearFields(form) {
localStorage.clear();
let prompt = form.prompt.value;
form.reset();
form.prompt.value = prompt;
}
const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>';
async function generateSubmit(form) {
const prompt = document.querySelector("#prompt").value;
// Convert file data to base64
let formData = Object.fromEntries(new FormData(form));
formData.initimg_name = formData.initimg.name
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
let strength = formData.strength;
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
let progressSectionEle = document.querySelector('#progress-section');
progressSectionEle.style.display = 'initial';
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('max', totalSteps);
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = BLANK_IMAGE_URL;
progressImageEle.style.display = {}.hasOwnProperty.call(formData, 'progress_images') ? 'initial': 'none';
// Post as JSON, using Fetch streaming to get results
fetch(form.action, {
method: form.method,
body: JSON.stringify(formData),
}).then(async (response) => {
const reader = response.body.getReader();
let noOutputs = true;
while (true) {
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) {
progressSectionEle.style.display = 'none';
break;
}
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event === 'result') {
noOutputs = false;
appendOutput(data.url, data.seed, data.config);
progressEle.setAttribute('value', 0);
progressEle.setAttribute('max', totalSteps);
} else if (data.event === 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (data.event === 'upscaling-done') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} else if (data.event === 'step') {
progressEle.setAttribute('value', data.step);
if (data.url) {
progressImageEle.src = data.url;
}
} else if (data.event === 'canceled') {
// avoid alerting as if this were an error case
noOutputs = false;
}
}
}
// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');
if (noOutputs) {
alert("Error occurred while generating.");
}
});
// Disable form while generating
form.querySelector('fieldset').setAttribute('disabled','');
document.querySelector("#prompt").value = `Generating: "${prompt}"`;
}
async function fetchRunLog() {
try {
let response = await fetch('/run_log.json')
const data = await response.json();
for(let item of data.run_log) {
appendOutput(item.url, item.seed, item);
}
} catch (e) {
console.error(e);
}
}
window.onload = async () => {
document.querySelector("#prompt").addEventListener("keydown", (e) => {
if (e.key === "Enter" && !e.shiftKey) {
const form = e.target.form;
generateSubmit(form);
}
});
document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
generateSubmit(form);
});
document.querySelector("#generate-form").addEventListener('change', (e) => {
saveFields(e.target.form);
});
document.querySelector("#reset-seed").addEventListener('click', (e) => {
document.querySelector("#seed").value = -1;
saveFields(e.target.form);
});
document.querySelector("#reset-all").addEventListener('click', (e) => {
clearFields(e.target.form);
});
document.querySelector("#remove-image").addEventListener('click', (e) => {
initimg.value=null;
});
loadFields(document.querySelector("#generate-form"));
document.querySelector('#cancel-button').addEventListener('click', () => {
fetch('/cancel').catch(e => {
console.error(e);
});
});
document.documentElement.addEventListener('keydown', (e) => {
if (e.key === "Escape")
fetch('/cancel').catch(err => {
console.error(err);
});
});
if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none';
}
await fetchRunLog()
};