React web UI with flask-socketio API (#429)

* Implements rudimentary api
* Fixes blocking in API
* Adds UI to monorepo > src/frontend/
* Updates frontend/README
* Reverts conda env name to `ldm`
* Fixes environment yamls
* CORS config for testing
* Fixes LogViewer position
* API WID
* Adds actions to image viewer
* Increases vite chunkSizeWarningLimit to 1500
* Implements init image
* Implements state persistence in localStorage
* Improve progress data handling
* Final build
* Fixes mimetypes error on windows
* Adds error logging
* Fixes bugged img2img strength component
* Adds sourcemaps to dev build
* Fixes missing key
* Changes connection status indicator to text
* Adds ability to serve other hosts than localhost
* Adding Flask API server
* Removes source maps from config
* Fixes prop transfer
* Add missing packages and add CORS support
* Adding API doc
* Remove defaults from openapi doc
* Adds basic error handling for server config query
* Mostly working socket.io implementation.
* Fixes bug preventing mask upload
* Fixes bug with sampler name not written to metadata
* UI Overhaul, numerous fixes

Co-authored-by: Kyle Schouviller <kyle0654@hotmail.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
psychedelicious 2022-09-17 03:18:15 +10:00 committed by GitHub
parent 403d02d94f
commit d1a2c4cd8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
89 changed files with 9647 additions and 121 deletions

11
.gitignore vendored
View File

@ -77,9 +77,6 @@ db.sqlite3-journal
instance/ instance/
.webassets-cache .webassets-cache
# WebUI temp files:
img2img-tmp.png
# Scrapy stuff: # Scrapy stuff:
.scrapy .scrapy
@ -186,3 +183,11 @@ testtube
checkpoints checkpoints
# If it's a Mac # If it's a Mac
.DS_Store .DS_Store
# Let the frontend manage its own gitignore
!frontend/*
# Scratch folder
.scratch/
.vscode/
gfpgan/

View File

@ -0,0 +1,206 @@
from modules.parse_seed_weights import parse_seed_weights
import argparse
SAMPLER_CHOICES = [
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
]
def parameters_to_command(params):
"""
Converts dict of parameters into a `dream.py` REPL command.
"""
switches = list()
if 'prompt' in params:
switches.append(f'"{params["prompt"]}"')
if 'steps' in params:
switches.append(f'-s {params["steps"]}')
if 'seed' in params:
switches.append(f'-S {params["seed"]}')
if 'width' in params:
switches.append(f'-W {params["width"]}')
if 'height' in params:
switches.append(f'-H {params["height"]}')
if 'cfg_scale' in params:
switches.append(f'-C {params["cfg_scale"]}')
if 'sampler_name' in params:
switches.append(f'-A {params["sampler_name"]}')
if 'seamless' in params and params["seamless"] == True:
switches.append(f'--seamless')
if 'init_img' in params and len(params['init_img']) > 0:
switches.append(f'-I {params["init_img"]}')
if 'init_mask' in params and len(params['init_mask']) > 0:
switches.append(f'-M {params["init_mask"]}')
if 'strength' in params and 'init_img' in params:
switches.append(f'-f {params["strength"]}')
if 'fit' in params and params["fit"] == True:
switches.append(f'--fit')
if 'gfpgan_strength' in params and params["gfpgan_strength"]:
switches.append(f'-G {params["gfpgan_strength"]}')
if 'upscale' in params and params["upscale"]:
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
if 'variation_amount' in params and params['variation_amount'] > 0:
switches.append(f'-v {params["variation_amount"]}')
if 'with_variations' in params:
seed_weight_pairs = ','.join(f'{seed}:{weight}' for seed, weight in params["with_variations"])
switches.append(f'-V {seed_weight_pairs}')
return ' '.join(switches)
def create_cmd_parser():
"""
This is simply a copy of the parser from `dream.py` with a change to give
prompt a default value. This is a temporary hack pending merge of #587 which
provides a better way to do this.
"""
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12',
exit_on_error=True,
)
parser.add_argument('prompt', nargs='?', default='')
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-W', '--width', type=int, help='Image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='Image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default=None,
help='Directory to save generated images and a log of prompts and seeds',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='Generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-M',
'--init_mask',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
parser.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
parser.add_argument(
'-U',
'--upscale',
nargs='+',
default=None,
type=float,
help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
)
parser.add_argument(
'-save_orig',
'--save_original',
action='store_true',
help='Save original. Use it when upscaling to save both versions.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='Skip subprompt weight normalization',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
default=None,
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
parser.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
parser.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
return parser

View File

@ -0,0 +1,47 @@
def parse_seed_weights(seed_weights):
"""
Accepts seed weights as string in "12345:0.1,23456:0.2,3456:0.3" format
Validates them
If valid: returns as [[12345, 0.1], [23456, 0.2], [3456, 0.3]]
If invalid: returns False
"""
# Must be a string
if not isinstance(seed_weights, str):
return False
# String must not be empty
if len(seed_weights) == 0:
return False
pairs = []
for pair in seed_weights.split(","):
split_values = pair.split(":")
# Seed and weight are required
if len(split_values) != 2:
return False
if len(split_values[0]) == 0 or len(split_values[1]) == 1:
return False
# Try casting the seed to int and weight to float
try:
seed = int(split_values[0])
weight = float(split_values[1])
except ValueError:
return False
# Seed must be 0 or above
if not seed >= 0:
return False
# Weight must be between 0 and 1
if not (weight >= 0 and weight <= 1):
return False
# This pair is valid
pairs.append([seed, weight])
# All pairs are valid
return pairs

399
backend/server.py Normal file
View File

@ -0,0 +1,399 @@
import mimetypes
import transformers
import json
import os
import traceback
import eventlet
import glob
import shlex
import argparse
from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify
from pathlib import Path
from PIL import Image
from pytorch_lightning import logging
from threading import Event
from uuid import uuid4
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
from ldm.gfpgan.gfpgan_tools import run_gfpgan
from ldm.generate import Generate
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from modules.parameters import parameters_to_command, create_cmd_parser
"""
USER CONFIG
"""
output_dir = "outputs/" # Base output directory for images
host = 'localhost' # Web & socket.io host
port = 9090 # Web & socket.io port
verbose = False # enables copious socket.io logging
additional_allowed_origins = ['http://localhost:5173'] # additional CORS allowed origins
"""
END USER CONFIG
"""
"""
SERVER SETUP
"""
# fix missing mimetypes on windows due to registry wonkiness
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('text/css', '.css')
app = Flask(__name__, static_url_path='', static_folder='../frontend/dist/')
app.config['OUTPUTS_FOLDER'] = "../outputs"
@app.route('/outputs/<path:filename>')
def outputs(filename):
return send_from_directory(
app.config['OUTPUTS_FOLDER'],
filename
)
@app.route("/", defaults={'path': ''})
def serve(path):
return send_from_directory(app.static_folder, 'index.html')
logger = True if verbose else False
engineio_logger = True if verbose else False
# default 1,000,000, needs to be higher for socketio to accept larger images
max_http_buffer_size = 10000000
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
socketio = SocketIO(
app,
logger=logger,
engineio_logger=engineio_logger,
max_http_buffer_size=max_http_buffer_size,
cors_allowed_origins=cors_allowed_origins,
)
"""
END SERVER SETUP
"""
"""
APP SETUP
"""
class CanceledException(Exception):
pass
canceled = Event()
# reduce logging outputs to error
transformers.logging.set_verbosity_error()
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# Initialize and load model
model = Generate()
model.load_model()
# location for "finished" images
result_path = os.path.join(output_dir, 'img-samples/')
# temporary path for intermediates
intermediate_path = os.path.join(result_path, 'intermediates/')
# path for user-uploaded init images and masks
init_path = os.path.join(result_path, 'init-images/')
mask_path = os.path.join(result_path, 'mask-images/')
# txt log
log_path = os.path.join(result_path, 'dream_log.txt')
# make all output paths
[os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_path, mask_path]]
"""
END APP SETUP
"""
"""
SOCKET.IO LISTENERS
"""
@socketio.on('requestAllImages')
def handle_request_all_images():
print(f'>> All images requested')
parser = create_cmd_parser()
paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png")))
paths.sort(key=lambda x: os.path.getmtime(x))
image_array = []
for path in paths:
image = Image.open(path)
metadata = {}
if 'Dream' in image.info:
try:
metadata = vars(parser.parse_args(shlex.split(image.info['Dream'])))
except SystemExit:
# TODO: Unable to parse metadata, ignore it for now,
# this can happen when metadata is missing a prompt
pass
image_array.append({'path': path, 'metadata': metadata})
return make_response("OK", data=image_array)
@socketio.on('generateImage')
def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan_parameters):
print(f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}')
generate_images(
generation_parameters,
esrgan_parameters,
gfpgan_parameters
)
return make_response("OK")
@socketio.on('runESRGAN')
def handle_run_esrgan_event(original_image, esrgan_parameters):
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
image = real_esrgan_upscale(
image=image,
upsampler_scale=esrgan_parameters['upscale'][0],
strength=esrgan_parameters['upscale'][1],
seed=seed
)
esrgan_parameters['seed'] = seed
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
command = parameters_to_command(esrgan_parameters)
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters})
@socketio.on('runGFPGAN')
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['gfpgan_strength'],
seed=seed,
upsampler_scale=1
)
gfpgan_parameters['seed'] = seed
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
command = parameters_to_command(gfpgan_parameters)
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters})
@socketio.on('cancel')
def handle_cancel():
print(f'>> Cancel processing requested')
canceled.set()
return make_response("OK")
# TODO: I think this needs a safety mechanism.
@socketio.on('deleteImage')
def handle_delete_image(path):
print(f'>> Delete requested "{path}"')
Path(path).unlink()
return make_response("OK")
# TODO: I think this needs a safety mechanism.
@socketio.on('uploadInitialImage')
def handle_upload_initial_image(bytes, name):
print(f'>> Init image upload requested "{name}"')
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(init_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
# TODO: I think this needs a safety mechanism.
@socketio.on('uploadMaskImage')
def handle_upload_mask_image(bytes, name):
print(f'>> Mask image upload requested "{name}"')
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(mask_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
"""
END SOCKET.IO LISTENERS
"""
"""
ADDITIONAL FUNCTIONS
"""
def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file"""
message = f'{message}\n'
with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(message)
def make_response(status, message=None, data=None):
response = {'status': status}
if message is not None:
response['message'] = message
if data is not None:
response['data'] = data
return response
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
pngwriter = PngWriter(output_dir)
prefix = pngwriter.unique_prefix()
filename = f'{prefix}.{seed}'
if step_index:
filename += f'.{step_index}'
if postprocessing:
filename += f'.postprocessed'
filename += '.png'
command = parameters_to_command(parameters)
path = pngwriter.save_image_and_prompt_to_png(image, command, filename)
return path
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear()
step_index = 1
def image_progress(sample, step):
if canceled.is_set():
raise CanceledException
nonlocal step_index
nonlocal generation_parameters
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
image = model.sample_to_image(sample)
path = save_image(image, generation_parameters, intermediate_path, step_index)
step_index += 1
socketio.emit('intermediateResult', {
'url': os.path.relpath(path), 'metadata': generation_parameters})
socketio.emit('progress', {'step': step + 1})
eventlet.sleep(0)
def image_done(image, seed):
nonlocal generation_parameters
nonlocal esrgan_parameters
nonlocal gfpgan_parameters
all_parameters = generation_parameters
postprocessing = False
if esrgan_parameters:
image = real_esrgan_upscale(
image=image,
strength=esrgan_parameters['strength'],
upsampler_scale=esrgan_parameters['level'],
seed=seed
)
postprocessing = True
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
if gfpgan_parameters:
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['strength'],
seed=seed,
upsampler_scale=1,
)
postprocessing = True
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
all_parameters['seed'] = seed
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
command = parameters_to_command(all_parameters)
print(f'Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}')
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters})
eventlet.sleep(0)
try:
model.prompt2image(
**generation_parameters,
step_callback=image_progress,
image_callback=image_done
)
except KeyboardInterrupt:
raise
except CanceledException:
pass
except Exception as e:
socketio.emit('error', (str(e)))
print("\n")
traceback.print_exc()
print("\n")
"""
END ADDITIONAL FUNCTIONS
"""
if __name__ == '__main__':
print(f'Starting server at http://{host}:{port}')
socketio.run(app, host=host, port=port)

19
docs/index.html Normal file
View File

@ -0,0 +1,19 @@
<!-- HTML for static distribution bundle build -->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Swagger UI</title>
<link rel="stylesheet" type="text/css" href="swagger-ui/swagger-ui.css" />
<link rel="stylesheet" type="text/css" href="swagger-ui/index.css" />
<link rel="icon" type="image/png" href="swagger-ui/favicon-32x32.png" sizes="32x32" />
<link rel="icon" type="image/png" href="swagger-ui/favicon-16x16.png" sizes="16x16" />
</head>
<body>
<div id="swagger-ui"></div>
<script src="swagger-ui/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script src="swagger-ui/swagger-ui-standalone-preset.js" charset="UTF-8"> </script>
<script src="swagger-ui/swagger-initializer.js" charset="UTF-8"> </script>
</body>
</html>

73
docs/openapi3_0.yaml Normal file
View File

@ -0,0 +1,73 @@
openapi: 3.0.3
info:
title: Stable Diffusion
description: |-
TODO: Description Here
Some useful links:
- [Stable Diffusion Dream Server](https://github.com/lstein/stable-diffusion)
license:
name: MIT License
url: https://github.com/lstein/stable-diffusion/blob/main/LICENSE
version: 1.0.0
servers:
- url: http://localhost:9090/api
tags:
- name: images
description: Retrieve and manage generated images
paths:
/images/{imageId}:
get:
tags:
- images
summary: Get image by ID
description: Returns a single image
operationId: getImageById
parameters:
- name: imageId
in: path
description: ID of image to return
required: true
schema:
type: string
responses:
'200':
description: successful operation
content:
image/png:
schema:
type: string
format: binary
'404':
description: Image not found
/intermediates/{intermediateId}/{step}:
get:
tags:
- images
summary: Get intermediate image by ID
description: Returns a single intermediate image
operationId: getIntermediateById
parameters:
- name: intermediateId
in: path
description: ID of intermediate to return
required: true
schema:
type: string
- name: step
in: path
description: The generation step of the intermediate
required: true
schema:
type: string
responses:
'200':
description: successful operation
content:
image/png:
schema:
type: string
format: binary
'404':
description: Intermediate not found

Binary file not shown.

After

Width:  |  Height:  |  Size: 665 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 628 B

16
docs/swagger-ui/index.css Normal file
View File

@ -0,0 +1,16 @@
html {
box-sizing: border-box;
overflow: -moz-scrollbars-vertical;
overflow-y: scroll;
}
*,
*:before,
*:after {
box-sizing: inherit;
}
body {
margin: 0;
background: #fafafa;
}

View File

@ -0,0 +1,79 @@
<!doctype html>
<html lang="en-US">
<head>
<title>Swagger UI: OAuth2 Redirect</title>
</head>
<body>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;
if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1).replace('?', '&');
} else {
qp = location.search.substring(1);
}
arr = qp.split("&");
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value);
}
) : {};
isValid = qp.state === sentState;
if ((
oauth2.auth.schema.get("flow") === "accessCode" ||
oauth2.auth.schema.get("flow") === "authorizationCode" ||
oauth2.auth.schema.get("flow") === "authorization_code"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
});
}
if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg;
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}
if (document.readyState !== 'loading') {
run();
} else {
document.addEventListener('DOMContentLoaded', function () {
run();
});
}
</script>
</body>
</html>

View File

@ -0,0 +1,20 @@
window.onload = function() {
//<editor-fold desc="Changeable Configuration Block">
// the following lines will be replaced by docker/configurator, when it runs in a docker-container
window.ui = SwaggerUIBundle({
url: "openapi3_0.yaml",
dom_id: '#swagger-ui',
deepLinking: true,
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIStandalonePreset
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
],
layout: "StandaloneLayout"
});
//</editor-fold>
};

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -40,6 +40,11 @@ dependencies:
- tensorboard==2.9.0 - tensorboard==2.9.0
- torchmetrics==0.9.3 - torchmetrics==0.9.3
- pip: - pip:
- flask==2.1.3
- flask_socketio==5.3.0
- flask_cors==3.0.10
- dependency_injector==4.40.0
- eventlet
- opencv-python==4.6.0 - opencv-python==4.6.0
- protobuf==3.20.1 - protobuf==3.20.1
- realesrgan==0.2.5.0 - realesrgan==0.2.5.0

View File

@ -25,6 +25,11 @@ dependencies:
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0
- transformers==4.19.2 - transformers==4.19.2
- torchmetrics==0.6.0 - torchmetrics==0.6.0
- flask==2.1.3
- flask_socketio==5.3.0
- flask_cors==3.0.10
- dependency_injector==4.40.0
- eventlet
- kornia==0.6.0 - kornia==0.6.0
- -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers

6
frontend/.eslintrc.cjs Normal file
View File

@ -0,0 +1,6 @@
module.exports = {
extends: ['eslint:recommended', 'plugin:@typescript-eslint/recommended', 'plugin:react-hooks/recommended'],
parser: '@typescript-eslint/parser',
plugins: ['@typescript-eslint', 'eslint-plugin-react-hooks'],
root: true,
};

25
frontend/.gitignore vendored Normal file
View File

@ -0,0 +1,25 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
# We want to distribute the repo
# dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

85
frontend/README.md Normal file
View File

@ -0,0 +1,85 @@
# Stable Diffusion Web UI
Demo at https://peaceful-otter-7a427f.netlify.app/ (not connected to back end)
much of this readme is just notes for myself during dev work
numpy rand: 0 to 4294967295
## Test and Build
from `frontend/`:
- `yarn dev` runs `tsc-watch`, which runs `vite build` on successful `tsc` transpilation
from `.`:
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
## API
`backend/server.py` serves the UI and provides a [socket.io](https://github.com/socketio/socket.io) API via [flask-socketio](https://github.com/miguelgrinberg/flask-socketio).
### Server Listeners
The server listens for these socket.io events:
`cancel`
- Cancels in-progress image generation
- Returns ack only
`generateImage`
- Accepts object of image parameters
- Generates an image
- Returns ack only (image generation function sends progress and result via separate events)
`deleteImage`
- Accepts file path to image
- Deletes image
- Returns ack only
`deleteAllImages` WIP
- Deletes all images in `outputs/`
- Returns ack only
`requestAllImages`
- Returns array of all images in `outputs/`
`requestCapabilities` WIP
- Returns capabilities of server (torch device, GFPGAN and ESRGAN availability, ???)
`sendImage` WIP
- Accepts a File and attributes
- Saves image
- Used to save init images which are not generated images
### Server Emitters
`progress`
- Emitted during each step in generation
- Sends a number from 0 to 1 representing percentage of steps completed
`result` WIP
- Emitted when an image generation has completed
- Sends a object:
```
{
url: relative_file_path,
metadata: image_metadata_object
}
```
## TODO
- Search repo for "TODO"
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on this. Need to check in on this issue periodically.

View File

@ -0,0 +1 @@
.checkerboard{background-position:0px 0px,10px 10px;background-size:20px 20px;background-image:linear-gradient(45deg,#eee 25%,transparent 25%,transparent 75%,#eee 75%,#eee 100%),linear-gradient(45deg,#eee 25%,white 25%,white 75%,#eee 75%,#eee 100%)}

695
frontend/dist/assets/index.cc5cde43.js vendored Normal file

File diff suppressed because one or more lines are too long

14
frontend/dist/index.html vendored Normal file
View File

@ -0,0 +1,14 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.cc5cde43.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head>
<body>
<div id="root"></div>
</body>
</html>

1
frontend/index.d.ts vendored Normal file
View File

@ -0,0 +1 @@
declare module 'redux-socket.io-middleware';

12
frontend/index.html Normal file
View File

@ -0,0 +1,12 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

46
frontend/package.json Normal file
View File

@ -0,0 +1,46 @@
{
"name": "sdui",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "tsc-watch --onSuccess 'yarn run vite build -m development'",
"hmr": "vite dev",
"build": "tsc && vite build",
"build-dev": "tsc && vite build -m development",
"preview": "vite preview"
},
"dependencies": {
"@chakra-ui/react": "^2.3.1",
"@emotion/react": "^11.10.4",
"@emotion/styled": "^11.10.4",
"@reduxjs/toolkit": "^1.8.5",
"@types/uuid": "^8.3.4",
"dateformat": "^5.0.3",
"framer-motion": "^7.2.1",
"lodash": "^4.17.21",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.2",
"react-icons": "^4.4.0",
"react-redux": "^8.0.2",
"redux-persist": "^6.0.0",
"socket.io-client": "^4.5.2",
"uuid": "^9.0.0"
},
"devDependencies": {
"@types/dateformat": "^5.0.0",
"@types/react": "^18.0.17",
"@types/react-dom": "^18.0.6",
"@typescript-eslint/eslint-plugin": "^5.36.2",
"@typescript-eslint/parser": "^5.36.2",
"@vitejs/plugin-react": "^2.0.1",
"eslint": "^8.23.0",
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react-hooks": "^4.6.0",
"tsc-watch": "^5.0.3",
"typescript": "^4.6.4",
"vite": "^3.0.7",
"vite-plugin-eslint": "^1.8.1"
}
}

60
frontend/src/App.tsx Normal file
View File

@ -0,0 +1,60 @@
import { Grid, GridItem } from '@chakra-ui/react';
import CurrentImage from './features/gallery/CurrentImage';
import LogViewer from './features/system/LogViewer';
import PromptInput from './features/sd/PromptInput';
import ProgressBar from './features/header/ProgressBar';
import { useEffect } from 'react';
import { useAppDispatch } from './app/hooks';
import { requestAllImages } from './app/socketio';
import ProcessButtons from './features/sd/ProcessButtons';
import ImageRoll from './features/gallery/ImageRoll';
import SiteHeader from './features/header/SiteHeader';
import OptionsAccordion from './features/sd/OptionsAccordion';
const App = () => {
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(requestAllImages());
}, [dispatch]);
return (
<>
<Grid
width='100vw'
height='100vh'
templateAreas={`
"header header header header"
"progressBar progressBar progressBar progressBar"
"menu prompt processButtons imageRoll"
"menu currentImage currentImage imageRoll"`}
gridTemplateRows={'36px 10px 100px auto'}
gridTemplateColumns={'350px auto 100px 388px'}
gap={2}
>
<GridItem area={'header'} pt={1}>
<SiteHeader />
</GridItem>
<GridItem area={'progressBar'}>
<ProgressBar />
</GridItem>
<GridItem pl='2' area={'menu'} overflowY='scroll'>
<OptionsAccordion />
</GridItem>
<GridItem area={'prompt'}>
<PromptInput />
</GridItem>
<GridItem area={'processButtons'}>
<ProcessButtons />
</GridItem>
<GridItem area={'currentImage'}>
<CurrentImage />
</GridItem>
<GridItem pr='2' area={'imageRoll'} overflowY='scroll'>
<ImageRoll />
</GridItem>
</Grid>
<LogViewer />
</>
);
};
export default App;

22
frontend/src/Loading.tsx Normal file
View File

@ -0,0 +1,22 @@
import { Flex, Spinner } from '@chakra-ui/react';
const Loading = () => {
return (
<Flex
width={'100vw'}
height={'100vh'}
alignItems='center'
justifyContent='center'
>
<Spinner
thickness='2px'
speed='1s'
emptyColor='gray.200'
color='gray.400'
size='xl'
/>
</Flex>
);
};
export default Loading;

View File

@ -0,0 +1,55 @@
// TODO: use Enums?
// Valid samplers
export const SAMPLERS: Array<string> = [
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_euler',
'k_euler_a',
'k_heun',
];
// Valid image widths
export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid image heights
export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
];
// Internal to human-readable parameters
export const PARAMETERS: { [key: string]: string } = {
prompt: 'Prompt',
iterations: 'Iterations',
steps: 'Steps',
cfgScale: 'CFG Scale',
height: 'Height',
width: 'Width',
sampler: 'Sampler',
seed: 'Seed',
img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling',
};
export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 4294967295;

View File

@ -0,0 +1,7 @@
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import type { RootState, AppDispatch } from './store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -0,0 +1,182 @@
import { SDState } from '../features/sd/sdSlice';
import randomInt from '../features/sd/util/randomInt';
import {
seedWeightsToString,
stringToSeedWeights,
} from '../features/sd/util/seedWeightPairs';
import { SystemState } from '../features/system/systemSlice';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
/*
These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa.
*/
export const frontendToBackendParameters = (
sdState: SDState,
systemState: SystemState
): { [key: string]: any } => {
const {
prompt,
iterations,
steps,
cfgScale,
height,
width,
sampler,
seed,
seamless,
shouldUseInitImage,
img2imgStrength,
initialImagePath,
maskPath,
shouldFitToWidthHeight,
shouldGenerateVariations,
variantAmount,
seedWeights,
shouldRunESRGAN,
upscalingLevel,
upscalingStrength,
shouldRunGFPGAN,
gfpganStrength,
shouldRandomizeSeed,
} = sdState;
const { shouldDisplayInProgress } = systemState;
const generationParameters: { [k: string]: any } = {
prompt,
iterations,
steps,
cfg_scale: cfgScale,
height,
width,
sampler_name: sampler,
seed,
seamless,
progress_images: shouldDisplayInProgress,
};
generationParameters.seed = shouldRandomizeSeed
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: seed;
if (shouldUseInitImage) {
generationParameters.init_img = initialImagePath;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
if (maskPath) {
generationParameters.init_mask = maskPath;
}
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variantAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeights(seedWeights);
}
} else {
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
let gfpganParameters: false | { [k: string]: any } = false;
if (shouldRunESRGAN) {
esrganParameters = {
level: upscalingLevel,
strength: upscalingStrength,
};
}
if (shouldRunGFPGAN) {
gfpganParameters = {
strength: gfpganStrength,
};
}
return {
generationParameters,
esrganParameters,
gfpganParameters,
};
};
export const backendToFrontendParameters = (parameters: {
[key: string]: any;
}) => {
const {
prompt,
iterations,
steps,
cfg_scale,
height,
width,
sampler_name,
seed,
seamless,
progress_images,
variation_amount,
with_variations,
gfpgan_strength,
upscale,
init_img,
init_mask,
strength,
} = parameters;
const sd: { [key: string]: any } = {
shouldDisplayInProgress: progress_images,
// init
shouldGenerateVariations: false,
shouldRunESRGAN: false,
shouldRunGFPGAN: false,
initialImagePath: '',
maskPath: '',
};
if (variation_amount > 0) {
sd.shouldGenerateVariations = true;
sd.variantAmount = variation_amount;
if (with_variations) {
sd.seedWeights = seedWeightsToString(with_variations);
}
}
if (gfpgan_strength > 0) {
sd.shouldRunGFPGAN = true;
sd.gfpganStrength = gfpgan_strength;
}
if (upscale) {
sd.shouldRunESRGAN = true;
sd.upscalingLevel = upscale[0];
sd.upscalingStrength = upscale[1];
}
if (init_img) {
sd.shouldUseInitImage = true
sd.initialImagePath = init_img;
sd.strength = strength;
if (init_mask) {
sd.maskPath = init_mask;
}
}
// if we had a prompt, add all the metadata, but if we don't have a prompt,
// we must have only done ESRGAN or GFPGAN so do not add that metadata
if (prompt) {
sd.prompt = prompt;
sd.iterations = iterations;
sd.steps = steps;
sd.cfgScale = cfg_scale;
sd.height = height;
sd.width = width;
sd.sampler = sampler_name;
sd.seed = seed;
sd.seamless = seamless;
}
return sd;
};

View File

@ -0,0 +1,393 @@
import { createAction, Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
SDMetadata,
setGalleryImages,
setIntermediateImage,
} from '../features/gallery/gallerySlice';
import {
addLogEntry,
setCurrentStep,
setIsConnected,
setIsProcessing,
} from '../features/system/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
import {
backendToFrontendParameters,
frontendToBackendParameters,
} from './parameterTranslation';
export interface SocketIOResponse {
status: 'OK' | 'ERROR';
message?: string;
data?: any;
}
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const { dispatch, getState } = store;
if (!areListenersSet) {
// CONNECT
socketio.on('connect', () => {
try {
dispatch(setIsConnected(true));
} catch (e) {
console.error(e);
}
});
// DISCONNECT
socketio.on('disconnect', () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(addLogEntry(`Disconnected from server`));
} catch (e) {
console.error(e);
}
});
// PROCESSING RESULT
socketio.on(
'result',
(data: {
url: string;
type: 'generation' | 'esrgan' | 'gfpgan';
uuid?: string;
metadata: { [key: string]: any };
}) => {
try {
const newUuid = uuidv4();
const { type, url, uuid, metadata } = data;
switch (type) {
case 'generation': {
const translatedMetadata =
backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry(`Image generated: ${url}`)
);
break;
}
case 'esrgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel =
metadata.upscale[0];
newMetadata.upscalingStrength =
metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`ESRGAN upscaled: ${url}`)
);
break;
}
case 'gfpgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength =
metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`GFPGAN fixed faces: ${url}`)
);
break;
}
}
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
}
);
// PROGRESS UPDATE
socketio.on('progress', (data: { step: number }) => {
try {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(data.step));
} catch (e) {
console.error(e);
}
});
// INTERMEDIATE IMAGE
socketio.on(
'intermediateResult',
(data: { url: string; metadata: SDMetadata }) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry(`Intermediate image generated: ${url}`)
);
} catch (e) {
console.error(e);
}
}
);
// ERROR FROM BACKEND
socketio.on('error', (message) => {
try {
dispatch(addLogEntry(`Server error: ${message}`));
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
});
areListenersSet = true;
}
// HANDLE ACTIONS
switch (action.type) {
// GENERATE IMAGE
case 'socketio/generateImage': {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const {
generationParameters,
esrganParameters,
gfpganParameters,
} = frontendToBackendParameters(
getState().sd,
getState().system
);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry(
`Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`
)
);
break;
}
// RUN ESRGAN (UPSCALING)
case 'socketio/runESRGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry(
`ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`
)
);
break;
}
// RUN GFPGAN (FIX FACES)
case 'socketio/runGFPGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry(
`GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`
)
);
break;
}
// DELETE IMAGE
case 'socketio/deleteImage': {
const imageToDelete = action.payload;
const { url } = imageToDelete;
socketio.emit(
'deleteImage',
url,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(removeImage(imageToDelete));
dispatch(addLogEntry(`Image deleted: ${url}`));
}
}
);
break;
}
// GET ALL IMAGES FOR GALLERY
case 'socketio/requestAllImages': {
socketio.emit(
'requestAllImages',
(response: SocketIOResponse) => {
dispatch(setGalleryImages(response.data));
dispatch(
addLogEntry(`Loaded ${response.data.length} images`)
);
}
);
break;
}
// CANCEL PROCESSING
case 'socketio/cancelProcessing': {
socketio.emit('cancel', (response: SocketIOResponse) => {
const { intermediateImage } = getState().gallery;
if (response.status === 'OK') {
dispatch(setIsProcessing(false));
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry(
`Intermediate image saved: ${intermediateImage.url}`
)
);
dispatch(clearIntermediateImage());
}
dispatch(addLogEntry(`Processing canceled`));
}
});
break;
}
// UPLOAD INITIAL IMAGE
case 'socketio/uploadInitialImage': {
const file = action.payload;
socketio.emit(
'uploadInitialImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setInitialImagePath(response.data));
dispatch(
addLogEntry(
`Initial image uploaded: ${response.data}`
)
);
}
}
);
break;
}
// UPLOAD MASK IMAGE
case 'socketio/uploadMaskImage': {
const file = action.payload;
socketio.emit(
'uploadMaskImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setMaskPath(response.data));
dispatch(
addLogEntry(
`Mask image uploaded: ${response.data}`
)
);
}
}
);
break;
}
}
next(action);
};
return middleware;
};
// Actions to be used by app
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

53
frontend/src/app/store.ts Normal file
View File

@ -0,0 +1,53 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import sdReducer from '../features/sd/sdSlice';
import galleryReducer from '../features/gallery/gallerySlice';
import systemReducer from '../features/system/systemSlice';
import { socketioMiddleware } from './socketio';
const reducers = combineReducers({
sd: sdReducer,
gallery: galleryReducer,
system: systemReducer,
});
const persistConfig = {
key: 'root',
storage,
};
const persistedReducer = persistReducer(persistConfig, reducers);
/*
The frontend needs to be distributed as a production build, so
we cannot reasonably ask users to edit the JS and specify the
host and port on which the socket.io server will run.
The solution is to allow server script to be run with arguments
(or just edited) providing the host and port. Then, the server
serves a route `/socketio_config` which responds with the host
and port.
When the frontend loads, it synchronously requests that route
and thus gets the host and port. This requires a suspicious
fetch somewhere, and the store setup seems like as good a place
as any to make this fetch request.
*/
// Continue with store setup
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// redux-persist sometimes needs to have a function in redux, need to disable this check
serializableCheck: false,
}).concat(socketioMiddleware()),
});
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<typeof store.getState>;
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
export type AppDispatch = typeof store.dispatch;

37
frontend/src/app/theme.ts Normal file
View File

@ -0,0 +1,37 @@
import { extendTheme } from '@chakra-ui/react';
import type { StyleFunctionProps } from '@chakra-ui/styled-system';
export const theme = extendTheme({
config: {
initialColorMode: 'dark',
useSystemColorMode: false,
},
components: {
Tooltip: {
baseStyle: (props: StyleFunctionProps) => ({
textColor: props.colorMode === 'dark' ? 'gray.800' : 'gray.100',
}),
},
Accordion: {
baseStyle: (props: StyleFunctionProps) => ({
button: {
fontWeight: 'bold',
_hover: {
bgColor:
props.colorMode === 'dark'
? 'rgba(255,255,255,0.05)'
: 'rgba(0,0,0,0.05)',
},
},
panel: {
paddingBottom: 2,
},
}),
},
FormLabel: {
baseStyle: {
fontWeight: 'light',
},
},
},
});

View File

@ -0,0 +1,16 @@
import { Button, ButtonProps } from '@chakra-ui/react';
interface Props extends ButtonProps {
label: string;
}
const SDButton = (props: Props) => {
const { label, size = 'sm', ...rest } = props;
return (
<Button size={size} {...rest}>
{label}
</Button>
);
};
export default SDButton;

View File

@ -0,0 +1,56 @@
import {
FormControl,
NumberInput,
NumberInputField,
NumberInputStepper,
NumberIncrementStepper,
NumberDecrementStepper,
Text,
FormLabel,
NumberInputProps,
Flex,
} from '@chakra-ui/react';
interface Props extends NumberInputProps {
label?: string;
width?: string | number;
}
const SDNumberInput = (props: Props) => {
const {
label,
isDisabled = false,
fontSize = 'md',
size = 'sm',
width,
isInvalid,
...rest
} = props;
return (
<FormControl isDisabled={isDisabled} width={width} isInvalid={isInvalid}>
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
{label && (
<FormLabel marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'>
{label}
</Text>
</FormLabel>
)}
<NumberInput
size={size}
{...rest}
keepWithinRange={false}
clampValueOnBlur={true}
>
<NumberInputField fontSize={'md'}/>
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
</Flex>
</FormControl>
);
};
export default SDNumberInput;

View File

@ -0,0 +1,57 @@
import {
Flex,
FormControl,
FormLabel,
Select,
SelectProps,
Text,
} from '@chakra-ui/react';
interface Props extends SelectProps {
label: string;
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
}
const SDSelect = (props: Props) => {
const {
label,
isDisabled,
validValues,
size = 'sm',
fontSize = 'md',
marginBottom = 1,
whiteSpace = 'nowrap',
...rest
} = props;
return (
<FormControl isDisabled={isDisabled}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
<FormLabel
marginBottom={marginBottom}
>
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
{label}
</Text>
</FormLabel>
<Select fontSize={fontSize} size={size} {...rest}>
{validValues.map((opt) => {
return typeof opt === 'string' ||
typeof opt === 'number' ? (
<option key={opt} value={opt}>
{opt}
</option>
) : (
<option key={opt.value} value={opt.value}>
{opt.key}
</option>
);
})}
</Select>
</Flex>
</FormControl>
);
};
export default SDSelect;

View File

@ -0,0 +1,42 @@
import {
Flex,
FormControl,
FormLabel,
Switch,
SwitchProps,
} from '@chakra-ui/react';
interface Props extends SwitchProps {
label?: string;
width?: string | number;
}
const SDSwitch = (props: Props) => {
const {
label,
isDisabled = false,
fontSize = 'md',
size = 'md',
width,
...rest
} = props;
return (
<FormControl isDisabled={isDisabled} width={width}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
{label && (
<FormLabel
fontSize={fontSize}
marginBottom={1}
flexGrow={2}
whiteSpace='nowrap'
>
{label}
</FormLabel>
)}
<Switch size={size} {...rest} />
</Flex>
</FormControl>
);
};
export default SDSwitch;

View File

@ -0,0 +1,161 @@
import { Center, Flex, Image, useColorModeValue } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';
import DeleteImageModalButton from './DeleteImageModalButton';
import SDButton from '../../components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash';
const height = 'calc(100vh - 238px)';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const CurrentImage = () => {
const { currentImage, intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const bgColor = useColorModeValue(
'rgba(255, 255, 255, 0.85)',
'rgba(0, 0, 0, 0.8)'
);
const [shouldShowImageDetails, setShouldShowImageDetails] =
useState<boolean>(false);
const imageToDisplay = intermediateImage || currentImage;
return (
<Flex direction={'column'} rounded={'md'} borderWidth={1} p={2} gap={2}>
{imageToDisplay && (
<Flex gap={2}>
<SDButton
label='Use as initial image'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setInitialImagePath(imageToDisplay.url))
}
/>
<SDButton
label='Use all'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setAllParameters(imageToDisplay.metadata))
}
/>
<SDButton
label='Use seed'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!imageToDisplay.metadata.seed}
onClick={() =>
dispatch(setSeed(imageToDisplay.metadata.seed!))
}
/>
<SDButton
label='Upscale'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isESRGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runESRGAN(imageToDisplay))}
/>
<SDButton
label='Fix faces'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isGFPGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runGFPGAN(imageToDisplay))}
/>
<SDButton
label='Details'
colorScheme={'gray'}
variant={shouldShowImageDetails ? 'solid' : 'outline'}
borderWidth={1}
flexGrow={1}
onClick={() =>
setShouldShowImageDetails(!shouldShowImageDetails)
}
/>
<DeleteImageModalButton image={imageToDisplay}>
<SDButton
label='Delete'
colorScheme={'red'}
flexGrow={1}
variant={'outline'}
isDisabled={Boolean(intermediateImage)}
/>
</DeleteImageModalButton>
</Flex>
)}
<Center height={height} position={'relative'}>
{imageToDisplay && (
<Image
src={imageToDisplay.url}
fit='contain'
maxWidth={'100%'}
maxHeight={'100%'}
/>
)}
{imageToDisplay && shouldShowImageDetails && (
<Flex
width={'100%'}
height={'100%'}
position={'absolute'}
top={0}
left={0}
p={3}
boxSizing='border-box'
backgroundColor={bgColor}
overflow='scroll'
>
<ImageMetadataViewer image={imageToDisplay} />
</Flex>
)}
</Center>
</Flex>
);
};
export default CurrentImage;

View File

@ -0,0 +1,94 @@
import {
IconButtonProps,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
cloneElement,
ReactElement,
SyntheticEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { deleteImage } from '../../app/socketio';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice';
interface Props extends IconButtonProps {
image: SDImage;
'aria-label': string;
children: ReactElement;
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.shouldConfirmOnDelete
);
/*
TODO: The modal and button to open it should be two different components,
but their state is closely related and I'm not sure how best to accomplish it.
*/
const DeleteImageModalButton = (props: Omit<Props, 'aria-label'>) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete();
};
const { image, children } = props;
const handleDelete = () => {
dispatch(deleteImage(image));
onClose();
};
const handleDeleteAndDontAsk = () => {
dispatch(deleteImage(image));
dispatch(setShouldConfirmOnDelete(false));
onClose();
};
return (
<>
{cloneElement(children, {
onClick: handleClickDelete,
})}
<Modal isOpen={isOpen} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Are you sure you want to delete this image?</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Text>It will be deleted forever!</Text>
</ModalBody>
<ModalFooter justifyContent={'space-between'}>
<SDButton label={'Yes'} colorScheme='red' onClick={handleDelete} />
<SDButton
label={"Yes, and don't ask me again"}
colorScheme='red'
onClick={handleDeleteAndDontAsk}
/>
<SDButton label='Cancel' colorScheme='blue' onClick={onClose} />
</ModalFooter>
</ModalContent>
</Modal>
</>
);
};
export default DeleteImageModalButton;

View File

@ -0,0 +1,124 @@
import {
Center,
Flex,
IconButton,
Link,
List,
ListItem,
Text,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { PARAMETERS } from '../../app/constants';
import { useAppDispatch } from '../../app/hooks';
import SDButton from '../../components/SDButton';
import { setAllParameters, setParameter } from '../sd/sdSlice';
import { SDImage, SDMetadata } from './gallerySlice';
type Props = {
image: SDImage;
};
const ImageMetadataViewer = ({ image }: Props) => {
const dispatch = useAppDispatch();
const keys = Object.keys(PARAMETERS);
const metadata: Array<{
label: string;
key: string;
value: string | number | boolean;
}> = [];
keys.forEach((key) => {
const value = image.metadata[key as keyof SDMetadata];
if (value !== undefined) {
metadata.push({ label: PARAMETERS[key], key, value });
}
});
return (
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
<SDButton
label='Use all parameters'
colorScheme={'gray'}
padding={2}
isDisabled={metadata.length === 0}
onClick={() => dispatch(setAllParameters(image.metadata))}
/>
<Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal>
<Text>{image.url}</Text>
</Link>
</Flex>
{metadata.length ? (
<>
<List>
{metadata.map((parameter, i) => {
const { label, key, value } = parameter;
return (
<ListItem key={i} pb={1}>
<Flex gap={2}>
<IconButton
aria-label='Use this parameter'
icon={<FaPlus />}
size={'xs'}
onClick={() =>
dispatch(
setParameter({
key,
value,
})
)
}
/>
<Text fontWeight={'semibold'}>
{label}:
</Text>
{value === undefined ||
value === null ||
value === '' ||
value === 0 ? (
<Text
maxHeight={100}
fontStyle={'italic'}
>
None
</Text>
) : (
<Text
maxHeight={100}
overflowY={'scroll'}
>
{value.toString()}
</Text>
)}
</Flex>
</ListItem>
);
})}
</List>
<Flex gap={2}>
<Text fontWeight={'semibold'}>Raw:</Text>
<Text
maxHeight={100}
overflowY={'scroll'}
wordBreak={'break-all'}
>
{JSON.stringify(image.metadata)}
</Text>
</Flex>
</>
) : (
<Center width={'100%'} pt={10}>
<Text fontSize={'lg'} fontWeight='semibold'>
No metadata available
</Text>
</Center>
)}
</Flex>
);
};
export default ImageMetadataViewer;

View File

@ -0,0 +1,150 @@
import {
Box,
Flex,
Icon,
IconButton,
Image,
useColorModeValue,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { SDImage, setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
import DeleteImageModalButton from './DeleteImageModalButton';
import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../sd/sdSlice';
interface HoverableImageProps {
image: SDImage;
isSelected: boolean;
}
const HoverableImage = memo(
(props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const checkColor = useColorModeValue('green.600', 'green.300');
const bgColor = useColorModeValue('gray.200', 'gray.700');
const bgGradient = useColorModeValue(
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
);
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setSeed(image.metadata.seed!)); // component not rendered unless this exists
};
return (
<Box position={'relative'} key={uuid}>
<Image
width={120}
height={120}
objectFit='cover'
rounded={'md'}
src={url}
loading={'lazy'}
backgroundColor={bgColor}
/>
<Flex
cursor={'pointer'}
position={'absolute'}
top={0}
left={0}
rounded={'md'}
width='100%'
height='100%'
alignItems={'center'}
justifyContent={'center'}
background={isSelected ? bgGradient : undefined}
onClick={() => dispatch(setCurrentImage(image))}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
{isSelected && (
<Icon
fill={checkColor}
width={'50%'}
height={'50%'}
as={FaCheck}
/>
)}
{isHovered && (
<Flex
direction={'column'}
gap={1}
position={'absolute'}
top={1}
right={1}
>
<DeleteImageModalButton image={image}>
<IconButton
colorScheme='red'
aria-label='Delete image'
icon={<FaTrash />}
size='xs'
fontSize={15}
/>
</DeleteImageModalButton>
<IconButton
aria-label='Use all parameters'
colorScheme={'blue'}
icon={<FaCopy />}
size='xs'
fontSize={15}
onClickCapture={handleClickSetAllParameters}
/>
{image.metadata.seed && (
<IconButton
aria-label='Use seed'
colorScheme={'blue'}
icon={<FaSeedling />}
size='xs'
fontSize={16}
onClickCapture={handleClickSetSeed}
/>
)}
</Flex>
)}
</Flex>
</Box>
);
},
(prev, next) =>
prev.image.uuid === next.image.uuid &&
prev.isSelected === next.isSelected
);
const ImageRoll = () => {
const { images, currentImageUuid } = useAppSelector(
(state: RootState) => state.gallery
);
return (
<Flex gap={2} wrap='wrap' pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage
key={uuid}
image={image}
isSelected={isSelected}
/>
);
})}
</Flex>
);
};
export default ImageRoll;

View File

@ -0,0 +1,144 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import { UpscalingLevel } from '../sd/sdSlice';
import { backendToFrontendParameters } from '../../app/parameterTranslation';
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
export interface SDMetadata {
prompt?: string;
steps?: number;
cfgScale?: number;
height?: number;
width?: number;
sampler?: string;
seed?: number;
img2imgStrength?: number;
gfpganStrength?: number;
upscalingLevel?: UpscalingLevel;
upscalingStrength?: number;
initialImagePath?: string;
maskPath?: string;
seamless?: boolean;
shouldFitToWidthHeight?: boolean;
}
export interface SDImage {
// TODO: I have installed @types/uuid but cannot figure out how to use them here.
uuid: string;
url: string;
metadata: SDMetadata;
}
export interface GalleryState {
currentImageUuid: string;
images: Array<SDImage>;
intermediateImage?: SDImage;
currentImage?: SDImage;
}
const initialState: GalleryState = {
currentImageUuid: '',
images: [],
};
export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
setCurrentImage: (state, action: PayloadAction<SDImage>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
removeImage: (state, action: PayloadAction<SDImage>) => {
const { uuid } = action.payload;
const newImages = state.images.filter((image) => image.uuid !== uuid);
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
const newCurrentImageIndex = Math.min(
Math.max(imageToDeleteIndex, 0),
newImages.length - 1
);
state.images = newImages;
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
},
addImage: (state, action: PayloadAction<SDImage>) => {
state.images.push(action.payload);
state.currentImageUuid = action.payload.uuid;
state.intermediateImage = undefined;
state.currentImage = action.payload;
},
setIntermediateImage: (state, action: PayloadAction<SDImage>) => {
state.intermediateImage = action.payload;
},
clearIntermediateImage: (state) => {
state.intermediateImage = undefined;
},
setGalleryImages: (
state,
action: PayloadAction<
Array<{
path: string;
metadata: { [key: string]: string | number | boolean };
}>
>
) => {
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
const images = action.payload;
if (images.length === 0) {
// there are no images on disk, clear the gallery
state.images = [];
state.currentImageUuid = '';
state.currentImage = undefined;
} else {
// Filter image urls that are already in the rehydrated state
const filteredImages = action.payload.filter(
(image) => !state.images.find((i) => i.url === image.path)
);
const preparedImages = filteredImages.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
const newImages = [...state.images].concat(preparedImages);
// if previous currentimage no longer exists, set a new one
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
const newCurrentImage = newImages[newImages.length - 1];
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
state.images = newImages;
}
},
},
});
export const {
setCurrentImage,
removeImage,
addImage,
setGalleryImages,
setIntermediateImage,
clearIntermediateImage,
} = gallerySlice.actions;
export default gallerySlice.reducer;

View File

@ -0,0 +1,35 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
realSteps: sd.realSteps,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ProgressBar = () => {
const { realSteps } = useAppSelector(sdSelector);
const { currentStep } = useAppSelector((state: RootState) => state.system);
const progress = Math.round((currentStep * 100) / realSteps);
return (
<Progress
height='10px'
value={progress}
isIndeterminate={progress < 0 || currentStep === realSteps}
/>
);
};
export default ProgressBar;

View File

@ -0,0 +1,93 @@
import {
Flex,
Heading,
IconButton,
Link,
Spacer,
Text,
useColorMode,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { isConnected: system.isConnected };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode();
const { isConnected } = useAppSelector(systemSelector);
return (
<Flex minWidth='max-content' alignItems='center' gap='1' pl={2} pr={1}>
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>
<Spacer />
<Text textColor={isConnected ? 'green.500' : 'red.500'}>
{isConnected ? `Connected to server` : 'No connection to server'}
</Text>
<SettingsModal>
<IconButton
aria-label='Settings'
variant='link'
fontSize={24}
size={'sm'}
icon={<MdSettings />}
/>
</SettingsModal>
<IconButton
aria-label='Link to Github Issues'
variant='link'
fontSize={23}
size={'sm'}
icon={
<Link
isExternal
href='http://github.com/lstein/stable-diffusion/issues'
>
<MdHelp />
</Link>
}
/>
<IconButton
aria-label='Link to Github Repo'
variant='link'
fontSize={20}
size={'sm'}
icon={
<Link isExternal href='http://github.com/lstein/stable-diffusion'>
<FaGithub />
</Link>
}
/>
<IconButton
aria-label='Toggle Dark Mode'
onClick={toggleColorMode}
variant='link'
size={'sm'}
fontSize={colorMode == 'light' ? 18 : 20}
icon={colorMode == 'light' ? <FaMoon /> : <FaSun />}
/>
</Flex>
);
};
export default SiteHeader;

View File

@ -0,0 +1,84 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
SDState,
} from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
upscalingLevel: sd.upscalingLevel,
upscalingStrength: sd.upscalingStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ESRGANOptions = () => {
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDSelect
isDisabled={!isESRGANAvailable}
label='Scale'
value={upscalingLevel}
onChange={(e) =>
dispatch(
setUpscalingLevel(
Number(e.target.value) as UpscalingLevel
)
)
}
validValues={UPSCALING_LEVELS}
/>
<SDNumberInput
isDisabled={!isESRGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setUpscalingStrength(Number(v)))}
value={upscalingStrength}
/>
</Flex>
);
};
export default ESRGANOptions;

View File

@ -0,0 +1,63 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { SDState, setGfpganStrength } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
gfpganStrength: sd.gfpganStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const GFPGANOptions = () => {
const { gfpganStrength } = useAppSelector(sdSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!isGFPGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setGfpganStrength(Number(v)))}
value={gfpganStrength}
/>
</Flex>
);
};
export default GFPGANOptions;

View File

@ -0,0 +1,54 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import InitImage from './InitImage';
import {
SDState,
setImg2imgStrength,
setShouldFitToWidthHeight,
} from './sdSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
img2imgStrength: sd.img2imgStrength,
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
};
}
);
const ImageToImageOptions = () => {
const { initialImagePath, img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!initialImagePath}
label='Strength'
step={0.01}
min={0}
max={1}
onChange={(v) => dispatch(setImg2imgStrength(Number(v)))}
value={img2imgStrength}
/>
<SDSwitch
isDisabled={!initialImagePath}
label='Fit initial image to output size'
isChecked={shouldFitToWidthHeight}
onChange={(e) =>
dispatch(setShouldFitToWidthHeight(e.target.checked))
}
/>
<InitImage />
</Flex>
);
};
export default ImageToImageOptions;

View File

@ -0,0 +1,20 @@
.checkerboard {
background-position: 0px 0px, 10px 10px;
background-size: 20px 20px;
background-image: linear-gradient(
45deg,
#eee 25%,
transparent 25%,
transparent 75%,
#eee 75%,
#eee 100%
),
linear-gradient(
45deg,
#eee 25%,
white 25%,
white 75%,
#eee 75%,
#eee 100%
);
}

View File

@ -0,0 +1,155 @@
import {
Button,
Flex,
IconButton,
Image,
useToast,
} from '@chakra-ui/react';
import { SyntheticEvent, useCallback, useState } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { FaTrash } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import {
SDState,
setInitialImagePath,
setMaskPath,
} from '../../features/sd/sdSlice';
import MaskUploader from './MaskUploader';
import './InitImage.css';
import { uploadInitialImage } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
maskPath: sd.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
const InitImage = () => {
const toast = useToast();
const dispatch = useAppDispatch();
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadInitialImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(''));
dispatch(setMaskPath(''));
};
const handleMouseOverInitialImageUploadButton = () =>
setShouldShowMask(false);
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
return (
<Flex
{...getRootProps({
onClick: initialImagePath ? (e) => e.stopPropagation() : undefined,
})}
direction={'column'}
alignItems={'center'}
gap={2}
>
<input {...getInputProps({ multiple: false })} />
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton}
>
Upload Image
</Button>
<MaskUploader>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton}
>
Upload Mask
</Button>
</MaskUploader>
<IconButton
size={'sm'}
aria-label={'Reset initial image and mask'}
onClick={handleClickResetInitialImageAndMask}
icon={<FaTrash />}
/>
</Flex>
{initialImagePath && (
<Flex position={'relative'} width={'100%'}>
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
className={'checkerboard'}
/>
{shouldShowMask && maskPath && (
<Image
position={'absolute'}
top={0}
left={0}
fit={'contain'}
src={maskPath}
rounded={'md'}
zIndex={1}
className={'checkerboard'}
/>
)}
</Flex>
)}
</Flex>
);
};
export default InitImage;

View File

@ -0,0 +1,61 @@
import { useToast } from '@chakra-ui/react';
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useAppDispatch } from '../../app/hooks';
import { uploadMaskImage } from '../../app/socketio';
type Props = {
children: ReactElement;
};
const MaskUploader = ({ children }: Props) => {
const dispatch = useAppDispatch();
const toast = useToast();
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) =>
acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadMaskImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
return (
<div {...getRootProps()}>
<input {...getInputProps({ multiple: false })} />
{cloneElement(children, {
onClick: handleClickUploadIcon,
})}
</div>
);
};
export default MaskUploader;

View File

@ -0,0 +1,211 @@
import {
Flex,
Box,
Text,
Accordion,
AccordionItem,
AccordionButton,
AccordionIcon,
AccordionPanel,
Switch,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
setShouldRunGFPGAN,
setShouldRunESRGAN,
SDState,
setShouldUseInitImage,
} from '../sd/sdSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { setOpenAccordions, SystemState } from '../system/systemSlice';
import SeedVariationOptions from './SeedVariationOptions';
import SamplerOptions from './SamplerOptions';
import ESRGANOptions from './ESRGANOptions';
import GFPGANOptions from './GFPGANOptions';
import OutputOptions from './OutputOptions';
import ImageToImageOptions from './ImageToImageOptions';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
shouldUseInitImage: sd.shouldUseInitImage,
shouldRunESRGAN: sd.shouldRunESRGAN,
shouldRunGFPGAN: sd.shouldRunGFPGAN,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
openAccordions: system.openAccordions,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const OptionsAccordion = () => {
const {
shouldRunESRGAN,
shouldRunGFPGAN,
shouldUseInitImage,
initialImagePath,
} = useAppSelector(sdSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Accordion
defaultIndex={openAccordions}
allowMultiple
reduceMotion
onChange={(openAccordions) =>
dispatch(setOpenAccordions(openAccordions))
}
>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Seed & Variation
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SeedVariationOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Sampler
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SamplerOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Upscale (ESRGAN)</Text>
<Switch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={(e) =>
dispatch(
setShouldRunESRGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ESRGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Fix Faces (GFPGAN)</Text>
<Switch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunGFPGAN}
onChange={(e) =>
dispatch(
setShouldRunGFPGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<GFPGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Image to Image</Text>
<Switch
isDisabled={!initialImagePath}
isChecked={shouldUseInitImage}
onChange={(e) =>
dispatch(
setShouldUseInitImage(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ImageToImageOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Output
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<OutputOptions />
</AccordionPanel>
</AccordionItem>
</Accordion>
);
};
export default OptionsAccordion;

View File

@ -0,0 +1,66 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
import SDSelect from '../../components/SDSelect';
import { HEIGHTS, WIDTHS } from '../../app/constants';
import SDSwitch from '../../components/SDSwitch';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
height: sd.height,
width: sd.width,
seamless: sd.seamless,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const OutputOptions = () => {
const { height, width, seamless } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<SDSelect
label='Width'
value={width}
flexGrow={1}
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
validValues={WIDTHS}
/>
<SDSelect
label='Height'
value={height}
flexGrow={1}
onChange={(e) =>
dispatch(setHeight(Number(e.target.value)))
}
validValues={HEIGHTS}
/>
</Flex>
<SDSwitch
label='Seamless tiling'
fontSize={'md'}
isChecked={seamless}
onChange={(e) => dispatch(setSeamless(e.target.checked))}
/>
</Flex>
);
};
export default OutputOptions;

View File

@ -0,0 +1,58 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { cancelProcessing, generateImage } from '../../app/socketio';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { SystemState } from '../system/systemSlice';
import useCheckParameters from '../system/useCheckParameters';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ProcessButtons = () => {
const { isProcessing, isConnected } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const isReady = useCheckParameters();
return (
<Flex gap={2} direction={'column'} alignItems={'space-between'} height={'100%'}>
<SDButton
label='Generate'
type='submit'
colorScheme='green'
flexGrow={1}
isDisabled={!isReady}
fontSize={'md'}
size={'md'}
onClick={() => dispatch(generateImage())}
/>
<SDButton
label='Cancel'
colorScheme='red'
flexGrow={1}
fontSize={'md'}
size={'md'}
isDisabled={!isConnected || !isProcessing}
onClick={() => dispatch(cancelProcessing())}
/>
</Flex>
);
};
export default ProcessButtons;

View File

@ -0,0 +1,25 @@
import { Textarea } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setPrompt } from '../sd/sdSlice';
const PromptInput = () => {
const { prompt } = useAppSelector((state: RootState) => state.sd);
const dispatch = useAppDispatch();
return (
<Textarea
id='prompt'
name='prompt'
resize='none'
size={'lg'}
height={'100%'}
isInvalid={!prompt.length}
onChange={(e) => dispatch(setPrompt(e.target.value))}
value={prompt}
placeholder="I'm dreaming of..."
/>
);
};
export default PromptInput;

View File

@ -0,0 +1,51 @@
import {
Slider,
SliderTrack,
SliderFilledTrack,
SliderThumb,
FormControl,
FormLabel,
Text,
Flex,
SliderProps,
} from '@chakra-ui/react';
interface Props extends SliderProps {
label: string;
value: number;
fontSize?: number | string;
}
const SDSlider = ({
label,
value,
fontSize = 'sm',
onChange,
...rest
}: Props) => {
return (
<FormControl>
<Flex gap={2}>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'>
{label}
</Text>
</FormLabel>
<Slider
aria-label={label}
focusThumbOnChange={true}
value={value}
onChange={onChange}
{...rest}
>
<SliderTrack>
<SliderFilledTrack />
</SliderTrack>
<SliderThumb />
</Slider>
</Flex>
</FormControl>
);
};
export default SDSlider;

View File

@ -0,0 +1,62 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
steps: sd.steps,
cfgScale: sd.cfgScale,
sampler: sd.sampler,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const SamplerOptions = () => {
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Steps'
min={1}
step={1}
precision={0}
onChange={(v) => dispatch(setSteps(Number(v)))}
value={steps}
/>
<SDNumberInput
label='CFG scale'
step={0.5}
onChange={(v) => dispatch(setCfgScale(Number(v)))}
value={cfgScale}
/>
<SDSelect
label='Sampler'
value={sampler}
onChange={(e) => dispatch(setSampler(e.target.value))}
validValues={SAMPLERS}
/>
</Flex>
);
};
export default SamplerOptions;

View File

@ -0,0 +1,144 @@
import {
Flex,
Input,
HStack,
FormControl,
FormLabel,
Text,
Button,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import {
randomizeSeed,
SDState,
setIterations,
setSeed,
setSeedWeights,
setShouldGenerateVariations,
setShouldRandomizeSeed,
setVariantAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
shouldRandomizeSeed: sd.shouldRandomizeSeed,
seed: sd.seed,
iterations: sd.iterations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const SeedVariationOptions = () => {
const {
shouldGenerateVariations,
variantAmount,
seedWeights,
shouldRandomizeSeed,
seed,
iterations,
} = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Images to generate'
step={1}
min={1}
precision={0}
onChange={(v) => dispatch(setIterations(Number(v)))}
value={iterations}
/>
<SDSwitch
label='Randomize seed on generation'
isChecked={shouldRandomizeSeed}
onChange={(e) =>
dispatch(setShouldRandomizeSeed(e.target.checked))
}
/>
<Flex gap={2}>
<SDNumberInput
label='Seed'
step={1}
precision={0}
flexGrow={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
isDisabled={shouldRandomizeSeed}
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={(v) => dispatch(setSeed(Number(v)))}
value={seed}
/>
<Button
size={'sm'}
isDisabled={shouldRandomizeSeed}
onClick={() => dispatch(randomizeSeed())}
>
<Text pl={2} pr={2}>
Shuffle
</Text>
</Button>
</Flex>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Variation amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default SeedVariationOptions;

View File

@ -0,0 +1,92 @@
import {
Flex,
FormControl,
FormLabel,
HStack,
Input,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import {
SDState,
setSeedWeights,
setShouldGenerateVariations,
setVariantAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const Variant = () => {
const { shouldGenerateVariations, variantAmount, seedWeights } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} alignItems={'center'} pl={1}>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
width={240}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={'sm'} whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default Variant;

View File

@ -0,0 +1,283 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { SDMetadata } from '../gallery/gallerySlice';
import randomInt from './util/randomInt';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
const calculateRealSteps = (
steps: number,
strength: number,
hasInitImage: boolean
): number => {
return hasInitImage ? Math.floor(strength * steps) : steps;
};
export type UpscalingLevel = 0 | 2 | 3 | 4;
export interface SDState {
prompt: string;
iterations: number;
steps: number;
realSteps: number;
cfgScale: number;
height: number;
width: number;
sampler: string;
seed: number;
img2imgStrength: number;
gfpganStrength: number;
upscalingLevel: UpscalingLevel;
upscalingStrength: number;
shouldUseInitImage: boolean;
initialImagePath: string;
maskPath: string;
seamless: boolean;
shouldFitToWidthHeight: boolean;
shouldGenerateVariations: boolean;
variantAmount: number;
seedWeights: string;
shouldRunESRGAN: boolean;
shouldRunGFPGAN: boolean;
shouldRandomizeSeed: boolean;
}
const initialSDState: SDState = {
prompt: '',
iterations: 1,
steps: 50,
realSteps: 50,
cfgScale: 7.5,
height: 512,
width: 512,
sampler: 'k_lms',
seed: 0,
seamless: false,
shouldUseInitImage: false,
img2imgStrength: 0.75,
initialImagePath: '',
maskPath: '',
shouldFitToWidthHeight: true,
shouldGenerateVariations: false,
variantAmount: 0.1,
seedWeights: '',
shouldRunESRGAN: false,
upscalingLevel: 4,
upscalingStrength: 0.75,
shouldRunGFPGAN: false,
gfpganStrength: 0.8,
shouldRandomizeSeed: true,
};
const initialState: SDState = initialSDState;
export const sdSlice = createSlice({
name: 'sd',
initialState,
reducers: {
setPrompt: (state, action: PayloadAction<string>) => {
state.prompt = action.payload;
},
setIterations: (state, action: PayloadAction<number>) => {
state.iterations = action.payload;
},
setSteps: (state, action: PayloadAction<number>) => {
const { img2imgStrength, initialImagePath } = state;
const steps = action.payload;
state.steps = steps;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
},
setCfgScale: (state, action: PayloadAction<number>) => {
state.cfgScale = action.payload;
},
setHeight: (state, action: PayloadAction<number>) => {
state.height = action.payload;
},
setWidth: (state, action: PayloadAction<number>) => {
state.width = action.payload;
},
setSampler: (state, action: PayloadAction<string>) => {
state.sampler = action.payload;
},
setSeed: (state, action: PayloadAction<number>) => {
state.seed = action.payload;
state.shouldRandomizeSeed = false;
},
setImg2imgStrength: (state, action: PayloadAction<number>) => {
const img2imgStrength = action.payload;
const { steps, initialImagePath } = state;
state.img2imgStrength = img2imgStrength;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
},
setGfpganStrength: (state, action: PayloadAction<number>) => {
state.gfpganStrength = action.payload;
},
setUpscalingLevel: (state, action: PayloadAction<UpscalingLevel>) => {
state.upscalingLevel = action.payload;
},
setUpscalingStrength: (state, action: PayloadAction<number>) => {
state.upscalingStrength = action.payload;
},
setShouldUseInitImage: (state, action: PayloadAction<boolean>) => {
state.shouldUseInitImage = action.payload;
},
setInitialImagePath: (state, action: PayloadAction<string>) => {
const initialImagePath = action.payload;
const { steps, img2imgStrength } = state;
state.shouldUseInitImage = initialImagePath ? true : false;
state.initialImagePath = initialImagePath;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
},
setMaskPath: (state, action: PayloadAction<string>) => {
state.maskPath = action.payload;
},
setSeamless: (state, action: PayloadAction<boolean>) => {
state.seamless = action.payload;
},
setShouldFitToWidthHeight: (state, action: PayloadAction<boolean>) => {
state.shouldFitToWidthHeight = action.payload;
},
resetSeed: (state) => {
state.seed = -1;
},
randomizeSeed: (state) => {
state.seed = randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX);
},
setParameter: (
state,
action: PayloadAction<{ key: string; value: string | number | boolean }>
) => {
const { key, value } = action.payload;
const temp = { ...state, [key]: value };
if (key === 'seed') {
temp.shouldRandomizeSeed = false;
}
if (key === 'initialImagePath' && value === '') {
temp.shouldUseInitImage = false;
}
return temp;
},
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
state.shouldGenerateVariations = action.payload;
},
setVariantAmount: (state, action: PayloadAction<number>) => {
state.variantAmount = action.payload;
},
setSeedWeights: (state, action: PayloadAction<string>) => {
state.seedWeights = action.payload;
},
setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
const {
prompt,
steps,
cfgScale,
height,
width,
sampler,
seed,
img2imgStrength,
gfpganStrength,
upscalingLevel,
upscalingStrength,
initialImagePath,
maskPath,
seamless,
shouldFitToWidthHeight,
} = action.payload;
// ?? = falsy values ('', 0, etc) are used
// || = falsy values not used
state.prompt = prompt ?? state.prompt;
state.steps = steps || state.steps;
state.cfgScale = cfgScale || state.cfgScale;
state.width = width || state.width;
state.height = height || state.height;
state.sampler = sampler || state.sampler;
state.seed = seed ?? state.seed;
state.seamless = seamless ?? state.seamless;
state.shouldFitToWidthHeight =
shouldFitToWidthHeight ?? state.shouldFitToWidthHeight;
state.img2imgStrength = img2imgStrength ?? state.img2imgStrength;
state.gfpganStrength = gfpganStrength ?? state.gfpganStrength;
state.upscalingLevel = upscalingLevel ?? state.upscalingLevel;
state.upscalingStrength = upscalingStrength ?? state.upscalingStrength;
state.initialImagePath = initialImagePath ?? state.initialImagePath;
state.maskPath = maskPath ?? state.maskPath;
// If the image whose parameters we are using has a seed, disable randomizing the seed
if (seed) {
state.shouldRandomizeSeed = false;
}
// if we have a gfpgan strength, enable it
state.shouldRunGFPGAN = gfpganStrength ? true : false;
// if we have a esrgan strength, enable it
state.shouldRunESRGAN = upscalingLevel ? true : false;
// if we want to recreate an image exactly, we disable variations
state.shouldGenerateVariations = false;
state.shouldUseInitImage = initialImagePath ? true : false;
},
resetSDState: (state) => {
return {
...state,
...initialSDState,
};
},
setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => {
state.shouldRunGFPGAN = action.payload;
},
setShouldRunESRGAN: (state, action: PayloadAction<boolean>) => {
state.shouldRunESRGAN = action.payload;
},
setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => {
state.shouldRandomizeSeed = action.payload;
},
},
});
export const {
setPrompt,
setIterations,
setSteps,
setCfgScale,
setHeight,
setWidth,
setSampler,
setSeed,
setSeamless,
setImg2imgStrength,
setGfpganStrength,
setUpscalingLevel,
setUpscalingStrength,
setShouldUseInitImage,
setInitialImagePath,
setMaskPath,
resetSeed,
randomizeSeed,
resetSDState,
setShouldFitToWidthHeight,
setParameter,
setShouldGenerateVariations,
setSeedWeights,
setVariantAmount,
setAllParameters,
setShouldRunGFPGAN,
setShouldRunESRGAN,
setShouldRandomizeSeed,
} = sdSlice.actions;
export default sdSlice.reducer;

View File

@ -0,0 +1,5 @@
const randomInt = (min: number, max: number): number => {
return Math.floor(Math.random() * (max - min + 1) + min);
};
export default randomInt;

View File

@ -0,0 +1,56 @@
export interface SeedWeightPair {
seed: number;
weight: number;
}
export type SeedWeights = Array<Array<number>>;
export const stringToSeedWeights = (string: string): SeedWeights | boolean => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
const pairs = arrPairs.map((p) => {
return [parseInt(p[0]), parseFloat(p[1])];
});
if (!validateSeedWeights(pairs)) {
return false;
}
return pairs;
};
export const validateSeedWeights = (
seedWeights: SeedWeights | string
): boolean => {
return typeof seedWeights === 'string'
? Boolean(stringToSeedWeights(seedWeights))
: Boolean(
seedWeights.length &&
!seedWeights.some((pair) => {
const [seed, weight] = pair;
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
const isWeightValid =
!isNaN(parseInt(weight.toString(), 10)) &&
weight >= 0 &&
weight <= 1;
return !(isSeedValid && isWeightValid);
})
);
};
export const seedWeightsToString = (
seedWeights: SeedWeights
): string | boolean => {
if (!validateSeedWeights(seedWeights)) {
return false;
}
return seedWeights.reduce((acc, pair, i, arr) => {
const [seed, weight] = pair;
acc += `${seed}:${weight}`;
if (i !== arr.length - 1) {
acc += ',';
}
return acc;
}, '');
};

View File

@ -0,0 +1,125 @@
import {
IconButton,
useColorModeValue,
Flex,
Text,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setShouldShowLogViewer, SystemState } from './systemSlice';
import { useLayoutEffect, useRef, useState } from 'react';
import { FaAngleDoubleDown, FaCode, FaMinus } from 'react-icons/fa';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const logSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.log,
{
memoizeOptions: {
resultEqualityCheck: (a, b) => a.length === b.length,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { shouldShowLogViewer: system.shouldShowLogViewer };
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const LogViewer = () => {
const dispatch = useAppDispatch();
const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500');
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const log = useAppSelector(logSelector);
const { shouldShowLogViewer } = useAppSelector(systemSelector);
const viewerRef = useRef<HTMLDivElement>(null);
useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
}
});
return (
<>
{shouldShowLogViewer && (
<Flex
position={'fixed'}
left={0}
bottom={0}
height='200px'
width='100vw'
overflow='auto'
direction='column'
fontFamily='monospace'
fontSize='sm'
pl={12}
pr={2}
pb={2}
borderTopWidth='4px'
borderColor={borderColor}
background={bg}
ref={viewerRef}
>
{log.map((entry, i) => (
<Flex gap={2} key={i}>
<Text fontSize='sm' fontWeight={'semibold'}>
{entry.timestamp}:
</Text>
<Text fontSize='sm' wordBreak={'break-all'}>
{entry.message}
</Text>
</Flex>
))}
</Flex>
)}
{shouldShowLogViewer && (
<Tooltip
label={
shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'
}
>
<IconButton
size='sm'
position={'fixed'}
left={2}
bottom={12}
aria-label='Toggle autoscroll'
variant={'solid'}
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
icon={<FaAngleDoubleDown />}
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
/>
</Tooltip>
)}
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
<IconButton
size='sm'
position={'fixed'}
left={2}
bottom={2}
variant={'solid'}
aria-label='Toggle Log Viewer'
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
onClick={() =>
dispatch(setShouldShowLogViewer(!shouldShowLogViewer))
}
/>
</Tooltip>
</>
);
};
export default LogViewer;

View File

@ -0,0 +1,170 @@
import {
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Switch,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
setShouldConfirmOnDelete,
setShouldDisplayInProgress,
SystemState,
} from './systemSlice';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { persistor } from '../../main';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { cloneElement, ReactElement } from 'react';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
return { shouldDisplayInProgress, shouldConfirmOnDelete };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
type Props = {
children: ReactElement;
};
const SettingsModal = ({ children }: Props) => {
const {
isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen,
onClose: onSettingsModalClose,
} = useDisclosure();
const {
isOpen: isRefreshModalOpen,
onOpen: onRefreshModalOpen,
onClose: onRefreshModalClose,
} = useDisclosure();
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const handleClickResetWebUI = () => {
persistor.purge().then(() => {
onSettingsModalClose();
onRefreshModalOpen();
});
};
return (
<>
{cloneElement(children, {
onClick: onSettingsModalOpen,
})}
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Settings</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex gap={5} direction='column'>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Display in-progress images (slower)
</FormLabel>
<Switch
isChecked={shouldDisplayInProgress}
onChange={(e) =>
dispatch(
setShouldDisplayInProgress(
e.target.checked
)
)
}
/>
</HStack>
</FormControl>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Confirm on delete
</FormLabel>
<Switch
isChecked={shouldConfirmOnDelete}
onChange={(e) =>
dispatch(
setShouldConfirmOnDelete(
e.target.checked
)
)
}
/>
</HStack>
</FormControl>
<Heading size={'md'}>Reset Web UI</Heading>
<Text>
Resetting the web UI only resets the browser's
local cache of your images and remembered
settings. It does not delete any images from
disk.
</Text>
<Text>
If images aren't showing up in the gallery or
something else isn't working, please try
resetting before submitting an issue on GitHub.
</Text>
<SDButton
label='Reset Web UI'
colorScheme='red'
onClick={handleClickResetWebUI}
/>
</Flex>
</ModalBody>
<ModalFooter>
<SDButton
label='Close'
onClick={onSettingsModalClose}
/>
</ModalFooter>
</ModalContent>
</Modal>
<Modal
closeOnOverlayClick={false}
isOpen={isRefreshModalOpen}
onClose={onRefreshModalClose}
isCentered
>
<ModalOverlay bg='blackAlpha.300' backdropFilter='blur(40px)' />
<ModalContent>
<ModalBody pb={6} pt={6}>
<Flex justifyContent={'center'}>
<Text fontSize={'lg'}>
Web UI has been reset. Refresh the page to
reload.
</Text>
</Flex>
</ModalBody>
</ModalContent>
</Modal>
</>
);
};
export default SettingsModal;

View File

@ -0,0 +1,98 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { ExpandedIndex } from '@chakra-ui/react';
export interface LogEntry {
timestamp: string;
message: string;
}
export interface Log {
[index: number]: LogEntry;
}
export interface SystemState {
shouldDisplayInProgress: boolean;
isProcessing: boolean;
currentStep: number;
log: Array<LogEntry>;
shouldShowLogViewer: boolean;
isGFPGANAvailable: boolean;
isESRGANAvailable: boolean;
isConnected: boolean;
socketId: string;
shouldConfirmOnDelete: boolean;
openAccordions: ExpandedIndex;
}
const initialSystemState = {
isConnected: false,
isProcessing: false,
currentStep: 0,
log: [],
shouldShowLogViewer: false,
shouldDisplayInProgress: false,
isGFPGANAvailable: true,
isESRGANAvailable: true,
socketId: '',
shouldConfirmOnDelete: true,
openAccordions: [0],
};
const initialState: SystemState = initialSystemState;
export const systemSlice = createSlice({
name: 'system',
initialState,
reducers: {
setShouldDisplayInProgress: (state, action: PayloadAction<boolean>) => {
state.shouldDisplayInProgress = action.payload;
},
setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload;
if (action.payload === false) {
state.currentStep = 0;
}
},
setCurrentStep: (state, action: PayloadAction<number>) => {
state.currentStep = action.payload;
},
addLogEntry: (state, action: PayloadAction<string>) => {
const entry: LogEntry = {
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: action.payload,
};
state.log.push(entry);
},
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
state.shouldShowLogViewer = action.payload;
},
setIsConnected: (state, action: PayloadAction<boolean>) => {
state.isConnected = action.payload;
},
setSocketId: (state, action: PayloadAction<string>) => {
state.socketId = action.payload;
},
setShouldConfirmOnDelete: (state, action: PayloadAction<boolean>) => {
state.shouldConfirmOnDelete = action.payload;
},
setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => {
state.openAccordions = action.payload;
},
},
});
export const {
setShouldDisplayInProgress,
setIsProcessing,
setCurrentStep,
addLogEntry,
setShouldShowLogViewer,
setIsConnected,
setSocketId,
setShouldConfirmOnDelete,
setOpenAccordions,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -0,0 +1,108 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
import { SystemState } from './systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
prompt: sd.prompt,
shouldGenerateVariations: sd.shouldGenerateVariations,
seedWeights: sd.seedWeights,
maskPath: sd.maskPath,
initialImagePath: sd.initialImagePath,
seed: sd.seed,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/*
Checks relevant pieces of state to confirm generation will not deterministically fail.
This is used to prevent the 'Generate' button from being clicked.
Other parameter values may cause failure but we rely on input validation for those.
*/
const useCheckParameters = () => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
seed,
} = useAppSelector(sdSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector);
return useMemo(() => {
// Cannot generate without a prompt
if (!prompt) {
return false;
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
return false;
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
return false;
}
// Cannot generate if not connected
if (!isConnected) {
return false;
}
// Cannot generate variations without valid seed weights
if (
shouldGenerateVariations &&
(!(validateSeedWeights(seedWeights) || seedWeights === '') ||
seed === -1)
) {
return false;
}
// All good
return true;
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
isConnected,
shouldGenerateVariations,
seedWeights,
seed,
]);
};
export default useCheckParameters;

26
frontend/src/main.tsx Normal file
View File

@ -0,0 +1,26 @@
import React from 'react';
import ReactDOM from 'react-dom/client';
import { ChakraProvider, ColorModeScript } from '@chakra-ui/react';
import { store } from './app/store';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);
import App from './App';
import { theme } from './app/theme';
import Loading from './Loading';
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<ChakraProvider theme={theme}>
<ColorModeScript initialColorMode={theme.config.initialColorMode} />
<App />
</ChakraProvider>
</PersistGate>
</Provider>
</React.StrictMode>
);

1
frontend/src/vite-env.d.ts vendored Normal file
View File

@ -0,0 +1 @@
/// <reference types="vite/client" />

21
frontend/tsconfig.json Normal file
View File

@ -0,0 +1,21 @@
{
"compilerOptions": {
"target": "ESNext",
"useDefineForClassFields": true,
"lib": ["DOM", "DOM.Iterable", "ESNext"],
"allowJs": false,
"skipLibCheck": true,
"esModuleInterop": false,
"allowSyntheticDefaultImports": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"module": "ESNext",
"moduleResolution": "Node",
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx"
},
"include": ["src", "index.d.ts"],
"references": [{ "path": "./tsconfig.node.json" }]
}

View File

@ -0,0 +1,9 @@
{
"compilerOptions": {
"composite": true,
"module": "ESNext",
"moduleResolution": "Node",
"allowSyntheticDefaultImports": true
},
"include": ["vite.config.ts"]
}

36
frontend/vite.config.ts Normal file
View File

@ -0,0 +1,36 @@
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
import eslint from 'vite-plugin-eslint';
// https://vitejs.dev/config/
export default defineConfig(({ mode }) => {
const common = {
plugins: [react(), eslint()],
server: {
proxy: {
'/outputs': {
target: 'http://localhost:9090/outputs',
changeOrigin: true,
rewrite: (path) => path.replace(/^\/outputs/, ''),
},
},
},
build: {
target: 'esnext',
chunkSizeWarningLimit: 1500, // we don't really care about chunk size
},
};
if (mode == 'development') {
return {
...common,
build: {
...common.build,
// sourcemap: true, // this can be enabled if needed, it adds ovwer 15MB to the commit
},
};
} else {
return {
...common,
};
}
});

3149
frontend/yarn.lock Normal file

File diff suppressed because it is too large Load Diff

View File

@ -58,6 +58,3 @@ def retrieve_metadata(img_path):
md = im.text.get('sd-metadata',{}) md = im.text.get('sd-metadata',{})
return json.loads(md) return json.loads(md)

0
server/__init__.py Normal file
View File

149
server/application.py Normal file
View File

@ -0,0 +1,149 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
"""Application module."""
import argparse
import json
import os
import sys
from flask import Flask
from flask_cors import CORS
from flask_socketio import SocketIO, join_room, leave_room
from omegaconf import OmegaConf
from dependency_injector.wiring import inject, Provide
from server import views
from server.containers import Container
from server.services import GeneratorService, SignalService
# The socketio_service is injected here (rather than created in run_app) to initialize it
@inject
def initialize_app(
app: Flask,
socketio: SocketIO = Provide[Container.socketio]
) -> SocketIO:
socketio.init_app(app)
return socketio
# The signal and generator services are injected to warm up the processing queues
# TODO: Initialize these a better way?
@inject
def initialize_generator(
signal_service: SignalService = Provide[Container.signal_service],
generator_service: GeneratorService = Provide[Container.generator_service]
):
pass
def run_app(config, host, port) -> Flask:
app = Flask(__name__, static_url_path='')
# Set up dependency injection container
container = Container()
container.config.from_dict(config)
container.wire(modules=[__name__])
app.container = container
# Set up CORS
CORS(app, resources={r'/api/*': {'origins': '*'}})
# Web Routes
app.add_url_rule('/', view_func=views.WebIndex.as_view('web_index', 'index.html'))
app.add_url_rule('/index.css', view_func=views.WebIndex.as_view('web_index_css', 'index.css'))
app.add_url_rule('/index.js', view_func=views.WebIndex.as_view('web_index_js', 'index.js'))
app.add_url_rule('/config.js', view_func=views.WebConfig.as_view('web_config'))
# API Routes
app.add_url_rule('/api/jobs', view_func=views.ApiJobs.as_view('api_jobs'))
app.add_url_rule('/api/cancel', view_func=views.ApiCancel.as_view('api_cancel'))
# TODO: Get storage root from config
app.add_url_rule('/api/images/<string:dreamId>', view_func=views.ApiImages.as_view('api_images', '../'))
app.add_url_rule('/api/intermediates/<string:dreamId>/<string:step>', view_func=views.ApiIntermediates.as_view('api_intermediates', '../'))
app.static_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../static/dream_web/'))
# Initialize
socketio = initialize_app(app)
initialize_generator()
print(">> Started Stable Diffusion api server!")
if host == '0.0.0.0':
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.")
# Run the app
socketio.run(app, host, port)
def main():
"""Initialize command-line parsers and the diffusion model"""
from scripts.dream import create_argv_parser
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
sys.exit(-1)
try:
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
print('* Initializing, be patient...\n')
sys.path.append('.')
from pytorch_lightning import logging
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
appConfig = {
"model": {
"width": width,
"height": height,
"sampler_name": opt.sampler_name,
"weights": weights,
"full_precision": opt.full_precision,
"config": config,
"grid": opt.grid,
"latent_diffusion_weights": opt.laion400m,
"embedding_path": opt.embedding_path,
"device_type": opt.device
}
}
# make sure the output directory exists
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
print('\n* starting api server...')
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
)
# Start server
try:
run_app(appConfig, opt.host, opt.port)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()

75
server/containers.py Normal file
View File

@ -0,0 +1,75 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
"""Containers module."""
from dependency_injector import containers, providers
from flask_socketio import SocketIO
from ldm.generate import Generate
from server import services
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(packages=['server'])
config = providers.Configuration()
socketio = providers.ThreadSafeSingleton(
SocketIO,
app = None
)
model_singleton = providers.ThreadSafeSingleton(
Generate,
width = config.model.width,
height = config.model.height,
sampler_name = config.model.sampler_name,
weights = config.model.weights,
full_precision = config.model.full_precision,
config = config.model.config,
grid = config.model.grid,
seamless = config.model.seamless,
embedding_path = config.model.embedding_path,
device_type = config.model.device_type
)
# TODO: get location from config
image_storage_service = providers.ThreadSafeSingleton(
services.ImageStorageService,
'./outputs/img-samples/'
)
# TODO: get location from config
image_intermediates_storage_service = providers.ThreadSafeSingleton(
services.ImageStorageService,
'./outputs/intermediates/'
)
signal_queue_service = providers.ThreadSafeSingleton(
services.SignalQueueService
)
signal_service = providers.ThreadSafeSingleton(
services.SignalService,
socketio = socketio,
queue = signal_queue_service
)
generation_queue_service = providers.ThreadSafeSingleton(
services.JobQueueService
)
# TODO: get locations from config
log_service = providers.ThreadSafeSingleton(
services.LogService,
'./outputs/img-samples/',
'dream_web_log.txt'
)
generator_service = providers.ThreadSafeSingleton(
services.GeneratorService,
model = model_singleton,
queue = generation_queue_service,
imageStorage = image_storage_service,
intermediateStorage = image_intermediates_storage_service,
log = log_service,
signal_service = signal_service
)

128
server/models.py Normal file
View File

@ -0,0 +1,128 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import json
import string
from copy import deepcopy
from datetime import datetime, timezone
from enum import Enum
class DreamRequest():
prompt: string
initimg: string
strength: float
iterations: int
steps: int
width: int
height: int
fit = None
cfgscale: float
sampler_name: string
gfpgan_strength: float
upscale_level: int
upscale_strength: float
upscale: None
progress_images = None
seed: int
time: int
# TODO: use something else for state tracking
images_generated: int = 0
images_upscaled: int = 0
def id(self, seed = None, upscaled = False) -> str:
return f"{self.time}.{seed or self.seed}{'.u' if upscaled else ''}"
# TODO: handle this more cleanly (probably by splitting this into a Job and Result class)
# TODO: Set iterations to 1 or remove it from the dream result? And just keep it on the job?
def clone_without_image(self, seed = None):
data = deepcopy(self)
data.initimg = None
if seed:
data.seed = seed
return data
def to_json(self, seed: int = None):
copy = self.clone_without_image(seed)
return json.dumps(copy.__dict__)
@staticmethod
def from_json(j, newTime: bool = False):
d = DreamRequest()
d.prompt = j.get('prompt')
d.initimg = j.get('initimg')
d.strength = float(j.get('strength'))
d.iterations = int(j.get('iterations'))
d.steps = int(j.get('steps'))
d.width = int(j.get('width'))
d.height = int(j.get('height'))
d.fit = 'fit' in j
d.seamless = 'seamless' in j
d.cfgscale = float(j.get('cfgscale'))
d.sampler_name = j.get('sampler')
d.variation_amount = float(j.get('variation_amount'))
d.with_variations = j.get('with_variations')
d.gfpgan_strength = float(j.get('gfpgan_strength'))
d.upscale_level = j.get('upscale_level')
d.upscale_strength = j.get('upscale_strength')
d.upscale = [int(d.upscale_level),float(d.upscale_strength)] if d.upscale_level != '' else None
d.progress_images = 'progress_images' in j
d.seed = int(j.get('seed'))
d.time = int(datetime.now(timezone.utc).timestamp()) if newTime else int(j.get('time'))
return d
class ProgressType(Enum):
GENERATION = 1
UPSCALING_STARTED = 2
UPSCALING_DONE = 3
class Signal():
event: str
data = None
room: str = None
broadcast: bool = False
def __init__(self, event: str, data, room: str = None, broadcast: bool = False):
self.event = event
self.data = data
self.room = room
self.broadcast = broadcast
@staticmethod
def image_progress(jobId: str, dreamId: str, step: int, totalSteps: int, progressType: ProgressType = ProgressType.GENERATION, hasProgressImage: bool = False):
return Signal('dream_progress', {
'jobId': jobId,
'dreamId': dreamId,
'step': step,
'totalSteps': totalSteps,
'hasProgressImage': hasProgressImage,
'progressType': progressType.name
}, room=jobId, broadcast=True)
# TODO: use a result id or something? Like a sub-job
@staticmethod
def image_result(jobId: str, dreamId: str, dreamRequest: DreamRequest):
return Signal('dream_result', {
'jobId': jobId,
'dreamId': dreamId,
'dreamRequest': dreamRequest.__dict__
}, room=jobId, broadcast=True)
@staticmethod
def job_started(jobId: str):
return Signal('job_started', {
'jobId': jobId
}, room=jobId, broadcast=True)
@staticmethod
def job_done(jobId: str):
return Signal('job_done', {
'jobId': jobId
}, room=jobId, broadcast=True)
@staticmethod
def job_canceled(jobId: str):
return Signal('job_canceled', {
'jobId': jobId
}, room=jobId, broadcast=True)

265
server/services.py Normal file
View File

@ -0,0 +1,265 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import base64
import os
from queue import Empty, Queue
from threading import Thread
import time
from flask import app, url_for
from flask_socketio import SocketIO, join_room, leave_room
from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import CanceledException
from ldm.generate import Generate
from server.models import DreamRequest, ProgressType, Signal
class JobQueueService:
__queue: Queue = Queue()
def push(self, dreamRequest: DreamRequest):
self.__queue.put(dreamRequest)
def get(self, timeout: float = None) -> DreamRequest:
return self.__queue.get(timeout= timeout)
class SignalQueueService:
__queue: Queue = Queue()
def push(self, signal: Signal):
self.__queue.put(signal)
def get(self) -> Signal:
return self.__queue.get(block=False)
class SignalService:
__socketio: SocketIO
__queue: SignalQueueService
def __init__(self, socketio: SocketIO, queue: SignalQueueService):
self.__socketio = socketio
self.__queue = queue
def on_join(data):
room = data['room']
join_room(room)
self.__socketio.emit("test", "something", room=room)
def on_leave(data):
room = data['room']
leave_room(room)
self.__socketio.on_event('join_room', on_join)
self.__socketio.on_event('leave_room', on_leave)
self.__socketio.start_background_task(self.__process)
def __process(self):
# preload the model
print('Started signal queue processor')
try:
while True:
try:
signal = self.__queue.get()
self.__socketio.emit(signal.event, signal.data, room=signal.room, broadcast=signal.broadcast)
except Empty:
pass
finally:
self.__socketio.sleep(0.001)
except KeyboardInterrupt:
print('Signal queue processor stopped')
def emit(self, signal: Signal):
self.__queue.push(signal)
# TODO: Name this better?
# TODO: Logging and signals should probably be event based (multiple listeners for an event)
class LogService:
__location: str
__logFile: str
def __init__(self, location:str, file:str):
self.__location = location
self.__logFile = file
def log(self, dreamRequest: DreamRequest, seed = None, upscaled = False):
with open(os.path.join(self.__location, self.__logFile), "a") as log:
log.write(f"{dreamRequest.id(seed, upscaled)}: {dreamRequest.to_json(seed)}\n")
class ImageStorageService:
__location: str
__pngWriter: PngWriter
def __init__(self, location):
self.__location = location
self.__pngWriter = PngWriter(self.__location)
def __getName(self, dreamId: str, postfix: str = '') -> str:
return f'{dreamId}{postfix}.png'
def save(self, image, dreamRequest, seed = None, upscaled = False, postfix: str = '', metadataPostfix: str = '') -> str:
name = self.__getName(dreamRequest.id(seed, upscaled), postfix)
path = self.__pngWriter.save_image_and_prompt_to_png(image, f'{dreamRequest.prompt} -S{seed or dreamRequest.seed}{metadataPostfix}', name)
return path
def path(self, dreamId: str, postfix: str = '') -> str:
name = self.__getName(dreamId, postfix)
path = os.path.join(self.__location, name)
return path
class GeneratorService:
__model: Generate
__queue: JobQueueService
__imageStorage: ImageStorageService
__intermediateStorage: ImageStorageService
__log: LogService
__thread: Thread
__cancellationRequested: bool = False
__signal_service: SignalService
def __init__(self, model: Generate, queue: JobQueueService, imageStorage: ImageStorageService, intermediateStorage: ImageStorageService, log: LogService, signal_service: SignalService):
self.__model = model
self.__queue = queue
self.__imageStorage = imageStorage
self.__intermediateStorage = intermediateStorage
self.__log = log
self.__signal_service = signal_service
# Create the background thread
self.__thread = Thread(target=self.__process, name = "GeneratorService")
self.__thread.daemon = True
self.__thread.start()
# Request cancellation of the current job
def cancel(self):
self.__cancellationRequested = True
# TODO: Consider moving this to its own service if there's benefit in separating the generator
def __process(self):
# preload the model
print('Preloading model')
tic = time.time()
self.__model.load_model()
print(
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
)
print('Started generation queue processor')
try:
while True:
dreamRequest = self.__queue.get()
self.__generate(dreamRequest)
except KeyboardInterrupt:
print('Generation queue processor stopped')
def __start(self, dreamRequest: DreamRequest):
if dreamRequest.start_callback:
dreamRequest.start_callback()
self.__signal_service.emit(Signal.job_started(dreamRequest.id()))
def __done(self, dreamRequest: DreamRequest, image, seed, upscaled=False):
self.__imageStorage.save(image, dreamRequest, seed, upscaled)
# TODO: handle upscaling logic better (this is appending data to log, but only on first generation)
if not upscaled:
self.__log.log(dreamRequest, seed, upscaled)
self.__signal_service.emit(Signal.image_result(dreamRequest.id(), dreamRequest.id(seed, upscaled), dreamRequest.clone_without_image(seed)))
upscaling_requested = dreamRequest.upscale or dreamRequest.gfpgan_strength>0
if upscaled:
dreamRequest.images_upscaled += 1
else:
dreamRequest.images_generated +=1
if upscaling_requested:
# action = None
if dreamRequest.images_generated >= dreamRequest.iterations:
progressType = ProgressType.UPSCALING_DONE
if dreamRequest.images_upscaled < dreamRequest.iterations:
progressType = ProgressType.UPSCALING_STARTED
self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(seed), dreamRequest.images_upscaled+1, dreamRequest.iterations, progressType))
def __progress(self, dreamRequest, sample, step):
if self.__cancellationRequested:
self.__cancellationRequested = False
raise CanceledException
hasProgressImage = False
if dreamRequest.progress_images and step % 5 == 0 and step < dreamRequest.steps - 1:
image = self.__model._sample_to_image(sample)
self.__intermediateStorage.save(image, dreamRequest, self.__model.seed, postfix=f'.{step}', metadataPostfix=f' [intermediate]')
hasProgressImage = True
self.__signal_service.emit(Signal.image_progress(dreamRequest.id(), dreamRequest.id(self.__model.seed), step, dreamRequest.steps, ProgressType.GENERATION, hasProgressImage))
def __generate(self, dreamRequest: DreamRequest):
try:
initimgfile = None
if dreamRequest.initimg is not None:
with open("./img2img-tmp.png", "wb") as f:
initimg = dreamRequest.initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
initimgfile = "./img2img-tmp.png"
# Get a random seed if we don't have one yet
# TODO: handle "previous" seed usage?
if dreamRequest.seed == -1:
dreamRequest.seed = self.__model.seed
# Zero gfpgan strength if the model doesn't exist
# TODO: determine if this could be at the top now? Used to cause circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
if not gfpgan_model_exists:
dreamRequest.gfpgan_strength = 0
self.__start(dreamRequest)
self.__model.prompt2image(
prompt = dreamRequest.prompt,
init_img = initimgfile, # TODO: ensure this works
strength = None if initimgfile is None else dreamRequest.strength,
fit = None if initimgfile is None else dreamRequest.fit,
iterations = dreamRequest.iterations,
cfg_scale = dreamRequest.cfgscale,
width = dreamRequest.width,
height = dreamRequest.height,
seed = dreamRequest.seed,
steps = dreamRequest.steps,
variation_amount = dreamRequest.variation_amount,
with_variations = dreamRequest.with_variations,
gfpgan_strength = dreamRequest.gfpgan_strength,
upscale = dreamRequest.upscale,
sampler_name = dreamRequest.sampler_name,
seamless = dreamRequest.seamless,
step_callback = lambda sample, step: self.__progress(dreamRequest, sample, step),
image_callback = lambda image, seed, upscaled=False: self.__done(dreamRequest, image, seed, upscaled))
except CanceledException:
if dreamRequest.cancelled_callback:
dreamRequest.cancelled_callback()
self.__signal_service.emit(Signal.job_canceled(dreamRequest.id()))
finally:
if dreamRequest.done_callback:
dreamRequest.done_callback()
self.__signal_service.emit(Signal.job_done(dreamRequest.id()))
# Remove the temp file
if (initimgfile is not None):
os.remove("./img2img-tmp.png")

99
server/views.py Normal file
View File

@ -0,0 +1,99 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
"""Views module."""
import json
import os
from queue import Queue
from flask import current_app, jsonify, request, Response, send_from_directory, stream_with_context, url_for
from flask.views import MethodView
from dependency_injector.wiring import inject, Provide
from server.models import DreamRequest
from server.services import GeneratorService, ImageStorageService, JobQueueService
from server.containers import Container
class ApiJobs(MethodView):
@inject
def post(self, job_queue_service: JobQueueService = Provide[Container.generation_queue_service]):
dreamRequest = DreamRequest.from_json(request.json, newTime = True)
#self.canceled.clear()
print(f">> Request to generate with prompt: {dreamRequest.prompt}")
q = Queue()
dreamRequest.start_callback = None
dreamRequest.image_callback = None
dreamRequest.progress_callback = None
dreamRequest.cancelled_callback = None
dreamRequest.done_callback = None
# Push the request
job_queue_service.push(dreamRequest)
return { 'dreamId': dreamRequest.id() }
class WebIndex(MethodView):
init_every_request = False
__file: str = None
def __init__(self, file):
self.__file = file
def get(self):
return current_app.send_static_file(self.__file)
class WebConfig(MethodView):
init_every_request = False
def get(self):
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
config = {
'gfpgan_model_exists': gfpgan_model_exists
}
js = f"let config = {json.dumps(config)};\n"
return Response(js, mimetype="application/javascript")
class ApiCancel(MethodView):
init_every_request = False
@inject
def get(self, generator_service: GeneratorService = Provide[Container.generator_service]):
generator_service.cancel()
return Response(status=204)
class ApiImages(MethodView):
init_every_request = False
__pathRoot = None
__storage: ImageStorageService
@inject
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_storage_service]):
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
self.__storage = storage
def get(self, dreamId):
name = self.__storage.path(dreamId)
fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))
class ApiIntermediates(MethodView):
init_every_request = False
__pathRoot = None
__storage: ImageStorageService = Provide[Container.image_intermediates_storage_service]
@inject
def __init__(self, pathBase, storage: ImageStorageService = Provide[Container.image_intermediates_storage_service]):
self.__pathRoot = os.path.abspath(os.path.join(os.path.dirname(__file__), pathBase))
self.__storage = storage
def get(self, dreamId, step):
name = self.__storage.path(dreamId, postfix=f'.{step}')
fullpath=os.path.join(self.__pathRoot, name)
return send_from_directory(os.path.dirname(fullpath), os.path.basename(fullpath))

View File

@ -2,11 +2,13 @@
<head> <head>
<title>Stable Diffusion Dream Server</title> <title>Stable Diffusion Dream Server</title>
<meta charset="utf-8"> <meta charset="utf-8">
<link rel="icon" type="image/x-icon" href="static/dream_web/favicon.ico" /> <link rel="icon" href="data:,">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="static/dream_web/index.css">
<link rel="stylesheet" href="index.css">
<script src="config.js"></script> <script src="config.js"></script>
<script src="static/dream_web/index.js"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js" integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA==" crossorigin="anonymous"></script>
<script src="index.js"></script>
</head> </head>
<body> <body>
<header> <header>
@ -17,7 +19,7 @@
</header> </header>
<main> <main>
<form id="generate-form" method="post" action="#"> <form id="generate-form" method="post" action="api/jobs">
<fieldset id="txt2img"> <fieldset id="txt2img">
<div id="search-box"> <div id="search-box">
<textarea rows="3" id="prompt" name="prompt"></textarea> <textarea rows="3" id="prompt" name="prompt"></textarea>
@ -30,10 +32,10 @@
<input value="1" type="number" id="iterations" name="iterations" size="4"> <input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label> <label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps"> <input value="50" type="number" id="steps" name="steps">
<label for="cfg_scale">Cfg Scale:</label> <label for="cfgscale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any"> <input value="7.5" type="number" id="cfgscale" name="cfgscale" step="any">
<label for="sampler_name">Sampler:</label> <label for="sampler">Sampler:</label>
<select id="sampler_name" name="sampler_name" value="k_lms"> <select id="sampler" name="sampler" value="k_lms">
<option value="ddim">DDIM</option> <option value="ddim">DDIM</option>
<option value="plms">PLMS</option> <option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option> <option value="k_lms" selected>KLMS</option>
@ -68,8 +70,8 @@
<option value="832">832</option> <option value="896">896</option> <option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option> <option value="960">960</option> <option value="1024">1024</option>
</select> </select>
<label title="Set to -1 for random seed" for="seed">Seed:</label> <label title="Set to 0 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed"> <input value="0" type="number" id="seed" name="seed">
<button type="button" id="reset-seed">&olarr;</button> <button type="button" id="reset-seed">&olarr;</button>
<input type="checkbox" name="progress_images" id="progress_images"> <input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slower)</label> <label for="progress_images">Display in-progress images (slower)</label>
@ -114,7 +116,7 @@
<br> <br>
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'> <img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'>
<div id="scaling-inprocess-message"> <div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i> <i><span>Postprocessing...</span><span id="processing_cnt">1</span>/<span id="processing_total">3</span></i>
</div> </div>
</span> </span>
</section> </section>

View File

@ -1,3 +1,41 @@
const socket = io();
function resetForm() {
var form = document.getElementById('generate-form');
form.querySelector('fieldset').removeAttribute('disabled');
}
function initProgress(totalSteps, showProgressImages) {
// TODO: Progress could theoretically come from multiple jobs at the same time (in the future)
let progressSectionEle = document.querySelector('#progress-section');
progressSectionEle.style.display = 'initial';
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('max', totalSteps);
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = BLANK_IMAGE_URL;
progressImageEle.style.display = showProgressImages ? 'initial': 'none';
}
function setProgress(step, totalSteps, src) {
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('value', step);
if (src) {
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = src;
}
}
function resetProgress(hide = true) {
if (hide) {
let progressSectionEle = document.querySelector('#progress-section');
progressSectionEle.style.display = 'none';
}
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('value', 0);
}
function toBase64(file) { function toBase64(file) {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const r = new FileReader(); const r = new FileReader();
@ -9,29 +47,17 @@ function toBase64(file) {
function appendOutput(src, seed, config) { function appendOutput(src, seed, config) {
let outputNode = document.createElement("figure"); let outputNode = document.createElement("figure");
let altText = seed.toString() + " | " + config.prompt;
let variations = config.with_variations;
if (config.variation_amount > 0) {
variations = (variations ? variations + ',' : '') + seed + ':' + config.variation_amount;
}
let baseseed = (config.with_variations || config.variation_amount > 0) ? config.seed : seed;
let altText = baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt;
// img needs width and height for lazy loading to work
const figureContents = ` const figureContents = `
<a href="${src}" target="_blank"> <a href="${src}" target="_blank">
<img src="${src}" <img src="${src}" alt="${altText}" title="${altText}">
alt="${altText}"
title="${altText}"
loading="lazy"
width="256"
height="256">
</a> </a>
<figcaption>${seed}</figcaption> <figcaption>${seed}</figcaption>
`; `;
outputNode.innerHTML = figureContents; outputNode.innerHTML = figureContents;
let figcaption = outputNode.querySelector('figcaption'); let figcaption = outputNode.querySelector('figcaption')
// Reload image config // Reload image config
figcaption.addEventListener('click', () => { figcaption.addEventListener('click', () => {
@ -40,17 +66,28 @@ function appendOutput(src, seed, config) {
if (k == 'initimg') { continue; } if (k == 'initimg') { continue; }
form.querySelector(`*[name=${k}]`).value = config[k]; form.querySelector(`*[name=${k}]`).value = config[k];
} }
if (config.variation_amount > 0 || config.with_variations != '') {
document.querySelector("#seed").value = config.seed;
} else {
document.querySelector("#seed").value = seed;
}
document.querySelector("#seed").value = baseseed; if (config.variation_amount > 0) {
document.querySelector("#with_variations").value = variations || ''; let oldVarAmt = document.querySelector("#variation_amount").value
if (document.querySelector("#variation_amount").value <= 0) { let oldVariations = document.querySelector("#with_variations").value
document.querySelector("#variation_amount").value = 0.2; let varSep = ''
document.querySelector("#variation_amount").value = 0;
if (document.querySelector("#with_variations").value != '') {
varSep = ","
}
document.querySelector("#with_variations").value = oldVariations + varSep + seed + ':' + config.variation_amount
} }
saveFields(document.querySelector("#generate-form")); saveFields(document.querySelector("#generate-form"));
}); });
document.querySelector("#results").prepend(outputNode); document.querySelector("#results").prepend(outputNode);
document.querySelector("#no-results-message")?.remove();
} }
function saveFields(form) { function saveFields(form) {
@ -79,9 +116,8 @@ function clearFields(form) {
const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'; const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>';
async function generateSubmit(form) { async function generateSubmit(form) {
const prompt = document.querySelector("#prompt").value;
// Convert file data to base64 // Convert file data to base64
// TODO: Should probably uplaod files with formdata or something, and store them in the backend?
let formData = Object.fromEntries(new FormData(form)); let formData = Object.fromEntries(new FormData(form));
formData.initimg_name = formData.initimg.name formData.initimg_name = formData.initimg.name
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
@ -89,90 +125,75 @@ async function generateSubmit(form) {
let strength = formData.strength; let strength = formData.strength;
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps; let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
let progressSectionEle = document.querySelector('#progress-section'); // Initialize the progress bar
progressSectionEle.style.display = 'initial'; initProgress(totalSteps);
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('max', totalSteps);
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = BLANK_IMAGE_URL;
progressImageEle.style.display = {}.hasOwnProperty.call(formData, 'progress_images') ? 'initial': 'none'; // POST, use response to listen for events
// Post as JSON, using Fetch streaming to get results
fetch(form.action, { fetch(form.action, {
method: form.method, method: form.method,
headers: new Headers({'content-type': 'application/json'}),
body: JSON.stringify(formData), body: JSON.stringify(formData),
}).then(async (response) => { })
const reader = response.body.getReader(); .then(response => response.json())
.then(data => {
let noOutputs = true; var dreamId = data.dreamId;
while (true) { socket.emit('join_room', { 'room': dreamId });
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) {
progressSectionEle.style.display = 'none';
break;
}
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event === 'result') {
noOutputs = false;
appendOutput(data.url, data.seed, data.config);
progressEle.setAttribute('value', 0);
progressEle.setAttribute('max', totalSteps);
} else if (data.event === 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (data.event === 'upscaling-done') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} else if (data.event === 'step') {
progressEle.setAttribute('value', data.step);
if (data.url) {
progressImageEle.src = data.url;
}
} else if (data.event === 'canceled') {
// avoid alerting as if this were an error case
noOutputs = false;
}
}
}
// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');
if (noOutputs) {
alert("Error occurred while generating.");
}
}); });
// Disable form while generating
form.querySelector('fieldset').setAttribute('disabled',''); form.querySelector('fieldset').setAttribute('disabled','');
document.querySelector("#prompt").value = `Generating: "${prompt}"`;
} }
async function fetchRunLog() { // Socket listeners
try { socket.on('job_started', (data) => {})
let response = await fetch('/run_log.json')
const data = await response.json();
for(let item of data.run_log) {
appendOutput(item.url, item.seed, item);
}
} catch (e) {
console.error(e);
}
}
window.onload = async () => { socket.on('dream_result', (data) => {
document.querySelector("#prompt").addEventListener("keydown", (e) => { var jobId = data.jobId;
if (e.key === "Enter" && !e.shiftKey) { var dreamId = data.dreamId;
const form = e.target.form; var dreamRequest = data.dreamRequest;
generateSubmit(form); var src = 'api/images/' + dreamId;
appendOutput(src, dreamRequest.seed, dreamRequest);
resetProgress(false);
})
socket.on('dream_progress', (data) => {
// TODO: it'd be nice if we could get a seed reported here, but the generator would need to be updated
var step = data.step;
var totalSteps = data.totalSteps;
var jobId = data.jobId;
var dreamId = data.dreamId;
var progressType = data.progressType
if (progressType === 'GENERATION') {
var src = data.hasProgressImage ?
'api/intermediates/' + dreamId + '/' + step
: null;
setProgress(step, totalSteps, src);
} else if (progressType === 'UPSCALING_STARTED') {
// step and totalSteps are used for upscale count on this message
document.getElementById("processing_cnt").textContent = step;
document.getElementById("processing_total").textContent = totalSteps;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (progressType == 'UPSCALING_DONE') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} }
}); })
socket.on('job_canceled', (data) => {
resetForm();
resetProgress();
})
socket.on('job_done', (data) => {
jobId = data.jobId
socket.emit('leave_room', { 'room': jobId });
resetForm();
resetProgress();
})
window.onload = () => {
document.querySelector("#generate-form").addEventListener('submit', (e) => { document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault(); e.preventDefault();
const form = e.target; const form = e.target;
@ -183,7 +204,7 @@ window.onload = async () => {
saveFields(e.target.form); saveFields(e.target.form);
}); });
document.querySelector("#reset-seed").addEventListener('click', (e) => { document.querySelector("#reset-seed").addEventListener('click', (e) => {
document.querySelector("#seed").value = -1; document.querySelector("#seed").value = 0;
saveFields(e.target.form); saveFields(e.target.form);
}); });
document.querySelector("#reset-all").addEventListener('click', (e) => { document.querySelector("#reset-all").addEventListener('click', (e) => {
@ -199,15 +220,8 @@ window.onload = async () => {
console.error(e); console.error(e);
}); });
}); });
document.documentElement.addEventListener('keydown', (e) => {
if (e.key === "Escape")
fetch('/cancel').catch(err => {
console.error(err);
});
});
if (!config.gfpgan_model_exists) { if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none'; document.querySelector("#gfpgan").style.display = 'none';
} }
await fetchRunLog()
}; };