Merge branch 'development' into mkdocs-updates

This commit is contained in:
Lincoln Stein 2022-09-20 17:11:43 -04:00 committed by GitHub
commit 555f21cd25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
104 changed files with 5425 additions and 4364 deletions

View File

@ -5,8 +5,7 @@ SAMPLES_DIR=${OUT_DIR}
python scripts/dream.py \ python scripts/dream.py \
--from_file ${PROMPT_FILE} \ --from_file ${PROMPT_FILE} \
--outdir ${OUT_DIR} \ --outdir ${OUT_DIR} \
--sampler plms \ --sampler plms
--full_precision
# original output by CompVis/stable-diffusion # original output by CompVis/stable-diffusion
IMAGE1=".dev_scripts/images/v1_4_astronaut_rides_horse_plms_step50_seed42.png" IMAGE1=".dev_scripts/images/v1_4_astronaut_rides_horse_plms_step50_seed42.png"

View File

@ -85,9 +85,9 @@ jobs:
fi fi
# Utterly hacky, but I don't know how else to do this # Utterly hacky, but I don't know how else to do this
if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt --full_precision time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt
elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt --full_precision time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt
fi fi
mkdir -p outputs/img-samples mkdir -p outputs/img-samples
- name: Archive results - name: Archive results

View File

@ -86,17 +86,14 @@ You wil need one of the following:
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies. - At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
> Note #### Note
>
> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
> full-precision mode as shown below.
Similarly, specify full-precision mode on Apple M1 hardware. Precision is auto configured based on the device. If however you encounter
errors like 'expected type Float but found Half' or 'not implemented for Half'
To run in full-precision mode, start `dream.py` with the `--full_precision` flag: you can try starting `dream.py` with the `--precision=float32` flag:
```bash ```bash
(ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision (ldm) ~/stable-diffusion$ python scripts/dream.py --precision=float32
``` ```
### Features ### Features
@ -125,6 +122,11 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
### Latest Changes ### Latest Changes
- vNEXT (TODO 2022)
- Deprecated `--full_precision` / `-F`. Simply omit it and `dream.py` will auto
configure. To switch away from auto use the new flag like `--precision=float32`.
- v1.14 (11 September 2022) - v1.14 (11 September 2022)
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs. - Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.

View File

@ -2,14 +2,14 @@ from modules.parse_seed_weights import parse_seed_weights
import argparse import argparse
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
'ddim', "ddim",
'k_dpm_2_a', "k_dpm_2_a",
'k_dpm_2', "k_dpm_2",
'k_euler_a', "k_euler_a",
'k_euler', "k_euler",
'k_heun', "k_heun",
'k_lms', "k_lms",
'plms', "plms",
] ]
@ -20,194 +20,42 @@ def parameters_to_command(params):
switches = list() switches = list()
if 'prompt' in params: if "prompt" in params:
switches.append(f'"{params["prompt"]}"') switches.append(f'"{params["prompt"]}"')
if 'steps' in params: if "steps" in params:
switches.append(f'-s {params["steps"]}') switches.append(f'-s {params["steps"]}')
if 'seed' in params: if "seed" in params:
switches.append(f'-S {params["seed"]}') switches.append(f'-S {params["seed"]}')
if 'width' in params: if "width" in params:
switches.append(f'-W {params["width"]}') switches.append(f'-W {params["width"]}')
if 'height' in params: if "height" in params:
switches.append(f'-H {params["height"]}') switches.append(f'-H {params["height"]}')
if 'cfg_scale' in params: if "cfg_scale" in params:
switches.append(f'-C {params["cfg_scale"]}') switches.append(f'-C {params["cfg_scale"]}')
if 'sampler_name' in params: if "sampler_name" in params:
switches.append(f'-A {params["sampler_name"]}') switches.append(f'-A {params["sampler_name"]}')
if 'seamless' in params and params["seamless"] == True: if "seamless" in params and params["seamless"] == True:
switches.append(f'--seamless') switches.append(f"--seamless")
if 'init_img' in params and len(params['init_img']) > 0: if "init_img" in params and len(params["init_img"]) > 0:
switches.append(f'-I {params["init_img"]}') switches.append(f'-I {params["init_img"]}')
if 'init_mask' in params and len(params['init_mask']) > 0: if "init_mask" in params and len(params["init_mask"]) > 0:
switches.append(f'-M {params["init_mask"]}') switches.append(f'-M {params["init_mask"]}')
if 'init_color' in params and len(params['init_color']) > 0: if "init_color" in params and len(params["init_color"]) > 0:
switches.append(f'--init_color {params["init_color"]}') switches.append(f'--init_color {params["init_color"]}')
if 'strength' in params and 'init_img' in params: if "strength" in params and "init_img" in params:
switches.append(f'-f {params["strength"]}') switches.append(f'-f {params["strength"]}')
if 'fit' in params and params["fit"] == True: if "fit" in params and params["fit"] == True:
switches.append(f'--fit') switches.append(f"--fit")
if 'gfpgan_strength' in params and params["gfpgan_strength"]: if "gfpgan_strength" in params and params["gfpgan_strength"]:
switches.append(f'-G {params["gfpgan_strength"]}') switches.append(f'-G {params["gfpgan_strength"]}')
if 'upscale' in params and params["upscale"]: if "upscale" in params and params["upscale"]:
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}') switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
if 'variation_amount' in params and params['variation_amount'] > 0: if "variation_amount" in params and params["variation_amount"] > 0:
switches.append(f'-v {params["variation_amount"]}') switches.append(f'-v {params["variation_amount"]}')
if 'with_variations' in params: if "with_variations" in params:
seed_weight_pairs = ','.join(f'{seed}:{weight}' for seed, weight in params["with_variations"]) seed_weight_pairs = ",".join(
switches.append(f'-V {seed_weight_pairs}') f"{seed}:{weight}" for seed, weight in params["with_variations"]
)
switches.append(f"-V {seed_weight_pairs}")
return ' '.join(switches) return " ".join(switches)
def create_cmd_parser():
"""
This is simply a copy of the parser from `dream.py` with a change to give
prompt a default value. This is a temporary hack pending merge of #587 which
provides a better way to do this.
"""
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12',
exit_on_error=True,
)
parser.add_argument('prompt', nargs='?', default='')
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-W', '--width', type=int, help='Image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='Image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default=None,
help='Directory to save generated images and a log of prompts and seeds',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='Generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-M',
'--init_mask',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
parser.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
parser.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
parser.add_argument(
'-U',
'--upscale',
nargs='+',
default=None,
type=float,
help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
)
parser.add_argument(
'-save_orig',
'--save_original',
action='store_true',
help='Save original. Use it when upscaling to save both versions.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='Skip subprompt weight normalization',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
default=None,
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
parser.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
parser.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
return parser

View File

@ -6,7 +6,8 @@ import traceback
import eventlet import eventlet
import glob import glob
import shlex import shlex
import argparse import math
import shutil
from flask_socketio import SocketIO from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify from flask import Flask, send_from_directory, url_for, jsonify
@ -15,13 +16,16 @@ from PIL import Image
from pytorch_lightning import logging from pytorch_lightning import logging
from threading import Event from threading import Event
from uuid import uuid4 from uuid import uuid4
from send2trash import send2trash
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
from ldm.gfpgan.gfpgan_tools import run_gfpgan from ldm.gfpgan.gfpgan_tools import run_gfpgan
from ldm.generate import Generate from ldm.generate import Generate
from ldm.dream.pngwriter import PngWriter, retrieve_metadata from ldm.dream.pngwriter import PngWriter, retrieve_metadata
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.dream.conditioning import split_weighted_subprompts
from modules.parameters import parameters_to_command, create_cmd_parser from modules.parameters import parameters_to_command
""" """
@ -29,12 +33,14 @@ USER CONFIG
""" """
output_dir = "outputs/" # Base output directory for images output_dir = "outputs/" # Base output directory for images
#host = 'localhost' # Web & socket.io host # host = 'localhost' # Web & socket.io host
host = '0.0.0.0' # Web & socket.io host host = "localhost" # Web & socket.io host
port = 9090 # Web & socket.io port port = 9090 # Web & socket.io port
verbose = False # enables copious socket.io logging verbose = False # enables copious socket.io logging
additional_allowed_origins = ['http://localhost:9090'] # additional CORS allowed origins additional_allowed_origins = [
"http://localhost:5173"
] # additional CORS allowed origins
model = "stable-diffusion-1.4"
""" """
END USER CONFIG END USER CONFIG
@ -47,26 +53,23 @@ SERVER SETUP
# fix missing mimetypes on windows due to registry wonkiness # fix missing mimetypes on windows due to registry wonkiness
mimetypes.add_type('application/javascript', '.js') mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type('text/css', '.css') mimetypes.add_type("text/css", ".css")
app = Flask(__name__, static_url_path='', static_folder='../frontend/dist/') app = Flask(__name__, static_url_path="", static_folder="../frontend/dist/")
app.config['OUTPUTS_FOLDER'] = "../outputs" app.config["OUTPUTS_FOLDER"] = "../outputs"
@app.route('/outputs/<path:filename>') @app.route("/outputs/<path:filename>")
def outputs(filename): def outputs(filename):
return send_from_directory( return send_from_directory(app.config["OUTPUTS_FOLDER"], filename)
app.config['OUTPUTS_FOLDER'],
filename
)
@app.route("/", defaults={'path': ''}) @app.route("/", defaults={"path": ""})
def serve(path): def serve(path):
return send_from_directory(app.static_folder, 'index.html') return send_from_directory(app.static_folder, "index.html")
logger = True if verbose else False logger = True if verbose else False
@ -78,12 +81,12 @@ max_http_buffer_size = 10000000
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
socketio = SocketIO( socketio = SocketIO(
app, app,
logger=logger, logger=logger,
engineio_logger=engineio_logger, engineio_logger=engineio_logger,
max_http_buffer_size=max_http_buffer_size, max_http_buffer_size=max_http_buffer_size,
cors_allowed_origins=cors_allowed_origins, cors_allowed_origins=cors_allowed_origins,
) )
""" """
@ -104,29 +107,31 @@ canceled = Event()
# reduce logging outputs to error # reduce logging outputs to error
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# Initialize and load model # Initialize and load model
model = Generate() generate = Generate(model)
model.load_model() generate.load_model()
# location for "finished" images # location for "finished" images
result_path = os.path.join(output_dir, 'img-samples/') result_path = os.path.join(output_dir, "img-samples/")
# temporary path for intermediates # temporary path for intermediates
intermediate_path = os.path.join(result_path, 'intermediates/') intermediate_path = os.path.join(result_path, "intermediates/")
# path for user-uploaded init images and masks # path for user-uploaded init images and masks
init_path = os.path.join(result_path, 'init-images/') init_image_path = os.path.join(result_path, "init-images/")
mask_path = os.path.join(result_path, 'mask-images/') mask_image_path = os.path.join(result_path, "mask-images/")
# txt log # txt log
log_path = os.path.join(result_path, 'dream_log.txt') log_path = os.path.join(result_path, "dream_log.txt")
# make all output paths # make all output paths
[os.makedirs(path, exist_ok=True) [
for path in [result_path, intermediate_path, init_path, mask_path]] os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_image_path, mask_image_path]
]
""" """
@ -139,126 +144,219 @@ SOCKET.IO LISTENERS
""" """
@socketio.on('requestAllImages') @socketio.on("requestSystemConfig")
def handle_request_capabilities():
print(f">> System config requested")
config = get_system_config()
socketio.emit("systemConfig", config)
@socketio.on("requestAllImages")
def handle_request_all_images(): def handle_request_all_images():
print(f'>> All images requested') print(f">> All images requested")
parser = create_cmd_parser()
paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png"))) paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png")))
paths.sort(key=lambda x: os.path.getmtime(x)) paths.sort(key=lambda x: os.path.getmtime(x))
image_array = [] image_array = []
for path in paths: for path in paths:
# image = Image.open(path) metadata = retrieve_metadata(path)
all_metadata = retrieve_metadata(path) image_array.append({"url": path, "metadata": metadata["sd-metadata"]})
if 'Dream' in all_metadata and not all_metadata['sd-metadata']: socketio.emit("galleryImages", {"images": image_array})
metadata = vars(parser.parse_args(shlex.split(all_metadata['Dream']))) eventlet.sleep(0)
else:
metadata = all_metadata['sd-metadata']
image_array.append({'path': path, 'metadata': metadata})
return make_response("OK", data=image_array)
@socketio.on('generateImage') @socketio.on("generateImage")
def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan_parameters): def handle_generate_image_event(
print(f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}') generation_parameters, esrgan_parameters, gfpgan_parameters
generate_images( ):
generation_parameters, print(
esrgan_parameters, f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}"
gfpgan_parameters
) )
return make_response("OK") generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
@socketio.on('runESRGAN') @socketio.on("runESRGAN")
def handle_run_esrgan_event(original_image, esrgan_parameters): def handle_run_esrgan_event(original_image, esrgan_parameters):
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}') print(
f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}'
)
progress = {
"currentStep": 1,
"totalSteps": 1,
"currentIteration": 1,
"totalIterations": 1,
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = Image.open(original_image["url"]) image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed' seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress["currentStatus"] = "Upscaling"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = real_esrgan_upscale( image = real_esrgan_upscale(
image=image, image=image,
upsampler_scale=esrgan_parameters['upscale'][0], upsampler_scale=esrgan_parameters["upscale"][0],
strength=esrgan_parameters['upscale'][1], strength=esrgan_parameters["upscale"][1],
seed=seed seed=seed,
) )
esrgan_parameters['seed'] = seed progress["currentStatus"] = "Saving image"
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan') socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
esrgan_parameters["seed"] = seed
metadata = parameters_to_post_processed_image_metadata(
parameters=esrgan_parameters,
original_image_path=original_image["url"],
type="esrgan",
)
command = parameters_to_command(esrgan_parameters) command = parameters_to_command(esrgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="esrgan")
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}') write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
progress["currentStatus"] = "Finished"
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit( socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters}) "esrganResult",
{
"url": os.path.relpath(path),
"metadata": metadata,
},
)
@socketio.on("runGFPGAN")
@socketio.on('runGFPGAN')
def handle_run_gfpgan_event(original_image, gfpgan_parameters): def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}') print(
f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}'
)
progress = {
"currentStep": 1,
"totalSteps": 1,
"currentIteration": 1,
"totalIterations": 1,
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = Image.open(original_image["url"]) image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed' seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress["currentStatus"] = "Fixing faces"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = run_gfpgan( image = run_gfpgan(
image=image, image=image,
strength=gfpgan_parameters['gfpgan_strength'], strength=gfpgan_parameters["gfpgan_strength"],
seed=seed, seed=seed,
upsampler_scale=1 upsampler_scale=1,
) )
gfpgan_parameters['seed'] = seed progress["currentStatus"] = "Saving image"
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan') socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
gfpgan_parameters["seed"] = seed
metadata = parameters_to_post_processed_image_metadata(
parameters=gfpgan_parameters,
original_image_path=original_image["url"],
type="gfpgan",
)
command = parameters_to_command(gfpgan_parameters) command = parameters_to_command(gfpgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="gfpgan")
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}') write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
progress["currentStatus"] = "Finished"
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit( socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters}) "gfpganResult",
{
"url": os.path.relpath(path),
"metadata": metadata,
},
)
@socketio.on('cancel') @socketio.on("cancel")
def handle_cancel(): def handle_cancel():
print(f'>> Cancel processing requested') print(f">> Cancel processing requested")
canceled.set() canceled.set()
return make_response("OK") socketio.emit("processingCanceled")
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@socketio.on('deleteImage') @socketio.on("deleteImage")
def handle_delete_image(path): def handle_delete_image(path, uuid):
print(f'>> Delete requested "{path}"') print(f'>> Delete requested "{path}"')
Path(path).unlink() send2trash(path)
return make_response("OK") socketio.emit("imageDeleted", {"url": path, "uuid": uuid})
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@socketio.on('uploadInitialImage') @socketio.on("uploadInitialImage")
def handle_upload_initial_image(bytes, name): def handle_upload_initial_image(bytes, name):
print(f'>> Init image upload requested "{name}"') print(f'>> Init image upload requested "{name}"')
uuid = uuid4().hex uuid = uuid4().hex
split = os.path.splitext(name) split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}' name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(init_path, name) file_path = os.path.join(init_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb") newFile = open(file_path, "wb")
newFile.write(bytes) newFile.write(bytes)
return make_response("OK", data=file_path) socketio.emit("initialImageUploaded", {"url": file_path, "uuid": ""})
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@socketio.on('uploadMaskImage') @socketio.on("uploadMaskImage")
def handle_upload_mask_image(bytes, name): def handle_upload_mask_image(bytes, name):
print(f'>> Mask image upload requested "{name}"') print(f'>> Mask image upload requested "{name}"')
uuid = uuid4().hex uuid = uuid4().hex
split = os.path.splitext(name) split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}' name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(mask_path, name) file_path = os.path.join(mask_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb") newFile = open(file_path, "wb")
newFile.write(bytes) newFile.write(bytes)
return make_response("OK", data=file_path) socketio.emit("maskImageUploaded", {"url": file_path, "uuid": ""})
""" """
@ -266,114 +364,343 @@ END SOCKET.IO LISTENERS
""" """
""" """
ADDITIONAL FUNCTIONS ADDITIONAL FUNCTIONS
""" """
def get_system_config():
return {
"model": "stable diffusion",
"model_id": model,
"model_hash": generate.model_hash,
"app_id": APP_ID,
"app_version": APP_VERSION,
}
def parameters_to_post_processed_image_metadata(parameters, original_image_path, type):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
orig_hash = calculate_init_img_hash(original_image_path)
image = {"orig_path": original_image_path, "orig_hash": orig_hash}
if type == "esrgan":
image["type"] = "esrgan"
image["scale"] = parameters["upscale"][0]
image["strength"] = parameters["upscale"][1]
elif type == "gfpgan":
image["type"] = "gfpgan"
image["strength"] = parameters["gfpgan_strength"]
else:
raise TypeError(f"Invalid type: {type}")
metadata["image"] = image
return metadata
def parameters_to_generated_image_metadata(parameters):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = [
"type",
"postprocessing",
"sampler",
"prompt",
"seed",
"variations",
"steps",
"cfg_scale",
"step_number",
"width",
"height",
"extra",
"seamless",
]
rfc_dict = {}
for item in parameters.items():
key, value = item
if key in rfc266_img_fields:
rfc_dict[key] = value
postprocessing = []
# 'postprocessing' is either null or an
if "gfpgan_strength" in parameters:
postprocessing.append(
{"type": "gfpgan", "strength": float(parameters["gfpgan_strength"])}
)
if "upscale" in parameters:
postprocessing.append(
{
"type": "esrgan",
"scale": int(parameters["upscale"][0]),
"strength": float(parameters["upscale"][1]),
}
)
rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None
# semantic drift
rfc_dict["sampler"] = parameters["sampler_name"]
# display weighted subprompts (liable to change)
subprompts = split_weighted_subprompts(parameters["prompt"])
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
variations = []
if "with_variations" in parameters:
variations = [
{"seed": x[0], "weight": x[1]} for x in parameters["with_variations"]
]
rfc_dict["variations"] = variations
if "init_img" in parameters:
rfc_dict["type"] = "img2img"
rfc_dict["strength"] = parameters["strength"]
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
rfc_dict["orig_hash"] = calculate_init_img_hash(parameters["init_img"])
rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant
rfc_dict["sampler"] = "ddim" # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
if "init_mask" in parameters:
rfc_dict["mask_hash"] = calculate_init_img_hash(
parameters["init_mask"]
) # TODO: Noncompliant
rfc_dict["mask_image_path"] = parameters["init_mask"] # TODO: Noncompliant
else:
rfc_dict["type"] = "txt2img"
metadata["image"] = rfc_dict
return metadata
def make_unique_init_image_filename(name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f"{split[0]}.{uuid}{split[1]}"
return name
def write_log_message(message, log_path=log_path): def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file""" """Logs the filename and parameters used to generate or process that image to log file"""
message = f'{message}\n' message = f"{message}\n"
with open(log_path, 'a', encoding='utf-8') as file: with open(log_path, "a", encoding="utf-8") as file:
file.writelines(message) file.writelines(message)
def make_response(status, message=None, data=None): def save_image(
response = {'status': status} image, command, metadata, output_dir, step_index=None, postprocessing=False
if message is not None: ):
response['message'] = message
if data is not None:
response['data'] = data
return response
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
pngwriter = PngWriter(output_dir) pngwriter = PngWriter(output_dir)
prefix = pngwriter.unique_prefix() prefix = pngwriter.unique_prefix()
filename = f'{prefix}.{seed}' seed = "unknown_seed"
if "image" in metadata:
if "seed" in metadata["image"]:
seed = metadata["image"]["seed"]
filename = f"{prefix}.{seed}"
if step_index: if step_index:
filename += f'.{step_index}' filename += f".{step_index}"
if postprocessing: if postprocessing:
filename += f'.postprocessed' filename += f".postprocessed"
filename += '.png' filename += ".png"
command = parameters_to_command(parameters) path = pngwriter.save_image_and_prompt_to_png(
image=image, dream_prompt=command, metadata=metadata, name=filename
path = pngwriter.save_image_and_prompt_to_png(image, command, metadata=parameters, name=filename) )
return path return path
def calculate_real_steps(steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters): def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear() canceled.clear()
step_index = 1 step_index = 1
"""
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it.
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
make a unique filename for it and copy it there.
"""
if "init_img" in generation_parameters:
filename = os.path.basename(generation_parameters["init_img"])
if not os.path.exists(os.path.join(init_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters["init_img"] = new_path
if "init_mask" in generation_parameters:
filename = os.path.basename(generation_parameters["init_mask"])
if not os.path.exists(os.path.join(mask_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters["init_mask"] = new_path
totalSteps = calculate_real_steps(
steps=generation_parameters["steps"],
strength=generation_parameters["strength"]
if "strength" in generation_parameters
else None,
has_init_image="init_img" in generation_parameters,
)
progress = {
"currentStep": 1,
"totalSteps": totalSteps,
"currentIteration": 1,
"totalIterations": generation_parameters["iterations"],
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
def image_progress(sample, step): def image_progress(sample, step):
if canceled.is_set(): if canceled.is_set():
raise CanceledException raise CanceledException
nonlocal step_index nonlocal step_index
nonlocal generation_parameters nonlocal generation_parameters
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1: nonlocal progress
image = model.sample_to_image(sample)
path = save_image(image, generation_parameters, intermediate_path, step_index) progress["currentStep"] = step + 1
progress["currentStatus"] = "Generating"
progress["currentStatusHasSteps"] = True
if (
generation_parameters["progress_images"]
and step % 5 == 0
and step < generation_parameters["steps"] - 1
):
image = generate.sample_to_image(sample)
path = save_image(
image, generation_parameters, intermediate_path, step_index
)
step_index += 1 step_index += 1
socketio.emit('intermediateResult', { socketio.emit(
'url': os.path.relpath(path), 'metadata': generation_parameters}) "intermediateResult",
socketio.emit('progress', {'step': step + 1}) {"url": os.path.relpath(path), "metadata": generation_parameters},
)
socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed): def image_done(image, seed):
nonlocal generation_parameters nonlocal generation_parameters
nonlocal esrgan_parameters nonlocal esrgan_parameters
nonlocal gfpgan_parameters nonlocal gfpgan_parameters
nonlocal progress
step_index = 1
progress["currentStatus"] = "Generation complete"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
all_parameters = generation_parameters all_parameters = generation_parameters
postprocessing = False postprocessing = False
if esrgan_parameters: if esrgan_parameters:
progress["currentStatus"] = "Upscaling"
progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = real_esrgan_upscale( image = real_esrgan_upscale(
image=image, image=image,
strength=esrgan_parameters['strength'], strength=esrgan_parameters["strength"],
upsampler_scale=esrgan_parameters['level'], upsampler_scale=esrgan_parameters["level"],
seed=seed seed=seed,
) )
postprocessing = True postprocessing = True
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']] all_parameters["upscale"] = [
esrgan_parameters["level"],
esrgan_parameters["strength"],
]
if gfpgan_parameters: if gfpgan_parameters:
progress["currentStatus"] = "Fixing faces"
progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = run_gfpgan( image = run_gfpgan(
image=image, image=image,
strength=gfpgan_parameters['strength'], strength=gfpgan_parameters["strength"],
seed=seed, seed=seed,
upsampler_scale=1, upsampler_scale=1,
) )
postprocessing = True postprocessing = True
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength'] all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
all_parameters['seed'] = seed all_parameters["seed"] = seed
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing) metadata = parameters_to_generated_image_metadata(all_parameters)
command = parameters_to_command(all_parameters) command = parameters_to_command(all_parameters)
print(f'Image generated: "{path}"') path = save_image(
image, command, metadata, result_path, postprocessing=postprocessing
)
print(f'>> Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}') write_log_message(f'[Generated] "{path}": {command}')
if progress["totalIterations"] > progress["currentIteration"]:
progress["currentStep"] = 1
progress["currentIteration"] += 1
progress["currentStatus"] = "Iteration finished"
progress["currentStatusHasSteps"] = False
else:
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["currentStatus"] = "Finished"
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit( socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters}) "generationResult",
{"url": os.path.relpath(path), "metadata": metadata},
)
eventlet.sleep(0) eventlet.sleep(0)
try: try:
model.prompt2image( generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=image_progress, step_callback=image_progress,
image_callback=image_done image_callback=image_done,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -381,7 +708,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
except CanceledException: except CanceledException:
pass pass
except Exception as e: except Exception as e:
socketio.emit('error', (str(e))) socketio.emit("error", {"message": (str(e))})
print("\n") print("\n")
traceback.print_exc() traceback.print_exc()
print("\n") print("\n")
@ -392,6 +719,6 @@ END ADDITIONAL FUNCTIONS
""" """
if __name__ == '__main__': if __name__ == "__main__":
print(f'Starting server at http://{host}:{port}') print(f">> Starting server at http://{host}:{port}")
socketio.run(app, host=host, port=port) socketio.run(app, host=host, port=port)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 22 KiB

View File

@ -59,9 +59,7 @@ Once the model is trained, specify the trained .pt or .bin file when starting
dream using dream using
```bash ```bash
python3 ./scripts/dream.py \ python3 ./scripts/dream.py --embedding_path /path/to/embedding.pt
--embedding_path /path/to/embedding.pt \
--full_precision
``` ```
Then, to utilize your subject at the dream prompt Then, to utilize your subject at the dream prompt

View File

@ -97,6 +97,11 @@ You wil need one of the following:
``` ```
## :octicons-log-16: Latest Changes ## :octicons-log-16: Latest Changes
### vNEXT <small>(TODO 2022)</small>
- Deprecated `--full_precision` / `-F`. Simply omit it and `dream.py` will auto
configure. To switch away from auto use the new flag like `--precision=float32`.
### v1.14 <small>(11 September 2022)</small> ### v1.14 <small>(11 September 2022)</small>
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs. - Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.

View File

@ -106,7 +106,6 @@ PATH_TO_CKPT="$HOME/Downloads" # (1)!
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" \ ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" \
models/ldm/stable-diffusion-v1/model.ckpt models/ldm/stable-diffusion-v1/model.ckpt
``` ```
1. or wherever you saved sd-v1-4.ckpt 1. or wherever you saved sd-v1-4.ckpt
@ -548,5 +547,3 @@ Abort trap: 6
warnings.warn('resource_tracker: There appear to be %d ' warnings.warn('resource_tracker: There appear to be %d '
``` ```
Macs do not support `autocast/mixed-precision`, so you need to supply
`--full_precision` to use float32 everywhere.

View File

@ -32,6 +32,7 @@ dependencies:
- omegaconf==2.1.1 - omegaconf==2.1.1
- onnx==1.12.0 - onnx==1.12.0
- onnxruntime==1.12.1 - onnxruntime==1.12.1
- protobuf==3.20.1
- pudb==2022.1 - pudb==2022.1
- pytorch-lightning==1.6.5 - pytorch-lightning==1.6.5
- scipy==1.9.1 - scipy==1.9.1
@ -48,6 +49,7 @@ dependencies:
- opencv-python==4.6.0 - opencv-python==4.6.0
- protobuf==3.20.1 - protobuf==3.20.1
- realesrgan==0.2.5.0 - realesrgan==0.2.5.0
- send2trash==1.8.0
- test-tube==0.7.5 - test-tube==0.7.5
- transformers==4.21.2 - transformers==4.21.2
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0

View File

@ -20,6 +20,7 @@ dependencies:
- realesrgan==0.2.5.0 - realesrgan==0.2.5.0
- test-tube>=0.7.5 - test-tube>=0.7.5
- streamlit==1.12.0 - streamlit==1.12.0
- send2trash==1.8.0
- pillow==9.2.0 - pillow==9.2.0
- einops==0.3.0 - einops==0.3.0
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0

View File

@ -1,85 +1,37 @@
# Stable Diffusion Web UI # Stable Diffusion Web UI
Demo at https://peaceful-otter-7a427f.netlify.app/ (not connected to back end) ## Run
much of this readme is just notes for myself during dev work - `python backend/server.py` serves both frontend and backend at http://localhost:9090
numpy rand: 0 to 4294967295 ## Evironment
## Test and Build Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
[yarn](https://yarnpkg.com/getting-started/install).
from `frontend/`: From `frontend/` run `npm install` / `yarn install` to install the frontend packages.
- `yarn dev` runs `tsc-watch`, which runs `vite build` on successful `tsc` transpilation ## Dev
from `.`: 1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
2. Note the address it starts up on (probably `http://localhost:5173/`).
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g.
`additional_allowed_origins = ['http://localhost:5173']`.
4. Leaving the dev server running, open a new terminal and go to the project root.
5. Run `python backend/server.py`.
6. Navigate to the dev server address e.g. `http://localhost:5173/`.
- `python backend/server.py` serves both frontend and backend at http://localhost:9090 To build for dev: `npm build-dev` / `yarn build-dev`
## API To build for production: `npm build` / `yarn build`
`backend/server.py` serves the UI and provides a [socket.io](https://github.com/socketio/socket.io) API via [flask-socketio](https://github.com/miguelgrinberg/flask-socketio).
### Server Listeners
The server listens for these socket.io events:
`cancel`
- Cancels in-progress image generation
- Returns ack only
`generateImage`
- Accepts object of image parameters
- Generates an image
- Returns ack only (image generation function sends progress and result via separate events)
`deleteImage`
- Accepts file path to image
- Deletes image
- Returns ack only
`deleteAllImages` WIP
- Deletes all images in `outputs/`
- Returns ack only
`requestAllImages`
- Returns array of all images in `outputs/`
`requestCapabilities` WIP
- Returns capabilities of server (torch device, GFPGAN and ESRGAN availability, ???)
`sendImage` WIP
- Accepts a File and attributes
- Saves image
- Used to save init images which are not generated images
### Server Emitters
`progress`
- Emitted during each step in generation
- Sends a number from 0 to 1 representing percentage of steps completed
`result` WIP
- Emitted when an image generation has completed
- Sends a object:
```
{
url: relative_file_path,
metadata: image_metadata_object
}
```
## TODO ## TODO
- Search repo for "TODO" - Search repo for "TODO"
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on this. Need to check in on this issue periodically. - My one gripe with Chakra: no way to disable all animations right now and drop the dependence on
`framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on
the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on
this. Need to check in on this issue periodically.
- Mobile friendly layout
- Proper image gallery/viewer/manager
- Help tooltips and such

694
frontend/dist/assets/index.727a397b.js vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title> <title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.cc5cde43.js"></script> <script type="module" crossorigin src="/assets/index.727a397b.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css"> <link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head> </head>
<body> <body>

View File

@ -3,7 +3,7 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title> <title>InvokeAI Stable Diffusion Dream Server</title>
</head> </head>
<body> <body>
<div id="root"></div> <div id="root"></div>

View File

@ -1,16 +1,16 @@
{ {
"name": "sdui", "name": "invoke-ai-ui",
"private": true, "private": true,
"version": "0.0.0", "version": "0.0.1",
"type": "module", "type": "module",
"scripts": { "scripts": {
"dev": "tsc-watch --onSuccess 'yarn run vite build -m development'", "dev": "vite dev",
"hmr": "vite dev",
"build": "tsc && vite build", "build": "tsc && vite build",
"build-dev": "tsc && vite build -m development", "build-dev": "tsc && vite build -m development",
"preview": "vite preview" "preview": "vite preview"
}, },
"dependencies": { "dependencies": {
"@chakra-ui/icons": "^2.0.10",
"@chakra-ui/react": "^2.3.1", "@chakra-ui/react": "^2.3.1",
"@emotion/react": "^11.10.4", "@emotion/react": "^11.10.4",
"@emotion/styled": "^11.10.4", "@emotion/styled": "^11.10.4",

View File

@ -1,60 +0,0 @@
import { Grid, GridItem } from '@chakra-ui/react';
import CurrentImage from './features/gallery/CurrentImage';
import LogViewer from './features/system/LogViewer';
import PromptInput from './features/sd/PromptInput';
import ProgressBar from './features/header/ProgressBar';
import { useEffect } from 'react';
import { useAppDispatch } from './app/hooks';
import { requestAllImages } from './app/socketio';
import ProcessButtons from './features/sd/ProcessButtons';
import ImageRoll from './features/gallery/ImageRoll';
import SiteHeader from './features/header/SiteHeader';
import OptionsAccordion from './features/sd/OptionsAccordion';
const App = () => {
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(requestAllImages());
}, [dispatch]);
return (
<>
<Grid
width='100vw'
height='100vh'
templateAreas={`
"header header header header"
"progressBar progressBar progressBar progressBar"
"menu prompt processButtons imageRoll"
"menu currentImage currentImage imageRoll"`}
gridTemplateRows={'36px 10px 100px auto'}
gridTemplateColumns={'350px auto 100px 388px'}
gap={2}
>
<GridItem area={'header'} pt={1}>
<SiteHeader />
</GridItem>
<GridItem area={'progressBar'}>
<ProgressBar />
</GridItem>
<GridItem pl='2' area={'menu'} overflowY='scroll'>
<OptionsAccordion />
</GridItem>
<GridItem area={'prompt'}>
<PromptInput />
</GridItem>
<GridItem area={'processButtons'}>
<ProcessButtons />
</GridItem>
<GridItem area={'currentImage'}>
<CurrentImage />
</GridItem>
<GridItem pr='2' area={'imageRoll'} overflowY='scroll'>
<ImageRoll />
</GridItem>
</Grid>
<LogViewer />
</>
);
};
export default App;

69
frontend/src/app/App.tsx Normal file
View File

@ -0,0 +1,69 @@
import { Grid, GridItem } from '@chakra-ui/react';
import { useEffect, useState } from 'react';
import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
import ImageGallery from '../features/gallery/ImageGallery';
import ProgressBar from '../features/system/ProgressBar';
import SiteHeader from '../features/system/SiteHeader';
import OptionsAccordion from '../features/options/OptionsAccordion';
import ProcessButtons from '../features/options/ProcessButtons';
import PromptInput from '../features/options/PromptInput';
import LogViewer from '../features/system/LogViewer';
import Loading from '../Loading';
import { useAppDispatch } from './store';
import { requestAllImages, requestSystemConfig } from './socketio/actions';
const App = () => {
const dispatch = useAppDispatch();
const [isReady, setIsReady] = useState<boolean>(false);
// Load images from the gallery once
useEffect(() => {
dispatch(requestAllImages());
dispatch(requestSystemConfig());
setIsReady(true);
}, [dispatch]);
return isReady ? (
<>
<Grid
width="100vw"
height="100vh"
templateAreas={`
"header header header header"
"progressBar progressBar progressBar progressBar"
"menu prompt processButtons imageRoll"
"menu currentImage currentImage imageRoll"`}
gridTemplateRows={'36px 10px 100px auto'}
gridTemplateColumns={'350px auto 100px 388px'}
gap={2}
>
<GridItem area={'header'} pt={1}>
<SiteHeader />
</GridItem>
<GridItem area={'progressBar'}>
<ProgressBar />
</GridItem>
<GridItem pl="2" area={'menu'} overflowY="scroll">
<OptionsAccordion />
</GridItem>
<GridItem area={'prompt'}>
<PromptInput />
</GridItem>
<GridItem area={'processButtons'}>
<ProcessButtons />
</GridItem>
<GridItem area={'currentImage'}>
<CurrentImageDisplay />
</GridItem>
<GridItem pr="2" area={'imageRoll'} overflowY="scroll">
<ImageGallery />
</GridItem>
</Grid>
<LogViewer />
</>
) : (
<Loading />
);
};
export default App;

View File

@ -2,52 +2,52 @@
// Valid samplers // Valid samplers
export const SAMPLERS: Array<string> = [ export const SAMPLERS: Array<string> = [
'ddim', 'ddim',
'plms', 'plms',
'k_lms', 'k_lms',
'k_dpm_2', 'k_dpm_2',
'k_dpm_2_a', 'k_dpm_2_a',
'k_euler', 'k_euler',
'k_euler_a', 'k_euler_a',
'k_heun', 'k_heun',
]; ];
// Valid image widths // Valid image widths
export const WIDTHS: Array<number> = [ export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1024,
]; ];
// Valid image heights // Valid image heights
export const HEIGHTS: Array<number> = [ export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1024,
]; ];
// Valid upscaling levels // Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [ export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 }, { key: '2x', value: 2 },
{ key: '4x', value: 4 }, { key: '4x', value: 4 },
]; ];
// Internal to human-readable parameters // Internal to human-readable parameters
export const PARAMETERS: { [key: string]: string } = { export const PARAMETERS: { [key: string]: string } = {
prompt: 'Prompt', prompt: 'Prompt',
iterations: 'Iterations', iterations: 'Iterations',
steps: 'Steps', steps: 'Steps',
cfgScale: 'CFG Scale', cfgScale: 'CFG Scale',
height: 'Height', height: 'Height',
width: 'Width', width: 'Width',
sampler: 'Sampler', sampler: 'Sampler',
seed: 'Seed', seed: 'Seed',
img2imgStrength: 'img2img Strength', img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength', gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level', upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength', upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image', initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask', maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image', shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling', seamless: 'Seamless Tiling',
}; };
export const NUMPY_RAND_MIN = 0; export const NUMPY_RAND_MIN = 0;

View File

@ -1,7 +0,0 @@
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import type { RootState, AppDispatch } from './store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

170
frontend/src/app/invokeai.d.ts vendored Normal file
View File

@ -0,0 +1,170 @@
/**
* Types for images, the things they are made of, and the things
* they make up.
*
* Generated images are txt2img and img2img images. They may have
* had additional postprocessing done on them when they were first
* generated.
*
* Postprocessed images are images which were not generated here
* but only postprocessed by the app. They only get postprocessing
* metadata and have a different image type, e.g. 'esrgan' or
* 'gfpgan'.
*/
/**
* TODO:
* Once an image has been generated, if it is postprocessed again,
* additional postprocessing steps are added to its postprocessing
* array.
*
* TODO: Better documentation of types.
*/
export declare type PromptItem = {
prompt: string;
weight: number;
};
export declare type Prompt = Array<PromptItem>;
export declare type SeedWeightPair = {
seed: number;
weight: number;
};
export declare type SeedWeights = Array<SeedWeightPair>;
// All generated images contain these metadata.
export declare type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | GFPGANMetadata>;
sampler:
| 'ddim'
| 'k_dpm_2_a'
| 'k_dpm_2'
| 'k_euler_a'
| 'k_euler'
| 'k_heun'
| 'k_lms'
| 'plms';
prompt: Prompt;
seed: number;
variations: SeedWeights;
steps: number;
cfg_scale: number;
width: number;
height: number;
seamless: boolean;
extra: null | Record<string, never>; // Pending development of RFC #266
};
// txt2img and img2img images have some unique attributes.
export declare type Txt2ImgMetadata = GeneratedImageMetadata & {
type: 'txt2img';
};
export declare type Img2ImgMetadata = GeneratedImageMetadata & {
type: 'img2img';
orig_hash: string;
strength: number;
fit: boolean;
init_image_path: string;
mask_image_path?: string;
};
// Superset of generated image metadata types.
export declare type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
// All post processed images contain these metadata.
export declare type CommonPostProcessedImageMetadata = {
orig_path: string;
orig_hash: string;
};
// esrgan and gfpgan images have some unique attributes.
export declare type ESRGANMetadata = CommonPostProcessedImageMetadata & {
type: 'esrgan';
scale: 2 | 4;
strength: number;
};
export declare type GFPGANMetadata = CommonPostProcessedImageMetadata & {
type: 'gfpgan';
strength: number;
};
// Superset of all postprocessed image metadata types..
export declare type PostProcessedImageMetadata =
| ESRGANMetadata
| GFPGANMetadata;
// Metadata includes the system config and image metadata.
export declare type Metadata = SystemConfig & {
image: GeneratedImageMetadata | PostProcessedImageMetadata;
};
// An Image has a UUID, url (path?) and Metadata.
export declare type Image = {
uuid: string;
url: string;
metadata: Metadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<Image>;
};
/**
* Types related to the system status.
*/
// This represents the processing status of the backend.
export declare type SystemStatus = {
isProcessing: boolean;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
};
export declare type SystemConfig = {
model: string;
model_id: string;
model_hash: string;
app_id: string;
app_version: string;
};
/**
* These types type data received from the server via socketio.
*/
export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = {
url: string;
metadata: Metadata;
};
export declare type ErrorResponse = {
message: string;
additionalData?: string;
};
export declare type GalleryImagesResponse = {
images: Array<{ url: string; metadata: Metadata }>;
};
export declare type ImageUrlAndUuidResponse = {
uuid: string;
url: string;
};
export declare type ImageUrlResponse = {
url: string;
};

View File

@ -1,182 +0,0 @@
import { SDState } from '../features/sd/sdSlice';
import randomInt from '../features/sd/util/randomInt';
import {
seedWeightsToString,
stringToSeedWeights,
} from '../features/sd/util/seedWeightPairs';
import { SystemState } from '../features/system/systemSlice';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
/*
These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa.
*/
export const frontendToBackendParameters = (
sdState: SDState,
systemState: SystemState
): { [key: string]: any } => {
const {
prompt,
iterations,
steps,
cfgScale,
height,
width,
sampler,
seed,
seamless,
shouldUseInitImage,
img2imgStrength,
initialImagePath,
maskPath,
shouldFitToWidthHeight,
shouldGenerateVariations,
variantAmount,
seedWeights,
shouldRunESRGAN,
upscalingLevel,
upscalingStrength,
shouldRunGFPGAN,
gfpganStrength,
shouldRandomizeSeed,
} = sdState;
const { shouldDisplayInProgress } = systemState;
const generationParameters: { [k: string]: any } = {
prompt,
iterations,
steps,
cfg_scale: cfgScale,
height,
width,
sampler_name: sampler,
seed,
seamless,
progress_images: shouldDisplayInProgress,
};
generationParameters.seed = shouldRandomizeSeed
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: seed;
if (shouldUseInitImage) {
generationParameters.init_img = initialImagePath;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
if (maskPath) {
generationParameters.init_mask = maskPath;
}
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variantAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeights(seedWeights);
}
} else {
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
let gfpganParameters: false | { [k: string]: any } = false;
if (shouldRunESRGAN) {
esrganParameters = {
level: upscalingLevel,
strength: upscalingStrength,
};
}
if (shouldRunGFPGAN) {
gfpganParameters = {
strength: gfpganStrength,
};
}
return {
generationParameters,
esrganParameters,
gfpganParameters,
};
};
export const backendToFrontendParameters = (parameters: {
[key: string]: any;
}) => {
const {
prompt,
iterations,
steps,
cfg_scale,
height,
width,
sampler_name,
seed,
seamless,
progress_images,
variation_amount,
with_variations,
gfpgan_strength,
upscale,
init_img,
init_mask,
strength,
} = parameters;
const sd: { [key: string]: any } = {
shouldDisplayInProgress: progress_images,
// init
shouldGenerateVariations: false,
shouldRunESRGAN: false,
shouldRunGFPGAN: false,
initialImagePath: '',
maskPath: '',
};
if (variation_amount > 0) {
sd.shouldGenerateVariations = true;
sd.variantAmount = variation_amount;
if (with_variations) {
sd.seedWeights = seedWeightsToString(with_variations);
}
}
if (gfpgan_strength > 0) {
sd.shouldRunGFPGAN = true;
sd.gfpganStrength = gfpgan_strength;
}
if (upscale) {
sd.shouldRunESRGAN = true;
sd.upscalingLevel = upscale[0];
sd.upscalingStrength = upscale[1];
}
if (init_img) {
sd.shouldUseInitImage = true
sd.initialImagePath = init_img;
sd.strength = strength;
if (init_mask) {
sd.maskPath = init_mask;
}
}
// if we had a prompt, add all the metadata, but if we don't have a prompt,
// we must have only done ESRGAN or GFPGAN so do not add that metadata
if (prompt) {
sd.prompt = prompt;
sd.iterations = iterations;
sd.steps = steps;
sd.cfgScale = cfg_scale;
sd.height = height;
sd.width = width;
sd.sampler = sampler_name;
sd.seed = seed;
sd.seamless = seamless;
}
return sd;
};

View File

@ -1,393 +0,0 @@
import { createAction, Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
SDMetadata,
setGalleryImages,
setIntermediateImage,
} from '../features/gallery/gallerySlice';
import {
addLogEntry,
setCurrentStep,
setIsConnected,
setIsProcessing,
} from '../features/system/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
import {
backendToFrontendParameters,
frontendToBackendParameters,
} from './parameterTranslation';
export interface SocketIOResponse {
status: 'OK' | 'ERROR';
message?: string;
data?: any;
}
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const { dispatch, getState } = store;
if (!areListenersSet) {
// CONNECT
socketio.on('connect', () => {
try {
dispatch(setIsConnected(true));
} catch (e) {
console.error(e);
}
});
// DISCONNECT
socketio.on('disconnect', () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(addLogEntry(`Disconnected from server`));
} catch (e) {
console.error(e);
}
});
// PROCESSING RESULT
socketio.on(
'result',
(data: {
url: string;
type: 'generation' | 'esrgan' | 'gfpgan';
uuid?: string;
metadata: { [key: string]: any };
}) => {
try {
const newUuid = uuidv4();
const { type, url, uuid, metadata } = data;
switch (type) {
case 'generation': {
const translatedMetadata =
backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry(`Image generated: ${url}`)
);
break;
}
case 'esrgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel =
metadata.upscale[0];
newMetadata.upscalingStrength =
metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`ESRGAN upscaled: ${url}`)
);
break;
}
case 'gfpgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength =
metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`GFPGAN fixed faces: ${url}`)
);
break;
}
}
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
}
);
// PROGRESS UPDATE
socketio.on('progress', (data: { step: number }) => {
try {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(data.step));
} catch (e) {
console.error(e);
}
});
// INTERMEDIATE IMAGE
socketio.on(
'intermediateResult',
(data: { url: string; metadata: SDMetadata }) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry(`Intermediate image generated: ${url}`)
);
} catch (e) {
console.error(e);
}
}
);
// ERROR FROM BACKEND
socketio.on('error', (message) => {
try {
dispatch(addLogEntry(`Server error: ${message}`));
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
});
areListenersSet = true;
}
// HANDLE ACTIONS
switch (action.type) {
// GENERATE IMAGE
case 'socketio/generateImage': {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const {
generationParameters,
esrganParameters,
gfpganParameters,
} = frontendToBackendParameters(
getState().sd,
getState().system
);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry(
`Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`
)
);
break;
}
// RUN ESRGAN (UPSCALING)
case 'socketio/runESRGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry(
`ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`
)
);
break;
}
// RUN GFPGAN (FIX FACES)
case 'socketio/runGFPGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry(
`GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`
)
);
break;
}
// DELETE IMAGE
case 'socketio/deleteImage': {
const imageToDelete = action.payload;
const { url } = imageToDelete;
socketio.emit(
'deleteImage',
url,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(removeImage(imageToDelete));
dispatch(addLogEntry(`Image deleted: ${url}`));
}
}
);
break;
}
// GET ALL IMAGES FOR GALLERY
case 'socketio/requestAllImages': {
socketio.emit(
'requestAllImages',
(response: SocketIOResponse) => {
dispatch(setGalleryImages(response.data));
dispatch(
addLogEntry(`Loaded ${response.data.length} images`)
);
}
);
break;
}
// CANCEL PROCESSING
case 'socketio/cancelProcessing': {
socketio.emit('cancel', (response: SocketIOResponse) => {
const { intermediateImage } = getState().gallery;
if (response.status === 'OK') {
dispatch(setIsProcessing(false));
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry(
`Intermediate image saved: ${intermediateImage.url}`
)
);
dispatch(clearIntermediateImage());
}
dispatch(addLogEntry(`Processing canceled`));
}
});
break;
}
// UPLOAD INITIAL IMAGE
case 'socketio/uploadInitialImage': {
const file = action.payload;
socketio.emit(
'uploadInitialImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setInitialImagePath(response.data));
dispatch(
addLogEntry(
`Initial image uploaded: ${response.data}`
)
);
}
}
);
break;
}
// UPLOAD MASK IMAGE
case 'socketio/uploadMaskImage': {
const file = action.payload;
socketio.emit(
'uploadMaskImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setMaskPath(response.data));
dispatch(
addLogEntry(
`Mask image uploaded: ${response.data}`
)
);
}
}
);
break;
}
}
next(action);
};
return middleware;
};
// Actions to be used by app
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,26 @@
import { createAction } from '@reduxjs/toolkit';
import * as InvokeAI from '../invokeai';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
* requests to the server via socketio. These actions will be handled
* by the middleware.
*/
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runGFPGAN = createAction<InvokeAI.Image>('socketio/runGFPGAN');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
export const requestSystemConfig = createAction<undefined>('socketio/requestSystemConfig');

View File

@ -0,0 +1,104 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { Socket } from 'socket.io-client';
import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
import {
addLogEntry,
setIsProcessing,
} from '../../features/system/systemSlice';
import * as InvokeAI from '../invokeai';
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
return {
emitGenerateImage: () => {
dispatch(setIsProcessing(true));
const { generationParameters, esrganParameters, gfpganParameters } =
frontendToBackendParameters(getState().options, getState().system);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const { upscalingLevel, upscalingStrength } = getState().options;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunGFPGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const { gfpganStrength } = getState().options;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
const { url, uuid } = imageToDelete;
socketio.emit('deleteImage', url, uuid);
},
emitRequestAllImages: () => {
socketio.emit('requestAllImages');
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitUploadInitialImage: (file: File) => {
socketio.emit('uploadInitialImage', file, file.name);
},
emitUploadMaskImage: (file: File) => {
socketio.emit('uploadMaskImage', file, file.name);
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig')
}
};
};
export default makeSocketIOEmitters;

View File

@ -0,0 +1,300 @@
import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import dateFormat from 'dateformat';
import * as InvokeAI from '../invokeai';
import {
addLogEntry,
setIsConnected,
setIsProcessing,
setSystemStatus,
setCurrentStatus,
setSystemConfig,
} from '../../features/system/systemSlice';
import {
addImage,
clearIntermediateImage,
removeImage,
setGalleryImages,
setIntermediateImage,
} from '../../features/gallery/gallerySlice';
import {
setInitialImagePath,
setMaskPath,
} from '../../features/options/optionsSlice';
/**
* Returns an object containing listener callbacks for socketio events.
* TODO: This file is large, but simple. Should it be split up further?
*/
const makeSocketIOListeners = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>
) => {
const { dispatch, getState } = store;
return {
/**
* Callback to run when we receive a 'connect' event.
*/
onConnect: () => {
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus('Connected'));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'disconnect' event.
*/
onDisconnect: () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(setCurrentStatus('Disconnected'));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'generationResult' event.
*/
onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
try {
const { url, metadata } = data;
const newUuid = uuidv4();
dispatch(
addImage({
uuid: newUuid,
url,
metadata: metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'intermediateResult' event.
*/
onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive an 'esrganResult' event.
*/
onESRGANResult: (data: InvokeAI.ImageResultResponse) => {
try {
const { url, metadata } = data;
dispatch(
addImage({
uuid: uuidv4(),
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Upscaled: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'gfpganResult' event.
*/
onGFPGANResult: (data: InvokeAI.ImageResultResponse) => {
try {
const { url, metadata } = data;
dispatch(
addImage({
uuid: uuidv4(),
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Fixed faces: ${url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases
*/
onProgressUpdate: (data: InvokeAI.SystemStatus) => {
try {
dispatch(setIsProcessing(true));
dispatch(setSystemStatus(data));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
*/
onError: (data: InvokeAI.ErrorResponse) => {
const { message, additionalData } = data;
if (additionalData) {
// TODO: handle more data than short message
}
try {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Server error: ${message}`,
level: 'error',
})
);
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'galleryImages' event.
*/
onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
const { images } = data;
const preparedImages = images.map((image): InvokeAI.Image => {
const { url, metadata } = image;
return {
uuid: uuidv4(),
url,
metadata,
};
});
dispatch(setGalleryImages(preparedImages));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Loaded ${images.length} images`,
})
);
},
/**
* Callback to run when we receive a 'processingCanceled' event.
*/
onProcessingCanceled: () => {
dispatch(setIsProcessing(false));
const { intermediateImage } = getState().gallery;
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image saved: ${intermediateImage.url}`,
})
);
dispatch(clearIntermediateImage());
}
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Processing canceled`,
level: 'warning',
})
);
},
/**
* Callback to run when we receive a 'imageDeleted' event.
*/
onImageDeleted: (data: InvokeAI.ImageUrlAndUuidResponse) => {
const { url, uuid } = data;
dispatch(removeImage(uuid));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image deleted: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'initialImageUploaded' event.
*/
onInitialImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
const { url } = data;
dispatch(setInitialImagePath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Initial image uploaded: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'maskImageUploaded' event.
*/
onMaskImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
const { url } = data;
dispatch(setMaskPath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Mask image uploaded: ${url}`,
})
);
},
onSystemConfig: (data: InvokeAI.SystemConfig) => {
dispatch(setSystemConfig(data));
},
};
};
export default makeSocketIOListeners;

View File

@ -0,0 +1,173 @@
import { Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import makeSocketIOListeners from './listeners';
import makeSocketIOEmitters from './emitters';
import * as InvokeAI from '../invokeai';
/**
* Creates a socketio middleware to handle communication with server.
*
* Special `socketio/actionName` actions are created in actions.ts and
* exported for use by the application, which treats them like any old
* action, using `dispatch` to dispatch them.
*
* These actions are intercepted here, where `socketio.emit()` calls are
* made on their behalf - see `emitters.ts`. The emitter functions
* are the outbound communication to the server.
*
* Listeners are also established here - see `listeners.ts`. The listener
* functions receive communication from the server and usually dispatch
* some new action to handle whatever data was sent from the server.
*/
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const {
onConnect,
onDisconnect,
onError,
onESRGANResult,
onGFPGANResult,
onGenerationResult,
onIntermediateResult,
onProgressUpdate,
onGalleryImages,
onProcessingCanceled,
onImageDeleted,
onInitialImageUploaded,
onMaskImageUploaded,
onSystemConfig,
} = makeSocketIOListeners(store);
const {
emitGenerateImage,
emitRunESRGAN,
emitRunGFPGAN,
emitDeleteImage,
emitRequestAllImages,
emitCancelProcessing,
emitUploadInitialImage,
emitUploadMaskImage,
emitRequestSystemConfig,
} = makeSocketIOEmitters(store, socketio);
/**
* If this is the first time the middleware has been called (e.g. during store setup),
* initialize all our socket.io listeners.
*/
if (!areListenersSet) {
socketio.on('connect', () => onConnect());
socketio.on('disconnect', () => onDisconnect());
socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
onGenerationResult(data)
);
socketio.on('esrganResult', (data: InvokeAI.ImageResultResponse) =>
onESRGANResult(data)
);
socketio.on('gfpganResult', (data: InvokeAI.ImageResultResponse) =>
onGFPGANResult(data)
);
socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
onIntermediateResult(data)
);
socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
onProgressUpdate(data)
);
socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
onGalleryImages(data)
);
socketio.on('processingCanceled', () => {
onProcessingCanceled();
});
socketio.on('imageDeleted', (data: InvokeAI.ImageUrlAndUuidResponse) => {
onImageDeleted(data);
});
socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onInitialImageUploaded(data);
});
socketio.on('maskImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onMaskImageUploaded(data);
});
socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
onSystemConfig(data);
});
areListenersSet = true;
}
/**
* Handle redux actions caught by middleware.
*/
switch (action.type) {
case 'socketio/generateImage': {
emitGenerateImage();
break;
}
case 'socketio/runESRGAN': {
emitRunESRGAN(action.payload);
break;
}
case 'socketio/runGFPGAN': {
emitRunGFPGAN(action.payload);
break;
}
case 'socketio/deleteImage': {
emitDeleteImage(action.payload);
break;
}
case 'socketio/requestAllImages': {
emitRequestAllImages();
break;
}
case 'socketio/cancelProcessing': {
emitCancelProcessing();
break;
}
case 'socketio/uploadInitialImage': {
emitUploadInitialImage(action.payload);
break;
}
case 'socketio/uploadMaskImage': {
emitUploadMaskImage(action.payload);
break;
}
case 'socketio/requestSystemConfig': {
emitRequestSystemConfig();
break;
}
}
next(action);
};
return middleware;
};

View File

@ -1,53 +1,78 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit'; import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import { persistReducer } from 'redux-persist'; import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import sdReducer from '../features/sd/sdSlice'; import optionsReducer from '../features/options/optionsSlice';
import galleryReducer from '../features/gallery/gallerySlice'; import galleryReducer from '../features/gallery/gallerySlice';
import systemReducer from '../features/system/systemSlice'; import systemReducer from '../features/system/systemSlice';
import { socketioMiddleware } from './socketio'; import { socketioMiddleware } from './socketio/middleware';
const reducers = combineReducers({ /**
sd: sdReducer, * redux-persist provides an easy and reliable way to persist state across reloads.
gallery: galleryReducer, *
system: systemReducer, * While we definitely want generation parameters to be persisted, there are a number
}); * of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be blacklisted in redux-persist.
*
* The necesssary nested persistors with blacklists are configured below.
*
* TODO: Do we blacklist initialImagePath? If the image is deleted from disk we get an
* ugly 404. But if we blacklist it, then this is a valuable parameter that is lost
* on reload. Need to figure out a good way to handle this.
*/
const persistConfig = { const rootPersistConfig = {
key: 'root', key: 'root',
storage, storage,
blacklist: ['gallery', 'system'],
}; };
const persistedReducer = persistReducer(persistConfig, reducers); const systemPersistConfig = {
key: 'system',
storage,
blacklist: [
'isConnected',
'isProcessing',
'currentStep',
'socketId',
'isESRGANAvailable',
'isGFPGANAvailable',
'currentStep',
'totalSteps',
'currentIteration',
'totalIterations',
'currentStatus',
],
};
/* const reducers = combineReducers({
The frontend needs to be distributed as a production build, so options: optionsReducer,
we cannot reasonably ask users to edit the JS and specify the gallery: galleryReducer,
host and port on which the socket.io server will run. system: persistReducer(systemPersistConfig, systemReducer),
});
The solution is to allow server script to be run with arguments
(or just edited) providing the host and port. Then, the server
serves a route `/socketio_config` which responds with the host
and port.
When the frontend loads, it synchronously requests that route
and thus gets the host and port. This requires a suspicious
fetch somewhere, and the store setup seems like as good a place
as any to make this fetch request.
*/
const persistedReducer = persistReducer(rootPersistConfig, reducers);
// Continue with store setup // Continue with store setup
export const store = configureStore({ export const store = configureStore({
reducer: persistedReducer, reducer: persistedReducer,
middleware: (getDefaultMiddleware) => middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({ getDefaultMiddleware({
// redux-persist sometimes needs to have a function in redux, need to disable this check // redux-persist sometimes needs to temporarily put a function in redux state, need to disable this check
serializableCheck: false, serializableCheck: false,
}).concat(socketioMiddleware()), }).concat(socketioMiddleware()),
}); });
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<typeof store.getState>; export type RootState = ReturnType<typeof store.getState>;
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
export type AppDispatch = typeof store.dispatch; export type AppDispatch = typeof store.dispatch;
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -33,5 +33,20 @@ export const theme = extendTheme({
fontWeight: 'light', fontWeight: 'light',
}, },
}, },
Button: {
variants: {
imageHoverIconButton: (props: StyleFunctionProps) => ({
bg: props.colorMode === 'dark' ? 'blackAlpha.700' : 'whiteAlpha.800',
color:
props.colorMode === 'dark' ? 'whiteAlpha.700' : 'blackAlpha.700',
_hover: {
bg:
props.colorMode === 'dark' ? 'blackAlpha.800' : 'whiteAlpha.800',
color:
props.colorMode === 'dark' ? 'whiteAlpha.900' : 'blackAlpha.900',
},
}),
},
},
}, },
}); });

View File

@ -0,0 +1,21 @@
import { Button, ButtonProps } from '@chakra-ui/react';
interface Props extends ButtonProps {
label: string;
}
/**
* Reusable customized button component. Originally was more customized - now probably unecessary.
*
* TODO: Get rid of this.
*/
const SDButton = (props: Props) => {
const { label, size = 'sm', ...rest } = props;
return (
<Button size={size} {...rest}>
{label}
</Button>
);
};
export default SDButton;

View File

@ -16,6 +16,9 @@ interface Props extends NumberInputProps {
width?: string | number; width?: string | number;
} }
/**
* Customized Chakra FormControl + NumberInput multi-part component.
*/
const SDNumberInput = (props: Props) => { const SDNumberInput = (props: Props) => {
const { const {
label, label,
@ -31,7 +34,7 @@ const SDNumberInput = (props: Props) => {
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}> <Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
{label && ( {label && (
<FormLabel marginBottom={1}> <FormLabel marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'> <Text fontSize={fontSize} whiteSpace="nowrap">
{label} {label}
</Text> </Text>
</FormLabel> </FormLabel>
@ -42,7 +45,7 @@ const SDNumberInput = (props: Props) => {
keepWithinRange={false} keepWithinRange={false}
clampValueOnBlur={true} clampValueOnBlur={true}
> >
<NumberInputField fontSize={'md'}/> <NumberInputField fontSize={'md'} />
<NumberInputStepper> <NumberInputStepper>
<NumberIncrementStepper /> <NumberIncrementStepper />
<NumberDecrementStepper /> <NumberDecrementStepper />

View File

@ -0,0 +1,56 @@
import {
Flex,
FormControl,
FormLabel,
Select,
SelectProps,
Text,
} from '@chakra-ui/react';
interface Props extends SelectProps {
label: string;
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
}
/**
* Customized Chakra FormControl + Select multi-part component.
*/
const SDSelect = (props: Props) => {
const {
label,
isDisabled,
validValues,
size = 'sm',
fontSize = 'md',
marginBottom = 1,
whiteSpace = 'nowrap',
...rest
} = props;
return (
<FormControl isDisabled={isDisabled}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
<FormLabel marginBottom={marginBottom}>
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
{label}
</Text>
</FormLabel>
<Select fontSize={fontSize} size={size} {...rest}>
{validValues.map((opt) => {
return typeof opt === 'string' || typeof opt === 'number' ? (
<option key={opt} value={opt}>
{opt}
</option>
) : (
<option key={opt.value} value={opt.value}>
{opt.key}
</option>
);
})}
</Select>
</Flex>
</FormControl>
);
};
export default SDSelect;

View File

@ -11,6 +11,9 @@ interface Props extends SwitchProps {
width?: string | number; width?: string | number;
} }
/**
* Customized Chakra FormControl + Switch multi-part component.
*/
const SDSwitch = (props: Props) => { const SDSwitch = (props: Props) => {
const { const {
label, label,
@ -28,7 +31,7 @@ const SDSwitch = (props: Props) => {
fontSize={fontSize} fontSize={fontSize}
marginBottom={1} marginBottom={1}
flexGrow={2} flexGrow={2}
whiteSpace='nowrap' whiteSpace="nowrap"
> >
{label} {label}
</FormLabel> </FormLabel>

View File

@ -0,0 +1,104 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { OptionsState } from '../../features/options/optionsSlice';
import { SystemState } from '../../features/system/systemSlice';
import { validateSeedWeights } from '../util/seedWeightPairs';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
prompt: options.prompt,
shouldGenerateVariations: options.shouldGenerateVariations,
seedWeights: options.seedWeights,
maskPath: options.maskPath,
initialImagePath: options.initialImagePath,
seed: options.seed,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Checks relevant pieces of state to confirm generation will not deterministically fail.
* This is used to prevent the 'Generate' button from being clicked.
*/
const useCheckParameters = (): boolean => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
seed,
} = useAppSelector(optionsSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector);
return useMemo(() => {
// Cannot generate without a prompt
if (!prompt) {
return false;
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
return false;
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
return false;
}
// Cannot generate if not connected
if (!isConnected) {
return false;
}
// Cannot generate variations without valid seed weights
if (
shouldGenerateVariations &&
(!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
) {
return false;
}
// All good
return true;
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
isConnected,
shouldGenerateVariations,
seedWeights,
seed,
]);
};
export default useCheckParameters;

View File

@ -0,0 +1,182 @@
/*
These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa.
*/
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { OptionsState } from '../../features/options/optionsSlice';
import { SystemState } from '../../features/system/systemSlice';
import {
seedWeightsToString,
stringToSeedWeightsArray,
} from './seedWeightPairs';
import randomInt from './randomInt';
export const frontendToBackendParameters = (
optionsState: OptionsState,
systemState: SystemState
): { [key: string]: any } => {
const {
prompt,
iterations,
steps,
cfgScale,
height,
width,
sampler,
seed,
seamless,
shouldUseInitImage,
img2imgStrength,
initialImagePath,
maskPath,
shouldFitToWidthHeight,
shouldGenerateVariations,
variationAmount,
seedWeights,
shouldRunESRGAN,
upscalingLevel,
upscalingStrength,
shouldRunGFPGAN,
gfpganStrength,
shouldRandomizeSeed,
} = optionsState;
const { shouldDisplayInProgress } = systemState;
const generationParameters: { [k: string]: any } = {
prompt,
iterations,
steps,
cfg_scale: cfgScale,
height,
width,
sampler_name: sampler,
seed,
seamless,
progress_images: shouldDisplayInProgress,
};
generationParameters.seed = shouldRandomizeSeed
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: seed;
if (shouldUseInitImage) {
generationParameters.init_img = initialImagePath;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
if (maskPath) {
generationParameters.init_mask = maskPath;
}
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variationAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeightsArray(seedWeights);
}
} else {
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
let gfpganParameters: false | { [k: string]: any } = false;
if (shouldRunESRGAN) {
esrganParameters = {
level: upscalingLevel,
strength: upscalingStrength,
};
}
if (shouldRunGFPGAN) {
gfpganParameters = {
strength: gfpganStrength,
};
}
return {
generationParameters,
esrganParameters,
gfpganParameters,
};
};
export const backendToFrontendParameters = (parameters: {
[key: string]: any;
}) => {
const {
prompt,
iterations,
steps,
cfg_scale,
height,
width,
sampler_name,
seed,
seamless,
progress_images,
variation_amount,
with_variations,
gfpgan_strength,
upscale,
init_img,
init_mask,
strength,
} = parameters;
const options: { [key: string]: any } = {
shouldDisplayInProgress: progress_images,
// init
shouldGenerateVariations: false,
shouldRunESRGAN: false,
shouldRunGFPGAN: false,
initialImagePath: '',
maskPath: '',
};
if (variation_amount > 0) {
options.shouldGenerateVariations = true;
options.variationAmount = variation_amount;
if (with_variations) {
options.seedWeights = seedWeightsToString(with_variations);
}
}
if (gfpgan_strength > 0) {
options.shouldRunGFPGAN = true;
options.gfpganStrength = gfpgan_strength;
}
if (upscale) {
options.shouldRunESRGAN = true;
options.upscalingLevel = upscale[0];
options.upscalingStrength = upscale[1];
}
if (init_img) {
options.shouldUseInitImage = true;
options.initialImagePath = init_img;
options.strength = strength;
if (init_mask) {
options.maskPath = init_mask;
}
}
// if we had a prompt, add all the metadata, but if we don't have a prompt,
// we must have only done ESRGAN or GFPGAN so do not add that metadata
if (prompt) {
options.prompt = prompt;
options.iterations = iterations;
options.steps = steps;
options.cfgScale = cfg_scale;
options.height = height;
options.width = width;
options.sampler = sampler_name;
options.seed = seed;
options.seamless = seamless;
}
return options;
};

View File

@ -0,0 +1,16 @@
import * as InvokeAI from '../../app/invokeai';
const promptToString = (prompt: InvokeAI.Prompt): string => {
if (prompt.length === 1) {
return prompt[0].prompt;
}
return prompt
.map(
(promptItem: InvokeAI.PromptItem): string =>
`${promptItem.prompt}:${promptItem.weight}`
)
.join(' ');
};
export default promptToString;

View File

@ -0,0 +1,68 @@
import * as InvokeAI from '../../app/invokeai';
export const stringToSeedWeights = (
string: string
): InvokeAI.SeedWeights | boolean => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
return { seed: parseInt(p[0]), weight: parseFloat(p[1]) };
});
if (!validateSeedWeights(pairs)) {
return false;
}
return pairs;
};
export const validateSeedWeights = (
seedWeights: InvokeAI.SeedWeights | string
): boolean => {
return typeof seedWeights === 'string'
? Boolean(stringToSeedWeights(seedWeights))
: Boolean(
seedWeights.length &&
!seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
const { seed, weight } = pair;
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
const isWeightValid =
!isNaN(parseInt(weight.toString(), 10)) &&
weight >= 0 &&
weight <= 1;
return !(isSeedValid && isWeightValid);
})
);
};
export const seedWeightsToString = (
seedWeights: InvokeAI.SeedWeights
): string => {
return seedWeights.reduce((acc, pair, i, arr) => {
const { seed, weight } = pair;
acc += `${seed}:${weight}`;
if (i !== arr.length - 1) {
acc += ',';
}
return acc;
}, '');
};
export const seedWeightsToArray = (
seedWeights: InvokeAI.SeedWeights
): Array<Array<number>> => {
return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
pair.seed,
pair.weight,
]);
};
export const stringToSeedWeightsArray = (
string: string
): Array<Array<number>> => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
return arrPairs.map(
(p: Array<string>): Array<number> => [parseInt(p[0]), parseFloat(p[1])]
);
};

View File

@ -1,16 +0,0 @@
import { Button, ButtonProps } from '@chakra-ui/react';
interface Props extends ButtonProps {
label: string;
}
const SDButton = (props: Props) => {
const { label, size = 'sm', ...rest } = props;
return (
<Button size={size} {...rest}>
{label}
</Button>
);
};
export default SDButton;

View File

@ -1,57 +0,0 @@
import {
Flex,
FormControl,
FormLabel,
Select,
SelectProps,
Text,
} from '@chakra-ui/react';
interface Props extends SelectProps {
label: string;
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
}
const SDSelect = (props: Props) => {
const {
label,
isDisabled,
validValues,
size = 'sm',
fontSize = 'md',
marginBottom = 1,
whiteSpace = 'nowrap',
...rest
} = props;
return (
<FormControl isDisabled={isDisabled}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
<FormLabel
marginBottom={marginBottom}
>
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
{label}
</Text>
</FormLabel>
<Select fontSize={fontSize} size={size} {...rest}>
{validValues.map((opt) => {
return typeof opt === 'string' ||
typeof opt === 'number' ? (
<option key={opt} value={opt}>
{opt}
</option>
) : (
<option key={opt.value} value={opt.value}>
{opt.key}
</option>
);
})}
</Select>
</Flex>
</FormControl>
);
};
export default SDSelect;

View File

@ -1,161 +0,0 @@
import { Center, Flex, Image, useColorModeValue } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';
import DeleteImageModalButton from './DeleteImageModalButton';
import SDButton from '../../components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash';
const height = 'calc(100vh - 238px)';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const CurrentImage = () => {
const { currentImage, intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const bgColor = useColorModeValue(
'rgba(255, 255, 255, 0.85)',
'rgba(0, 0, 0, 0.8)'
);
const [shouldShowImageDetails, setShouldShowImageDetails] =
useState<boolean>(false);
const imageToDisplay = intermediateImage || currentImage;
return (
<Flex direction={'column'} rounded={'md'} borderWidth={1} p={2} gap={2}>
{imageToDisplay && (
<Flex gap={2}>
<SDButton
label='Use as initial image'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setInitialImagePath(imageToDisplay.url))
}
/>
<SDButton
label='Use all'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setAllParameters(imageToDisplay.metadata))
}
/>
<SDButton
label='Use seed'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!imageToDisplay.metadata.seed}
onClick={() =>
dispatch(setSeed(imageToDisplay.metadata.seed!))
}
/>
<SDButton
label='Upscale'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isESRGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runESRGAN(imageToDisplay))}
/>
<SDButton
label='Fix faces'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isGFPGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runGFPGAN(imageToDisplay))}
/>
<SDButton
label='Details'
colorScheme={'gray'}
variant={shouldShowImageDetails ? 'solid' : 'outline'}
borderWidth={1}
flexGrow={1}
onClick={() =>
setShouldShowImageDetails(!shouldShowImageDetails)
}
/>
<DeleteImageModalButton image={imageToDisplay}>
<SDButton
label='Delete'
colorScheme={'red'}
flexGrow={1}
variant={'outline'}
isDisabled={Boolean(intermediateImage)}
/>
</DeleteImageModalButton>
</Flex>
)}
<Center height={height} position={'relative'}>
{imageToDisplay && (
<Image
src={imageToDisplay.url}
fit='contain'
maxWidth={'100%'}
maxHeight={'100%'}
/>
)}
{imageToDisplay && shouldShowImageDetails && (
<Flex
width={'100%'}
height={'100%'}
position={'absolute'}
top={0}
left={0}
p={3}
boxSizing='border-box'
backgroundColor={bgColor}
overflow='scroll'
>
<ImageMetadataViewer image={imageToDisplay} />
</Flex>
)}
</Center>
</Flex>
);
};
export default CurrentImage;

View File

@ -0,0 +1,155 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import * as InvokeAI from '../../app/invokeai';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import {
setAllParameters,
setInitialImagePath,
setSeed,
} from '../options/optionsSlice';
import DeleteImageModal from './DeleteImageModal';
import { SystemState } from '../system/systemSlice';
import SDButton from '../../common/components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
type CurrentImageButtonsProps = {
image: InvokeAI.Image;
shouldShowImageDetails: boolean;
setShouldShowImageDetails: (b: boolean) => void;
};
/**
* Row of buttons for common actions:
* Use as init image, use all params, use seed, upscale, fix faces, details, delete.
*/
const CurrentImageButtons = ({
image,
shouldShowImageDetails,
setShouldShowImageDetails,
}: CurrentImageButtonsProps) => {
const dispatch = useAppDispatch();
const { intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const { upscalingLevel, gfpganStrength } = useAppSelector(
(state: RootState) => state.options
);
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
useAppSelector(systemSelector);
const handleClickUseAsInitialImage = () =>
dispatch(setInitialImagePath(image.url));
const handleClickUseAllParameters = () =>
dispatch(setAllParameters(image.metadata));
// Non-null assertion: this button is disabled if there is no seed.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const handleClickUseSeed = () => dispatch(setSeed(image.metadata.image.seed));
const handleClickUpscale = () => dispatch(runESRGAN(image));
const handleClickFixFaces = () => dispatch(runGFPGAN(image));
const handleClickShowImageDetails = () =>
setShouldShowImageDetails(!shouldShowImageDetails);
return (
<Flex gap={2}>
<SDButton
label="Use as initial image"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={handleClickUseAsInitialImage}
/>
<SDButton
label="Use all"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)}
onClick={handleClickUseAllParameters}
/>
<SDButton
label="Use seed"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!image.metadata.image.seed}
onClick={handleClickUseSeed}
/>
<SDButton
label="Upscale"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isESRGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
onClick={handleClickUpscale}
/>
<SDButton
label="Fix faces"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isGFPGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing) ||
!gfpganStrength
}
onClick={handleClickFixFaces}
/>
<SDButton
label="Details"
colorScheme={'gray'}
variant={shouldShowImageDetails ? 'solid' : 'outline'}
borderWidth={1}
flexGrow={1}
onClick={handleClickShowImageDetails}
/>
<DeleteImageModal image={image}>
<SDButton
label="Delete"
colorScheme={'red'}
flexGrow={1}
variant={'outline'}
isDisabled={Boolean(intermediateImage)}
/>
</DeleteImageModal>
</Flex>
);
};
export default CurrentImageButtons;

View File

@ -0,0 +1,67 @@
import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';
import CurrentImageButtons from './CurrentImageButtons';
// TODO: With CSS Grid I had a hard time centering the image in a grid item. This is needed for that.
const height = 'calc(100vh - 238px)';
/**
* Displays the current image if there is one, plus associated actions.
*/
const CurrentImageDisplay = () => {
const { currentImage, intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const bgColor = useColorModeValue(
'rgba(255, 255, 255, 0.85)',
'rgba(0, 0, 0, 0.8)'
);
const [shouldShowImageDetails, setShouldShowImageDetails] =
useState<boolean>(false);
const imageToDisplay = intermediateImage || currentImage;
return imageToDisplay ? (
<Flex direction={'column'} borderWidth={1} rounded={'md'} p={2} gap={2}>
<CurrentImageButtons
image={imageToDisplay}
shouldShowImageDetails={shouldShowImageDetails}
setShouldShowImageDetails={setShouldShowImageDetails}
/>
<Center height={height} position={'relative'}>
<Image
src={imageToDisplay.url}
fit="contain"
maxWidth={'100%'}
maxHeight={'100%'}
/>
{shouldShowImageDetails && (
<Flex
width={'100%'}
height={'100%'}
position={'absolute'}
top={0}
left={0}
p={3}
boxSizing="border-box"
backgroundColor={bgColor}
overflow="scroll"
>
<ImageMetadataViewer image={imageToDisplay} />
</Flex>
)}
</Center>
</Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No image selected</Text>
</Center>
);
};
export default CurrentImageDisplay;

View File

@ -0,0 +1,125 @@
import {
Text,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
useDisclosure,
Button,
Switch,
FormControl,
FormLabel,
Flex,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
ChangeEvent,
cloneElement,
forwardRef,
ReactElement,
SyntheticEvent,
useRef,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { deleteImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import * as InvokeAI from '../../app/invokeai';
interface DeleteImageModalProps {
/**
* Component which, on click, should delete the image/open the modal.
*/
children: ReactElement;
/**
* The image to delete.
*/
image: InvokeAI.Image;
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.shouldConfirmOnDelete
);
/**
* Needs a child, which will act as the button to delete an image.
* If system.shouldConfirmOnDelete is true, a confirmation modal is displayed.
* If it is false, the image is deleted immediately.
* The confirmation modal has a "Don't ask me again" switch to set the boolean.
*/
const DeleteImageModal = forwardRef(
({ image, children }: DeleteImageModalProps, ref) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const cancelRef = useRef<HTMLButtonElement>(null);
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete();
};
const handleDelete = () => {
dispatch(deleteImage(image));
onClose();
};
const handleChangeShouldConfirmOnDelete = (
e: ChangeEvent<HTMLInputElement>
) => dispatch(setShouldConfirmOnDelete(!e.target.checked));
return (
<>
{cloneElement(children, {
// TODO: This feels wrong.
onClick: handleClickDelete,
ref: ref,
})}
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete image
</AlertDialogHeader>
<AlertDialogBody>
<Flex direction={'column'} gap={5}>
<Text>
Are you sure? You can't undo this action afterwards.
</Text>
<FormControl>
<Flex alignItems={'center'}>
<FormLabel mb={0}>Don't ask me again</FormLabel>
<Switch
checked={!shouldConfirmOnDelete}
onChange={handleChangeShouldConfirmOnDelete}
/>
</Flex>
</FormControl>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={handleDelete} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
</>
);
}
);
export default DeleteImageModal;

View File

@ -1,94 +0,0 @@
import {
IconButtonProps,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
cloneElement,
ReactElement,
SyntheticEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { deleteImage } from '../../app/socketio';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice';
interface Props extends IconButtonProps {
image: SDImage;
'aria-label': string;
children: ReactElement;
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.shouldConfirmOnDelete
);
/*
TODO: The modal and button to open it should be two different components,
but their state is closely related and I'm not sure how best to accomplish it.
*/
const DeleteImageModalButton = (props: Omit<Props, 'aria-label'>) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete();
};
const { image, children } = props;
const handleDelete = () => {
dispatch(deleteImage(image));
onClose();
};
const handleDeleteAndDontAsk = () => {
dispatch(deleteImage(image));
dispatch(setShouldConfirmOnDelete(false));
onClose();
};
return (
<>
{cloneElement(children, {
onClick: handleClickDelete,
})}
<Modal isOpen={isOpen} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Are you sure you want to delete this image?</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Text>It will be deleted forever!</Text>
</ModalBody>
<ModalFooter justifyContent={'space-between'}>
<SDButton label={'Yes'} colorScheme='red' onClick={handleDelete} />
<SDButton
label={"Yes, and don't ask me again"}
colorScheme='red'
onClick={handleDeleteAndDontAsk}
/>
<SDButton label='Cancel' colorScheme='blue' onClick={onClose} />
</ModalFooter>
</ModalContent>
</Modal>
</>
);
};
export default DeleteImageModalButton;

View File

@ -0,0 +1,143 @@
import {
Box,
Flex,
Icon,
IconButton,
Image,
Tooltip,
useColorModeValue,
} from '@chakra-ui/react';
import { useAppDispatch } from '../../app/store';
import { setCurrentImage } from './gallerySlice';
import { FaCheck, FaSeedling, FaTrashAlt } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';
import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../options/optionsSlice';
import * as InvokeAI from '../../app/invokeai';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
interface HoverableImageProps {
image: InvokeAI.Image;
isSelected: boolean;
}
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const HoverableImage = memo((props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const checkColor = useColorModeValue('green.600', 'green.300');
const bgColor = useColorModeValue('gray.200', 'gray.700');
const bgGradient = useColorModeValue(
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
);
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
// Non-null assertion: this button is not rendered unless this exists
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
dispatch(setSeed(image.metadata.image.seed));
};
const handleClickImage = () => dispatch(setCurrentImage(image));
return (
<Box position={'relative'} key={uuid}>
<Image
width={120}
height={120}
objectFit="cover"
rounded={'md'}
src={url}
loading={'lazy'}
backgroundColor={bgColor}
/>
<Flex
cursor={'pointer'}
position={'absolute'}
top={0}
left={0}
rounded={'md'}
width="100%"
height="100%"
alignItems={'center'}
justifyContent={'center'}
background={isSelected ? bgGradient : undefined}
onClick={handleClickImage}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
{isSelected && (
<Icon fill={checkColor} width={'50%'} height={'50%'} as={FaCheck} />
)}
{isHovered && (
<Flex
direction={'column'}
gap={1}
position={'absolute'}
top={1}
right={1}
>
<Tooltip label={'Delete image'}>
<DeleteImageModal image={image}>
<IconButton
colorScheme="red"
aria-label="Delete image"
icon={<FaTrashAlt />}
size="xs"
variant={'imageHoverIconButton'}
fontSize={14}
/>
</DeleteImageModal>
</Tooltip>
{['txt2img', 'img2img'].includes(image.metadata.image.type) && (
<Tooltip label="Use all parameters">
<IconButton
aria-label="Use all parameters"
icon={<IoArrowUndoCircleOutline />}
size="xs"
fontSize={18}
variant={'imageHoverIconButton'}
onClickCapture={handleClickSetAllParameters}
/>
</Tooltip>
)}
{image.metadata.image.seed && (
<Tooltip label="Use seed">
<IconButton
aria-label="Use seed"
icon={<FaSeedling />}
size="xs"
fontSize={16}
variant={'imageHoverIconButton'}
onClickCapture={handleClickSetSeed}
/>
</Tooltip>
)}
</Flex>
)}
</Flex>
</Box>
);
}, memoEqualityCheck);
export default HoverableImage;

View File

@ -0,0 +1,39 @@
import { Center, Flex, Text } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppSelector } from '../../app/store';
import HoverableImage from './HoverableImage';
/**
* Simple image gallery.
*/
const ImageGallery = () => {
const { images, currentImageUuid } = useAppSelector(
(state: RootState) => state.gallery
);
/**
* I don't like that this needs to rerender whenever the current image is changed.
* What if we have a large number of images? I suppose pagination (planned) will
* mitigate this issue.
*
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
*/
return images.length ? (
<Flex gap={2} wrap="wrap" pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage key={uuid} image={image} isSelected={isSelected} />
);
})}
</Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No images in gallery</Text>
</Center>
);
};
export default ImageGallery;

View File

@ -1,124 +1,326 @@
import { import {
Center, Box,
Flex, Center,
IconButton, Flex,
Link, IconButton,
List, Link,
ListItem, Text,
Text, Tooltip,
useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa'; import { ExternalLinkIcon } from '@chakra-ui/icons';
import { PARAMETERS } from '../../app/constants'; import { memo } from 'react';
import { useAppDispatch } from '../../app/hooks'; import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import SDButton from '../../components/SDButton'; import { useAppDispatch } from '../../app/store';
import { setAllParameters, setParameter } from '../sd/sdSlice'; import * as InvokeAI from '../../app/invokeai';
import { SDImage, SDMetadata } from './gallerySlice'; import {
setCfgScale,
setGfpganStrength,
setHeight,
setImg2imgStrength,
setInitialImagePath,
setMaskPath,
setPrompt,
setSampler,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,
setSteps,
setUpscalingLevel,
setUpscalingStrength,
setWidth,
} from '../options/optionsSlice';
import promptToString from '../../common/util/promptToString';
import { seedWeightsToString } from '../../common/util/seedWeightPairs';
import { FaCopy } from 'react-icons/fa';
type Props = { type MetadataItemProps = {
image: SDImage; isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
}; };
const ImageMetadataViewer = ({ image }: Props) => { /**
const dispatch = useAppDispatch(); * Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({ label, value, onClick, isLink }: MetadataItemProps) => {
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label="Use this parameter"
icon={<IoArrowUndoCircleOutline />}
size={'xs'}
variant={'ghost'}
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
<Text fontWeight={'semibold'} whiteSpace={'nowrap'}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak={'break-all'}>
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text maxHeight={100} overflowY={'scroll'} wordBreak={'break-all'}>
{value.toString()}
</Text>
)}
</Flex>
);
};
const keys = Object.keys(PARAMETERS); type ImageMetadataViewerProps = {
image: InvokeAI.Image;
};
const metadata: Array<{ // TODO: I don't know if this is needed.
label: string; const memoEqualityCheck = (
key: string; prev: ImageMetadataViewerProps,
value: string | number | boolean; next: ImageMetadataViewerProps
}> = []; ) => prev.image.uuid === next.image.uuid;
keys.forEach((key) => { // TODO: Show more interesting information in this component.
const value = image.metadata[key as keyof SDMetadata];
if (value !== undefined) {
metadata.push({ label: PARAMETERS[key], key, value });
}
});
return ( /**
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}> * Image metadata viewer overlays currently selected image and provides
<SDButton * access to any of its metadata for use in processing.
label='Use all parameters' */
colorScheme={'gray'} const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
padding={2} const dispatch = useAppDispatch();
isDisabled={metadata.length === 0} const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
onClick={() => dispatch(setAllParameters(image.metadata))}
const metadata = image.metadata.image;
const {
type,
postprocessing,
sampler,
prompt,
seed,
variations,
steps,
cfg_scale,
seamless,
width,
height,
strength,
fit,
init_image_path,
mask_image_path,
orig_path,
scale,
} = metadata;
const metadataJSON = JSON.stringify(metadata, null, 2);
return (
<Flex
gap={1}
direction={'column'}
overflowY={'scroll'}
width={'100%'}
>
<Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal>
{image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{Object.keys(metadata).length ? (
<>
{type && <MetadataItem label="Type" value={type} />}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} isLink />
)}
{type === 'gfpgan' && strength && (
<MetadataItem
label="Fix faces strength"
value={strength}
onClick={() => dispatch(setGfpganStrength(strength))}
/> />
<Flex gap={2}> )}
<Text fontWeight={'semibold'}>File:</Text> {type === 'esrgan' && scale && (
<Link href={image.url} isExternal> <MetadataItem
<Text>{image.url}</Text> label="Upscaling scale"
</Link> value={scale}
</Flex> onClick={() => dispatch(setUpscalingLevel(scale))}
{metadata.length ? ( />
<> )}
<List> {type === 'esrgan' && strength && (
{metadata.map((parameter, i) => { <MetadataItem
const { label, key, value } = parameter; label="Upscaling strength"
return ( value={strength}
<ListItem key={i} pb={1}> onClick={() => dispatch(setUpscalingStrength(strength))}
<Flex gap={2}> />
<IconButton )}
aria-label='Use this parameter' {prompt && (
icon={<FaPlus />} <MetadataItem
size={'xs'} label="Prompt"
onClick={() => value={promptToString(prompt)}
dispatch( onClick={() => dispatch(setPrompt(prompt))}
setParameter({ />
key, )}
value, {seed && (
}) <MetadataItem
) label="Seed"
} value={seed}
/> onClick={() => dispatch(setSeed(seed))}
<Text fontWeight={'semibold'}> />
{label}: )}
</Text> {sampler && (
<MetadataItem
{value === undefined || label="Sampler"
value === null || value={sampler}
value === '' || onClick={() => dispatch(setSampler(sampler))}
value === 0 ? ( />
<Text )}
maxHeight={100} {steps && (
fontStyle={'italic'} <MetadataItem
> label="Steps"
None value={steps}
</Text> onClick={() => dispatch(setSteps(steps))}
) : ( />
<Text )}
maxHeight={100} {cfg_scale && (
overflowY={'scroll'} <MetadataItem
> label="CFG scale"
{value.toString()} value={cfg_scale}
</Text> onClick={() => dispatch(setCfgScale(cfg_scale))}
)} />
</Flex> )}
</ListItem> {variations && variations.length > 0 && (
); <MetadataItem
})} label="Seed-weight pairs"
</List> value={seedWeightsToString(variations)}
<Flex gap={2}> onClick={() =>
<Text fontWeight={'semibold'}>Raw:</Text> dispatch(setSeedWeights(seedWeightsToString(variations)))
<Text }
maxHeight={100} />
overflowY={'scroll'} )}
wordBreak={'break-all'} {seamless && (
> <MetadataItem
{JSON.stringify(image.metadata)} label="Seamless"
</Text> value={seamless}
</Flex> onClick={() => dispatch(setWidth(seamless))}
</> />
) : ( )}
<Center width={'100%'} pt={10}> {width && (
<Text fontSize={'lg'} fontWeight='semibold'> <MetadataItem
No metadata available label="Width"
</Text> value={width}
</Center> onClick={() => dispatch(setWidth(width))}
/>
)}
{height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
/>
)}
{init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImagePath(init_image_path))}
/>
)}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
/>
)}
{fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
/>
)}
{postprocessing &&
postprocessing.length > 0 &&
postprocessing.map(
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
if (postprocess.type === 'esrgan') {
const { scale, strength } = postprocess;
return (
<>
<MetadataItem
label="Upscaling scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Upscaling strength"
value={strength}
onClick={() => dispatch(setUpscalingStrength(strength))}
/>
</>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<MetadataItem
label="Fix faces strength"
value={strength}
onClick={() => dispatch(setGfpganStrength(strength))}
/>
);
}
}
)} )}
</Flex> <Flex gap={2} direction={'column'}>
); <Flex gap={2}>
}; <Tooltip label={`Copy JSON`}>
<IconButton
aria-label="Copy JSON"
icon={<FaCopy />}
size={'xs'}
variant={'ghost'}
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight={'semibold'}>JSON:</Text>
</Flex>
<Box
// maxHeight={200}
overflow={'scroll'}
flexGrow={3}
wordBreak={'break-all'}
bgColor={jsonBgColor}
padding={2}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width={'100%'} pt={10}>
<Text fontSize={'lg'} fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
</Flex>
);
}, memoEqualityCheck);
export default ImageMetadataViewer; export default ImageMetadataViewer;

View File

@ -1,150 +0,0 @@
import {
Box,
Flex,
Icon,
IconButton,
Image,
useColorModeValue,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { SDImage, setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
import DeleteImageModalButton from './DeleteImageModalButton';
import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../sd/sdSlice';
interface HoverableImageProps {
image: SDImage;
isSelected: boolean;
}
const HoverableImage = memo(
(props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const checkColor = useColorModeValue('green.600', 'green.300');
const bgColor = useColorModeValue('gray.200', 'gray.700');
const bgGradient = useColorModeValue(
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
);
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setSeed(image.metadata.seed!)); // component not rendered unless this exists
};
return (
<Box position={'relative'} key={uuid}>
<Image
width={120}
height={120}
objectFit='cover'
rounded={'md'}
src={url}
loading={'lazy'}
backgroundColor={bgColor}
/>
<Flex
cursor={'pointer'}
position={'absolute'}
top={0}
left={0}
rounded={'md'}
width='100%'
height='100%'
alignItems={'center'}
justifyContent={'center'}
background={isSelected ? bgGradient : undefined}
onClick={() => dispatch(setCurrentImage(image))}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
{isSelected && (
<Icon
fill={checkColor}
width={'50%'}
height={'50%'}
as={FaCheck}
/>
)}
{isHovered && (
<Flex
direction={'column'}
gap={1}
position={'absolute'}
top={1}
right={1}
>
<DeleteImageModalButton image={image}>
<IconButton
colorScheme='red'
aria-label='Delete image'
icon={<FaTrash />}
size='xs'
fontSize={15}
/>
</DeleteImageModalButton>
<IconButton
aria-label='Use all parameters'
colorScheme={'blue'}
icon={<FaCopy />}
size='xs'
fontSize={15}
onClickCapture={handleClickSetAllParameters}
/>
{image.metadata.seed && (
<IconButton
aria-label='Use seed'
colorScheme={'blue'}
icon={<FaSeedling />}
size='xs'
fontSize={16}
onClickCapture={handleClickSetSeed}
/>
)}
</Flex>
)}
</Flex>
</Box>
);
},
(prev, next) =>
prev.image.uuid === next.image.uuid &&
prev.isSelected === next.isSelected
);
const ImageRoll = () => {
const { images, currentImageUuid } = useAppSelector(
(state: RootState) => state.gallery
);
return (
<Flex gap={2} wrap='wrap' pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage
key={uuid}
image={image}
isSelected={isSelected}
/>
);
})}
</Flex>
);
};
export default ImageRoll;

View File

@ -1,40 +1,13 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid'; import { clamp } from 'lodash';
import { UpscalingLevel } from '../sd/sdSlice'; import * as InvokeAI from '../../app/invokeai';
import { backendToFrontendParameters } from '../../app/parameterTranslation';
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
export interface SDMetadata {
prompt?: string;
steps?: number;
cfgScale?: number;
height?: number;
width?: number;
sampler?: string;
seed?: number;
img2imgStrength?: number;
gfpganStrength?: number;
upscalingLevel?: UpscalingLevel;
upscalingStrength?: number;
initialImagePath?: string;
maskPath?: string;
seamless?: boolean;
shouldFitToWidthHeight?: boolean;
}
export interface SDImage {
// TODO: I have installed @types/uuid but cannot figure out how to use them here.
uuid: string;
url: string;
metadata: SDMetadata;
}
export interface GalleryState { export interface GalleryState {
currentImage?: InvokeAI.Image;
currentImageUuid: string; currentImageUuid: string;
images: Array<SDImage>; images: Array<InvokeAI.Image>;
intermediateImage?: SDImage; intermediateImage?: InvokeAI.Image;
currentImage?: SDImage;
} }
const initialState: GalleryState = { const initialState: GalleryState = {
@ -46,99 +19,84 @@ export const gallerySlice = createSlice({
name: 'gallery', name: 'gallery',
initialState, initialState,
reducers: { reducers: {
setCurrentImage: (state, action: PayloadAction<SDImage>) => { setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.currentImage = action.payload; state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid; state.currentImageUuid = action.payload.uuid;
}, },
removeImage: (state, action: PayloadAction<SDImage>) => { removeImage: (state, action: PayloadAction<string>) => {
const { uuid } = action.payload; const uuid = action.payload;
const newImages = state.images.filter((image) => image.uuid !== uuid); const newImages = state.images.filter((image) => image.uuid !== uuid);
const imageToDeleteIndex = state.images.findIndex( if (uuid === state.currentImageUuid) {
(image) => image.uuid === uuid /**
); * We are deleting the currently selected image.
*
* We want the new currentl selected image to be under the cursor in the
* gallery, so we need to do some fanagling. The currently selected image
* is set by its UUID, not its index in the image list.
*
* Get the currently selected image's index.
*/
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
const newCurrentImageIndex = Math.min( /**
Math.max(imageToDeleteIndex, 0), * New current image needs to be in the same spot, but because the gallery
newImages.length - 1 * is sorted in reverse order, the new current image's index will actuall be
); * one less than the deleted image's index.
*
* Clamp the new index to ensure it is valid..
*/
const newCurrentImageIndex = clamp(
imageToDeleteIndex - 1,
0,
newImages.length - 1
);
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
}
state.images = newImages; state.images = newImages;
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
}, },
addImage: (state, action: PayloadAction<SDImage>) => { addImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.images.push(action.payload); state.images.push(action.payload);
state.currentImageUuid = action.payload.uuid; state.currentImageUuid = action.payload.uuid;
state.intermediateImage = undefined; state.intermediateImage = undefined;
state.currentImage = action.payload; state.currentImage = action.payload;
}, },
setIntermediateImage: (state, action: PayloadAction<SDImage>) => { setIntermediateImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.intermediateImage = action.payload; state.intermediateImage = action.payload;
}, },
clearIntermediateImage: (state) => { clearIntermediateImage: (state) => {
state.intermediateImage = undefined; state.intermediateImage = undefined;
}, },
setGalleryImages: ( setGalleryImages: (state, action: PayloadAction<Array<InvokeAI.Image>>) => {
state, const newImages = action.payload;
action: PayloadAction< if (newImages.length) {
Array<{ const newCurrentImage = newImages[newImages.length - 1];
path: string;
metadata: { [key: string]: string | number | boolean };
}>
>
) => {
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
const images = action.payload;
if (images.length === 0) {
// there are no images on disk, clear the gallery
state.images = [];
state.currentImageUuid = '';
state.currentImage = undefined;
} else {
// Filter image urls that are already in the rehydrated state
const filteredImages = action.payload.filter(
(image) => !state.images.find((i) => i.url === image.path)
);
const preparedImages = filteredImages.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
const newImages = [...state.images].concat(preparedImages);
// if previous currentimage no longer exists, set a new one
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
const newCurrentImage = newImages[newImages.length - 1];
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
state.images = newImages; state.images = newImages;
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
} }
}, },
}, },
}); });
export const { export const {
setCurrentImage,
removeImage,
addImage, addImage,
clearIntermediateImage,
removeImage,
setCurrentImage,
setGalleryImages, setGalleryImages,
setIntermediateImage, setIntermediateImage,
clearIntermediateImage,
} = gallerySlice.actions; } = gallerySlice.actions;
export default gallerySlice.reducer; export default gallerySlice.reducer;

View File

@ -1,35 +0,0 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
realSteps: sd.realSteps,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ProgressBar = () => {
const { realSteps } = useAppSelector(sdSelector);
const { currentStep } = useAppSelector((state: RootState) => state.system);
const progress = Math.round((currentStep * 100) / realSteps);
return (
<Progress
height='10px'
value={progress}
isIndeterminate={progress < 0 || currentStep === realSteps}
/>
);
};
export default ProgressBar;

View File

@ -1,93 +0,0 @@
import {
Flex,
Heading,
IconButton,
Link,
Spacer,
Text,
useColorMode,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { isConnected: system.isConnected };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode();
const { isConnected } = useAppSelector(systemSelector);
return (
<Flex minWidth='max-content' alignItems='center' gap='1' pl={2} pr={1}>
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>
<Spacer />
<Text textColor={isConnected ? 'green.500' : 'red.500'}>
{isConnected ? `Connected to server` : 'No connection to server'}
</Text>
<SettingsModal>
<IconButton
aria-label='Settings'
variant='link'
fontSize={24}
size={'sm'}
icon={<MdSettings />}
/>
</SettingsModal>
<IconButton
aria-label='Link to Github Issues'
variant='link'
fontSize={23}
size={'sm'}
icon={
<Link
isExternal
href='http://github.com/lstein/stable-diffusion/issues'
>
<MdHelp />
</Link>
}
/>
<IconButton
aria-label='Link to Github Repo'
variant='link'
fontSize={20}
size={'sm'}
icon={
<Link isExternal href='http://github.com/lstein/stable-diffusion'>
<FaGithub />
</Link>
}
/>
<IconButton
aria-label='Toggle Dark Mode'
onClick={toggleColorMode}
variant='link'
size={'sm'}
fontSize={colorMode == 'light' ? 18 : 20}
icon={colorMode == 'light' ? <FaMoon /> : <FaSun />}
/>
</Flex>
);
};
export default SiteHeader;

View File

@ -0,0 +1,87 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
OptionsState,
} from '../options/optionsSlice';
import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
upscalingLevel: options.upscalingLevel,
upscalingStrength: options.upscalingStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Displays upscaling/ESRGAN options (level and strength).
*/
const ESRGANOptions = () => {
const dispatch = useAppDispatch();
const { upscalingLevel, upscalingStrength } = useAppSelector(optionsSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector);
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
const handleChangeStrength = (v: string | number) =>
dispatch(setUpscalingStrength(Number(v)));
return (
<Flex direction={'column'} gap={2}>
<SDSelect
isDisabled={!isESRGANAvailable}
label="Scale"
value={upscalingLevel}
onChange={handleChangeLevel}
validValues={UPSCALING_LEVELS}
/>
<SDNumberInput
isDisabled={!isESRGANAvailable}
label="Strength"
step={0.05}
min={0}
max={1}
onChange={handleChangeStrength}
value={upscalingStrength}
/>
</Flex>
);
};
export default ESRGANOptions;

View File

@ -0,0 +1,68 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { OptionsState, setGfpganStrength } from '../options/optionsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import SDNumberInput from '../../common/components/SDNumberInput';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
gfpganStrength: options.gfpganStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Displays face-fixing/GFPGAN options (strength).
*/
const GFPGANOptions = () => {
const dispatch = useAppDispatch();
const { gfpganStrength } = useAppSelector(optionsSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector);
const handleChangeStrength = (v: string | number) =>
dispatch(setGfpganStrength(Number(v)));
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!isGFPGANAvailable}
label="Strength"
step={0.05}
min={0}
max={1}
onChange={handleChangeStrength}
value={gfpganStrength}
/>
</Flex>
);
};
export default GFPGANOptions;

View File

@ -0,0 +1,59 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { ChangeEvent } from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import InitAndMaskImage from './InitAndMaskImage';
import {
OptionsState,
setImg2imgStrength,
setShouldFitToWidthHeight,
} from './optionsSlice';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
img2imgStrength: options.img2imgStrength,
shouldFitToWidthHeight: options.shouldFitToWidthHeight,
};
}
);
/**
* Options for img2img generation (strength, fit, init/mask upload).
*/
const ImageToImageOptions = () => {
const dispatch = useAppDispatch();
const { img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(optionsSelector);
const handleChangeStrength = (v: string | number) =>
dispatch(setImg2imgStrength(Number(v)));
const handleChangeFit = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldFitToWidthHeight(e.target.checked));
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
label="Strength"
step={0.01}
min={0}
max={1}
onChange={handleChangeStrength}
value={img2imgStrength}
/>
<SDSwitch
label="Fit initial image to output size"
isChecked={shouldFitToWidthHeight}
onChange={handleChangeFit}
/>
<InitAndMaskImage />
</Flex>
);
};
export default ImageToImageOptions;

View File

@ -0,0 +1,64 @@
import { Box } from '@chakra-ui/react';
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
type ImageUploaderProps = {
/**
* Component which, on click, should open the upload interface.
*/
children: ReactElement;
/**
* Callback to handle uploading the selected file.
*/
fileAcceptedCallback: (file: File) => void;
/**
* Callback to handle a file being rejected.
*/
fileRejectionCallback: (rejection: FileRejection) => void;
};
/**
* File upload using react-dropzone.
* Needs a child to be the button to activate the upload interface.
*/
const ImageUploader = ({
children,
fileAcceptedCallback,
fileRejectionCallback,
}: ImageUploaderProps) => {
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
fileRejectionCallback(rejection);
});
acceptedFiles.forEach((file: File) => {
fileAcceptedCallback(file);
});
},
[fileAcceptedCallback, fileRejectionCallback]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
return (
<Box {...getRootProps()} flexGrow={3}>
<input {...getInputProps({ multiple: false })} />
{cloneElement(children, {
onClick: handleClickUploadIcon,
})}
</Box>
);
};
export default ImageUploader;

View File

@ -0,0 +1,57 @@
import { Flex, Image } from '@chakra-ui/react';
import { useState } from 'react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { OptionsState } from '../../features/options/optionsSlice';
import './InitAndMaskImage.css';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import InitAndMaskUploadButtons from './InitAndMaskUploadButtons';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
initialImagePath: options.initialImagePath,
maskPath: options.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
/**
* Displays init and mask images and buttons to upload/delete them.
*/
const InitAndMaskImage = () => {
const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
return (
<Flex direction={'column'} alignItems={'center'} gap={2}>
<InitAndMaskUploadButtons setShouldShowMask={setShouldShowMask} />
{initialImagePath && (
<Flex position={'relative'} width={'100%'}>
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
className={'checkerboard'}
/>
{shouldShowMask && maskPath && (
<Image
position={'absolute'}
top={0}
left={0}
fit={'contain'}
src={maskPath}
rounded={'md'}
zIndex={1}
/>
)}
</Flex>
)}
</Flex>
);
};
export default InitAndMaskImage;

View File

@ -0,0 +1,151 @@
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
import { SyntheticEvent, useCallback } from 'react';
import { FaTrash, FaUpload } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import {
OptionsState,
setInitialImagePath,
setMaskPath,
} from '../../features/options/optionsSlice';
import {
uploadInitialImage,
uploadMaskImage,
} from '../../app/socketio/actions';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import ImageUploader from './ImageUploader';
import { FileRejection } from 'react-dropzone';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
initialImagePath: options.initialImagePath,
maskPath: options.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
type InitAndMaskUploadButtonsProps = {
setShouldShowMask: (b: boolean) => void;
};
/**
* Init and mask image upload buttons.
*/
const InitAndMaskUploadButtons = ({
setShouldShowMask,
}: InitAndMaskUploadButtonsProps) => {
const dispatch = useAppDispatch();
const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
// Use a toast to alert user when a file upload is rejected
const toast = useToast();
// Clear the init and mask images
const handleClickResetInitialImage = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(''));
};
// Clear the init and mask images
const handleClickResetMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setMaskPath(''));
};
// Handle hover to view initial image and mask image
const handleMouseOverInitialImageUploadButton = () =>
setShouldShowMask(false);
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
// Callbacks to for handling file upload attempts
const initImageFileAcceptedCallback = useCallback(
(file: File) => dispatch(uploadInitialImage(file)),
[dispatch]
);
const maskImageFileAcceptedCallback = useCallback(
(file: File) => dispatch(uploadMaskImage(file)),
[dispatch]
);
const fileRejectionCallback = useCallback(
(rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
},
[toast]
);
return (
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
<ImageUploader
fileAcceptedCallback={initImageFileAcceptedCallback}
fileRejectionCallback={fileRejectionCallback}
>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton}
leftIcon={<FaUpload />}
width={'100%'}
>
Image
</Button>
</ImageUploader>
<IconButton
isDisabled={!initialImagePath}
size={'sm'}
aria-label={'Reset mask'}
onClick={handleClickResetInitialImage}
icon={<FaTrash />}
/>
<ImageUploader
fileAcceptedCallback={maskImageFileAcceptedCallback}
fileRejectionCallback={fileRejectionCallback}
>
<Button
isDisabled={!initialImagePath}
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton}
leftIcon={<FaUpload />}
width={'100%'}
>
Mask
</Button>
</ImageUploader>
<IconButton
isDisabled={!maskPath}
size={'sm'}
aria-label={'Reset mask'}
onClick={handleClickResetMask}
icon={<FaTrash />}
/>
</Flex>
);
};
export default InitAndMaskUploadButtons;

View File

@ -0,0 +1,217 @@
import {
Flex,
Box,
Text,
Accordion,
AccordionItem,
AccordionButton,
AccordionIcon,
AccordionPanel,
Switch,
ExpandedIndex,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setShouldRunGFPGAN,
setShouldRunESRGAN,
OptionsState,
setShouldUseInitImage,
} from '../options/optionsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { setOpenAccordions, SystemState } from '../system/systemSlice';
import SeedVariationOptions from './SeedVariationOptions';
import SamplerOptions from './SamplerOptions';
import ESRGANOptions from './ESRGANOptions';
import GFPGANOptions from './GFPGANOptions';
import OutputOptions from './OutputOptions';
import ImageToImageOptions from './ImageToImageOptions';
import { ChangeEvent } from 'react';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
initialImagePath: options.initialImagePath,
shouldUseInitImage: options.shouldUseInitImage,
shouldRunESRGAN: options.shouldRunESRGAN,
shouldRunGFPGAN: options.shouldRunGFPGAN,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
openAccordions: system.openAccordions,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Main container for generation and processing parameters.
*/
const OptionsAccordion = () => {
const {
shouldRunESRGAN,
shouldRunGFPGAN,
shouldUseInitImage,
initialImagePath,
} = useAppSelector(optionsSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
/**
* Stores accordion state in redux so preferred UI setup is retained.
*/
const handleChangeAccordionState = (openAccordions: ExpandedIndex) =>
dispatch(setOpenAccordions(openAccordions));
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunESRGAN(e.target.checked));
const handleChangeShouldRunGFPGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunGFPGAN(e.target.checked));
const handleChangeShouldUseInitImage = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseInitImage(e.target.checked));
return (
<Accordion
defaultIndex={openAccordions}
allowMultiple
reduceMotion
onChange={handleChangeAccordionState}
>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Seed & Variation
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SeedVariationOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Sampler
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SamplerOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Upscale (ESRGAN)</Text>
<Switch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={handleChangeShouldRunESRGAN}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ESRGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Fix Faces (GFPGAN)</Text>
<Switch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunGFPGAN}
onChange={handleChangeShouldRunGFPGAN}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<GFPGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Image to Image</Text>
<Switch
isDisabled={!initialImagePath}
isChecked={shouldUseInitImage}
onChange={handleChangeShouldUseInitImage}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ImageToImageOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Output
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<OutputOptions />
</AccordionPanel>
</AccordionItem>
</Accordion>
);
};
export default OptionsAccordion;

View File

@ -0,0 +1,76 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setHeight, setWidth, setSeamless, OptionsState } from '../options/optionsSlice';
import { HEIGHTS, WIDTHS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDSelect from '../../common/components/SDSelect';
import SDSwitch from '../../common/components/SDSwitch';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
height: options.height,
width: options.width,
seamless: options.seamless,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Image output options. Includes width, height, seamless tiling.
*/
const OutputOptions = () => {
const dispatch = useAppDispatch();
const { height, width, seamless } = useAppSelector(optionsSelector);
const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setWidth(Number(e.target.value)));
const handleChangeHeight = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setHeight(Number(e.target.value)));
const handleChangeSeamless = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setSeamless(e.target.checked));
return (
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<SDSelect
label="Width"
value={width}
flexGrow={1}
onChange={handleChangeWidth}
validValues={WIDTHS}
/>
<SDSelect
label="Height"
value={height}
flexGrow={1}
onChange={handleChangeHeight}
validValues={HEIGHTS}
/>
</Flex>
<SDSwitch
label="Seamless tiling"
fontSize={'md'}
isChecked={seamless}
onChange={handleChangeSeamless}
/>
</Flex>
);
};
export default OutputOptions;

View File

@ -0,0 +1,68 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { cancelProcessing, generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import SDButton from '../../common/components/SDButton';
import useCheckParameters from '../../common/hooks/useCheckParameters';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Buttons to start and cancel image generation.
*/
const ProcessButtons = () => {
const dispatch = useAppDispatch();
const { isProcessing, isConnected } = useAppSelector(systemSelector);
const isReady = useCheckParameters();
const handleClickGenerate = () => dispatch(generateImage());
const handleClickCancel = () => dispatch(cancelProcessing());
return (
<Flex
gap={2}
direction={'column'}
alignItems={'space-between'}
height={'100%'}
>
<SDButton
label="Generate"
type="submit"
colorScheme="green"
flexGrow={1}
isDisabled={!isReady}
fontSize={'md'}
size={'md'}
onClick={handleClickGenerate}
/>
<SDButton
label="Cancel"
colorScheme="red"
flexGrow={1}
fontSize={'md'}
size={'md'}
isDisabled={!isConnected || !isProcessing}
onClick={handleClickCancel}
/>
</Flex>
);
};
export default ProcessButtons;

View File

@ -0,0 +1,44 @@
import { Textarea } from '@chakra-ui/react';
import {
ChangeEvent,
KeyboardEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setPrompt } from '../options/optionsSlice';
/**
* Prompt input text area.
*/
const PromptInput = () => {
const { prompt } = useAppSelector((state: RootState) => state.options);
const dispatch = useAppDispatch();
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) =>
dispatch(setPrompt(e.target.value));
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false) {
e.preventDefault();
dispatch(generateImage())
}
};
return (
<Textarea
id="prompt"
name="prompt"
resize="none"
size={'lg'}
height={'100%'}
isInvalid={!prompt.length}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
value={prompt}
placeholder="I'm dreaming of..."
/>
);
};
export default PromptInput;

View File

@ -0,0 +1,74 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setCfgScale, setSampler, setSteps, OptionsState } from '../options/optionsSlice';
import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
steps: options.steps,
cfgScale: options.cfgScale,
sampler: options.sampler,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Sampler options. Includes steps, CFG scale, sampler.
*/
const SamplerOptions = () => {
const dispatch = useAppDispatch();
const { steps, cfgScale, sampler } = useAppSelector(optionsSelector);
const handleChangeSteps = (v: string | number) =>
dispatch(setSteps(Number(v)));
const handleChangeCfgScale = (v: string | number) =>
dispatch(setCfgScale(Number(v)));
const handleChangeSampler = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setSampler(e.target.value));
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label="Steps"
min={1}
step={1}
precision={0}
onChange={handleChangeSteps}
value={steps}
/>
<SDNumberInput
label="CFG scale"
step={0.5}
onChange={handleChangeCfgScale}
value={cfgScale}
/>
<SDSelect
label="Sampler"
value={sampler}
onChange={handleChangeSampler}
validValues={SAMPLERS}
/>
</Flex>
);
};
export default SamplerOptions;

View File

@ -0,0 +1,159 @@
import {
Flex,
Input,
HStack,
FormControl,
FormLabel,
Text,
Button,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import randomInt from '../../common/util/randomInt';
import { validateSeedWeights } from '../../common/util/seedWeightPairs';
import {
OptionsState,
setIterations,
setSeed,
setSeedWeights,
setShouldGenerateVariations,
setShouldRandomizeSeed,
setVariationAmount,
} from './optionsSlice';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
variationAmount: options.variationAmount,
seedWeights: options.seedWeights,
shouldGenerateVariations: options.shouldGenerateVariations,
shouldRandomizeSeed: options.shouldRandomizeSeed,
seed: options.seed,
iterations: options.iterations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Seed & variation options. Includes iteration, seed, seed randomization, variation options.
*/
const SeedVariationOptions = () => {
const {
shouldGenerateVariations,
variationAmount,
seedWeights,
shouldRandomizeSeed,
seed,
iterations,
} = useAppSelector(optionsSelector);
const dispatch = useAppDispatch();
const handleChangeIterations = (v: string | number) =>
dispatch(setIterations(Number(v)));
const handleChangeShouldRandomizeSeed = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRandomizeSeed(e.target.checked));
const handleChangeSeed = (v: string | number) => dispatch(setSeed(Number(v)));
const handleClickRandomizeSeed = () =>
dispatch(setSeed(randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)));
const handleChangeShouldGenerateVariations = (
e: ChangeEvent<HTMLInputElement>
) => dispatch(setShouldGenerateVariations(e.target.checked));
const handleChangevariationAmount = (v: string | number) =>
dispatch(setVariationAmount(Number(v)));
const handleChangeSeedWeights = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setSeedWeights(e.target.value));
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label="Images to generate"
step={1}
min={1}
precision={0}
onChange={handleChangeIterations}
value={iterations}
/>
<SDSwitch
label="Randomize seed on generation"
isChecked={shouldRandomizeSeed}
onChange={handleChangeShouldRandomizeSeed}
/>
<Flex gap={2}>
<SDNumberInput
label="Seed"
step={1}
precision={0}
flexGrow={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
isDisabled={shouldRandomizeSeed}
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={handleChangeSeed}
value={seed}
/>
<Button
size={'sm'}
isDisabled={shouldRandomizeSeed}
onClick={handleClickRandomizeSeed}
>
<Text pl={2} pr={2}>
Shuffle
</Text>
</Button>
</Flex>
<SDSwitch
label="Generate variations"
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={handleChangeShouldGenerateVariations}
/>
<SDNumberInput
label="Variation amount"
value={variationAmount}
step={0.01}
min={0}
max={1}
onChange={handleChangevariationAmount}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text whiteSpace="nowrap">Seed Weights</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={handleChangeSeedWeights}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default SeedVariationOptions;

View File

@ -1,24 +1,15 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { SDMetadata } from '../gallery/gallerySlice'; import * as InvokeAI from '../../app/invokeai';
import randomInt from './util/randomInt'; import promptToString from '../../common/util/promptToString';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants'; import { seedWeightsToString } from '../../common/util/seedWeightPairs';
const calculateRealSteps = ( export type UpscalingLevel = 2 | 4;
steps: number,
strength: number,
hasInitImage: boolean
): number => {
return hasInitImage ? Math.floor(strength * steps) : steps;
};
export type UpscalingLevel = 0 | 2 | 3 | 4; export interface OptionsState {
export interface SDState {
prompt: string; prompt: string;
iterations: number; iterations: number;
steps: number; steps: number;
realSteps: number;
cfgScale: number; cfgScale: number;
height: number; height: number;
width: number; width: number;
@ -34,18 +25,17 @@ export interface SDState {
seamless: boolean; seamless: boolean;
shouldFitToWidthHeight: boolean; shouldFitToWidthHeight: boolean;
shouldGenerateVariations: boolean; shouldGenerateVariations: boolean;
variantAmount: number; variationAmount: number;
seedWeights: string; seedWeights: string;
shouldRunESRGAN: boolean; shouldRunESRGAN: boolean;
shouldRunGFPGAN: boolean; shouldRunGFPGAN: boolean;
shouldRandomizeSeed: boolean; shouldRandomizeSeed: boolean;
} }
const initialSDState: SDState = { const initialOptionsState: OptionsState = {
prompt: '', prompt: '',
iterations: 1, iterations: 1,
steps: 50, steps: 50,
realSteps: 50,
cfgScale: 7.5, cfgScale: 7.5,
height: 512, height: 512,
width: 512, width: 512,
@ -58,7 +48,7 @@ const initialSDState: SDState = {
maskPath: '', maskPath: '',
shouldFitToWidthHeight: true, shouldFitToWidthHeight: true,
shouldGenerateVariations: false, shouldGenerateVariations: false,
variantAmount: 0.1, variationAmount: 0.1,
seedWeights: '', seedWeights: '',
shouldRunESRGAN: false, shouldRunESRGAN: false,
upscalingLevel: 4, upscalingLevel: 4,
@ -68,27 +58,25 @@ const initialSDState: SDState = {
shouldRandomizeSeed: true, shouldRandomizeSeed: true,
}; };
const initialState: SDState = initialSDState; const initialState: OptionsState = initialOptionsState;
export const sdSlice = createSlice({ export const optionsSlice = createSlice({
name: 'sd', name: 'options',
initialState, initialState,
reducers: { reducers: {
setPrompt: (state, action: PayloadAction<string>) => { setPrompt: (state, action: PayloadAction<string | InvokeAI.Prompt>) => {
state.prompt = action.payload; const newPrompt = action.payload;
if (typeof newPrompt === 'string') {
state.prompt = newPrompt;
} else {
state.prompt = promptToString(newPrompt);
}
}, },
setIterations: (state, action: PayloadAction<number>) => { setIterations: (state, action: PayloadAction<number>) => {
state.iterations = action.payload; state.iterations = action.payload;
}, },
setSteps: (state, action: PayloadAction<number>) => { setSteps: (state, action: PayloadAction<number>) => {
const { img2imgStrength, initialImagePath } = state; state.steps = action.payload;
const steps = action.payload;
state.steps = steps;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
}, },
setCfgScale: (state, action: PayloadAction<number>) => { setCfgScale: (state, action: PayloadAction<number>) => {
state.cfgScale = action.payload; state.cfgScale = action.payload;
@ -107,14 +95,7 @@ export const sdSlice = createSlice({
state.shouldRandomizeSeed = false; state.shouldRandomizeSeed = false;
}, },
setImg2imgStrength: (state, action: PayloadAction<number>) => { setImg2imgStrength: (state, action: PayloadAction<number>) => {
const img2imgStrength = action.payload; state.img2imgStrength = action.payload;
const { steps, initialImagePath } = state;
state.img2imgStrength = img2imgStrength;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
}, },
setGfpganStrength: (state, action: PayloadAction<number>) => { setGfpganStrength: (state, action: PayloadAction<number>) => {
state.gfpganStrength = action.payload; state.gfpganStrength = action.payload;
@ -129,15 +110,9 @@ export const sdSlice = createSlice({
state.shouldUseInitImage = action.payload; state.shouldUseInitImage = action.payload;
}, },
setInitialImagePath: (state, action: PayloadAction<string>) => { setInitialImagePath: (state, action: PayloadAction<string>) => {
const initialImagePath = action.payload; const newInitialImagePath = action.payload;
const { steps, img2imgStrength } = state; state.shouldUseInitImage = newInitialImagePath ? true : false;
state.shouldUseInitImage = initialImagePath ? true : false; state.initialImagePath = newInitialImagePath;
state.initialImagePath = initialImagePath;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
}, },
setMaskPath: (state, action: PayloadAction<string>) => { setMaskPath: (state, action: PayloadAction<string>) => {
state.maskPath = action.payload; state.maskPath = action.payload;
@ -151,13 +126,11 @@ export const sdSlice = createSlice({
resetSeed: (state) => { resetSeed: (state) => {
state.seed = -1; state.seed = -1;
}, },
randomizeSeed: (state) => {
state.seed = randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX);
},
setParameter: ( setParameter: (
state, state,
action: PayloadAction<{ key: string; value: string | number | boolean }> action: PayloadAction<{ key: string; value: string | number | boolean }>
) => { ) => {
// TODO: This probably needs to be refactored.
const { key, value } = action.payload; const { key, value } = action.payload;
const temp = { ...state, [key]: value }; const temp = { ...state, [key]: value };
if (key === 'seed') { if (key === 'seed') {
@ -171,70 +144,95 @@ export const sdSlice = createSlice({
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => { setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
state.shouldGenerateVariations = action.payload; state.shouldGenerateVariations = action.payload;
}, },
setVariantAmount: (state, action: PayloadAction<number>) => { setVariationAmount: (state, action: PayloadAction<number>) => {
state.variantAmount = action.payload; state.variationAmount = action.payload;
}, },
setSeedWeights: (state, action: PayloadAction<string>) => { setSeedWeights: (state, action: PayloadAction<string>) => {
state.seedWeights = action.payload; state.seedWeights = action.payload;
}, },
setAllParameters: (state, action: PayloadAction<SDMetadata>) => { setAllParameters: (state, action: PayloadAction<InvokeAI.Metadata>) => {
const { const {
prompt, type,
steps, postprocessing,
cfgScale,
height,
width,
sampler, sampler,
prompt,
seed, seed,
img2imgStrength, variations,
gfpganStrength, steps,
upscalingLevel, cfg_scale,
upscalingStrength,
initialImagePath,
maskPath,
seamless, seamless,
shouldFitToWidthHeight, width,
} = action.payload; height,
strength,
fit,
init_image_path,
mask_image_path,
} = action.payload.image;
// ?? = falsy values ('', 0, etc) are used if (type === 'img2img') {
// || = falsy values not used if (init_image_path) state.initialImagePath = init_image_path;
state.prompt = prompt ?? state.prompt; if (mask_image_path) state.maskPath = mask_image_path;
state.steps = steps || state.steps; if (strength) state.img2imgStrength = strength;
state.cfgScale = cfgScale || state.cfgScale; if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit;
state.width = width || state.width; state.shouldUseInitImage = true;
state.height = height || state.height; } else {
state.sampler = sampler || state.sampler; state.shouldUseInitImage = false;
state.seed = seed ?? state.seed; }
state.seamless = seamless ?? state.seamless;
state.shouldFitToWidthHeight = if (variations && variations.length > 0) {
shouldFitToWidthHeight ?? state.shouldFitToWidthHeight; state.seedWeights = seedWeightsToString(variations);
state.img2imgStrength = img2imgStrength ?? state.img2imgStrength; state.shouldGenerateVariations = true;
state.gfpganStrength = gfpganStrength ?? state.gfpganStrength; } else {
state.upscalingLevel = upscalingLevel ?? state.upscalingLevel; state.shouldGenerateVariations = false;
state.upscalingStrength = upscalingStrength ?? state.upscalingStrength; }
state.initialImagePath = initialImagePath ?? state.initialImagePath;
state.maskPath = maskPath ?? state.maskPath;
// If the image whose parameters we are using has a seed, disable randomizing the seed
if (seed) { if (seed) {
state.seed = seed;
state.shouldRandomizeSeed = false; state.shouldRandomizeSeed = false;
} }
// if we have a gfpgan strength, enable it let postprocessingNotDone = ['gfpgan', 'esrgan'];
state.shouldRunGFPGAN = gfpganStrength ? true : false; if (postprocessing && postprocessing.length > 0) {
postprocessing.forEach(
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
if (strength) state.gfpganStrength = strength;
state.shouldRunGFPGAN = true;
postprocessingNotDone = postprocessingNotDone.filter(
(p) => p !== 'gfpgan'
);
}
if (postprocess.type === 'esrgan') {
const { scale, strength } = postprocess;
if (scale) state.upscalingLevel = scale;
if (strength) state.upscalingStrength = strength;
state.shouldRunESRGAN = true;
postprocessingNotDone = postprocessingNotDone.filter(
(p) => p !== 'esrgan'
);
}
}
);
}
// if we have a esrgan strength, enable it postprocessingNotDone.forEach((p) => {
state.shouldRunESRGAN = upscalingLevel ? true : false; if (p === 'esrgan') state.shouldRunESRGAN = false;
if (p === 'gfpgan') state.shouldRunGFPGAN = false;
});
// if we want to recreate an image exactly, we disable variations if (prompt) state.prompt = promptToString(prompt);
state.shouldGenerateVariations = false; if (sampler) state.sampler = sampler;
if (steps) state.steps = steps;
state.shouldUseInitImage = initialImagePath ? true : false; if (cfg_scale) state.cfgScale = cfg_scale;
if (typeof seamless === 'boolean') state.seamless = seamless;
if (width) state.width = width;
if (height) state.height = height;
}, },
resetSDState: (state) => { resetOptionsState: (state) => {
return { return {
...state, ...state,
...initialSDState, ...initialOptionsState,
}; };
}, },
setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => { setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => {
@ -267,17 +265,16 @@ export const {
setInitialImagePath, setInitialImagePath,
setMaskPath, setMaskPath,
resetSeed, resetSeed,
randomizeSeed, resetOptionsState,
resetSDState,
setShouldFitToWidthHeight, setShouldFitToWidthHeight,
setParameter, setParameter,
setShouldGenerateVariations, setShouldGenerateVariations,
setSeedWeights, setSeedWeights,
setVariantAmount, setVariationAmount,
setAllParameters, setAllParameters,
setShouldRunGFPGAN, setShouldRunGFPGAN,
setShouldRunESRGAN, setShouldRunESRGAN,
setShouldRandomizeSeed, setShouldRandomizeSeed,
} = sdSlice.actions; } = optionsSlice.actions;
export default sdSlice.reducer; export default optionsSlice.reducer;

View File

@ -1,84 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
SDState,
} from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
upscalingLevel: sd.upscalingLevel,
upscalingStrength: sd.upscalingStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ESRGANOptions = () => {
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDSelect
isDisabled={!isESRGANAvailable}
label='Scale'
value={upscalingLevel}
onChange={(e) =>
dispatch(
setUpscalingLevel(
Number(e.target.value) as UpscalingLevel
)
)
}
validValues={UPSCALING_LEVELS}
/>
<SDNumberInput
isDisabled={!isESRGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setUpscalingStrength(Number(v)))}
value={upscalingStrength}
/>
</Flex>
);
};
export default ESRGANOptions;

View File

@ -1,63 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { SDState, setGfpganStrength } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
gfpganStrength: sd.gfpganStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const GFPGANOptions = () => {
const { gfpganStrength } = useAppSelector(sdSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!isGFPGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setGfpganStrength(Number(v)))}
value={gfpganStrength}
/>
</Flex>
);
};
export default GFPGANOptions;

View File

@ -1,54 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import InitImage from './InitImage';
import {
SDState,
setImg2imgStrength,
setShouldFitToWidthHeight,
} from './sdSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
img2imgStrength: sd.img2imgStrength,
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
};
}
);
const ImageToImageOptions = () => {
const { initialImagePath, img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!initialImagePath}
label='Strength'
step={0.01}
min={0}
max={1}
onChange={(v) => dispatch(setImg2imgStrength(Number(v)))}
value={img2imgStrength}
/>
<SDSwitch
isDisabled={!initialImagePath}
label='Fit initial image to output size'
isChecked={shouldFitToWidthHeight}
onChange={(e) =>
dispatch(setShouldFitToWidthHeight(e.target.checked))
}
/>
<InitImage />
</Flex>
);
};
export default ImageToImageOptions;

View File

@ -1,155 +0,0 @@
import {
Button,
Flex,
IconButton,
Image,
useToast,
} from '@chakra-ui/react';
import { SyntheticEvent, useCallback, useState } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { FaTrash } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import {
SDState,
setInitialImagePath,
setMaskPath,
} from '../../features/sd/sdSlice';
import MaskUploader from './MaskUploader';
import './InitImage.css';
import { uploadInitialImage } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
maskPath: sd.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
const InitImage = () => {
const toast = useToast();
const dispatch = useAppDispatch();
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadInitialImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(''));
dispatch(setMaskPath(''));
};
const handleMouseOverInitialImageUploadButton = () =>
setShouldShowMask(false);
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
return (
<Flex
{...getRootProps({
onClick: initialImagePath ? (e) => e.stopPropagation() : undefined,
})}
direction={'column'}
alignItems={'center'}
gap={2}
>
<input {...getInputProps({ multiple: false })} />
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton}
>
Upload Image
</Button>
<MaskUploader>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton}
>
Upload Mask
</Button>
</MaskUploader>
<IconButton
size={'sm'}
aria-label={'Reset initial image and mask'}
onClick={handleClickResetInitialImageAndMask}
icon={<FaTrash />}
/>
</Flex>
{initialImagePath && (
<Flex position={'relative'} width={'100%'}>
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
className={'checkerboard'}
/>
{shouldShowMask && maskPath && (
<Image
position={'absolute'}
top={0}
left={0}
fit={'contain'}
src={maskPath}
rounded={'md'}
zIndex={1}
className={'checkerboard'}
/>
)}
</Flex>
)}
</Flex>
);
};
export default InitImage;

View File

@ -1,61 +0,0 @@
import { useToast } from '@chakra-ui/react';
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useAppDispatch } from '../../app/hooks';
import { uploadMaskImage } from '../../app/socketio';
type Props = {
children: ReactElement;
};
const MaskUploader = ({ children }: Props) => {
const dispatch = useAppDispatch();
const toast = useToast();
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) =>
acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadMaskImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
return (
<div {...getRootProps()}>
<input {...getInputProps({ multiple: false })} />
{cloneElement(children, {
onClick: handleClickUploadIcon,
})}
</div>
);
};
export default MaskUploader;

View File

@ -1,211 +0,0 @@
import {
Flex,
Box,
Text,
Accordion,
AccordionItem,
AccordionButton,
AccordionIcon,
AccordionPanel,
Switch,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
setShouldRunGFPGAN,
setShouldRunESRGAN,
SDState,
setShouldUseInitImage,
} from '../sd/sdSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { setOpenAccordions, SystemState } from '../system/systemSlice';
import SeedVariationOptions from './SeedVariationOptions';
import SamplerOptions from './SamplerOptions';
import ESRGANOptions from './ESRGANOptions';
import GFPGANOptions from './GFPGANOptions';
import OutputOptions from './OutputOptions';
import ImageToImageOptions from './ImageToImageOptions';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
shouldUseInitImage: sd.shouldUseInitImage,
shouldRunESRGAN: sd.shouldRunESRGAN,
shouldRunGFPGAN: sd.shouldRunGFPGAN,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
openAccordions: system.openAccordions,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const OptionsAccordion = () => {
const {
shouldRunESRGAN,
shouldRunGFPGAN,
shouldUseInitImage,
initialImagePath,
} = useAppSelector(sdSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Accordion
defaultIndex={openAccordions}
allowMultiple
reduceMotion
onChange={(openAccordions) =>
dispatch(setOpenAccordions(openAccordions))
}
>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Seed & Variation
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SeedVariationOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Sampler
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SamplerOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Upscale (ESRGAN)</Text>
<Switch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={(e) =>
dispatch(
setShouldRunESRGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ESRGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Fix Faces (GFPGAN)</Text>
<Switch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunGFPGAN}
onChange={(e) =>
dispatch(
setShouldRunGFPGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<GFPGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Image to Image</Text>
<Switch
isDisabled={!initialImagePath}
isChecked={shouldUseInitImage}
onChange={(e) =>
dispatch(
setShouldUseInitImage(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ImageToImageOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Output
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<OutputOptions />
</AccordionPanel>
</AccordionItem>
</Accordion>
);
};
export default OptionsAccordion;

View File

@ -1,66 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
import SDSelect from '../../components/SDSelect';
import { HEIGHTS, WIDTHS } from '../../app/constants';
import SDSwitch from '../../components/SDSwitch';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
height: sd.height,
width: sd.width,
seamless: sd.seamless,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const OutputOptions = () => {
const { height, width, seamless } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<SDSelect
label='Width'
value={width}
flexGrow={1}
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
validValues={WIDTHS}
/>
<SDSelect
label='Height'
value={height}
flexGrow={1}
onChange={(e) =>
dispatch(setHeight(Number(e.target.value)))
}
validValues={HEIGHTS}
/>
</Flex>
<SDSwitch
label='Seamless tiling'
fontSize={'md'}
isChecked={seamless}
onChange={(e) => dispatch(setSeamless(e.target.checked))}
/>
</Flex>
);
};
export default OutputOptions;

View File

@ -1,58 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { cancelProcessing, generateImage } from '../../app/socketio';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { SystemState } from '../system/systemSlice';
import useCheckParameters from '../system/useCheckParameters';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ProcessButtons = () => {
const { isProcessing, isConnected } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const isReady = useCheckParameters();
return (
<Flex gap={2} direction={'column'} alignItems={'space-between'} height={'100%'}>
<SDButton
label='Generate'
type='submit'
colorScheme='green'
flexGrow={1}
isDisabled={!isReady}
fontSize={'md'}
size={'md'}
onClick={() => dispatch(generateImage())}
/>
<SDButton
label='Cancel'
colorScheme='red'
flexGrow={1}
fontSize={'md'}
size={'md'}
isDisabled={!isConnected || !isProcessing}
onClick={() => dispatch(cancelProcessing())}
/>
</Flex>
);
};
export default ProcessButtons;

View File

@ -1,25 +0,0 @@
import { Textarea } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setPrompt } from '../sd/sdSlice';
const PromptInput = () => {
const { prompt } = useAppSelector((state: RootState) => state.sd);
const dispatch = useAppDispatch();
return (
<Textarea
id='prompt'
name='prompt'
resize='none'
size={'lg'}
height={'100%'}
isInvalid={!prompt.length}
onChange={(e) => dispatch(setPrompt(e.target.value))}
value={prompt}
placeholder="I'm dreaming of..."
/>
);
};
export default PromptInput;

View File

@ -1,51 +0,0 @@
import {
Slider,
SliderTrack,
SliderFilledTrack,
SliderThumb,
FormControl,
FormLabel,
Text,
Flex,
SliderProps,
} from '@chakra-ui/react';
interface Props extends SliderProps {
label: string;
value: number;
fontSize?: number | string;
}
const SDSlider = ({
label,
value,
fontSize = 'sm',
onChange,
...rest
}: Props) => {
return (
<FormControl>
<Flex gap={2}>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'>
{label}
</Text>
</FormLabel>
<Slider
aria-label={label}
focusThumbOnChange={true}
value={value}
onChange={onChange}
{...rest}
>
<SliderTrack>
<SliderFilledTrack />
</SliderTrack>
<SliderThumb />
</Slider>
</Flex>
</FormControl>
);
};
export default SDSlider;

View File

@ -1,62 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
steps: sd.steps,
cfgScale: sd.cfgScale,
sampler: sd.sampler,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const SamplerOptions = () => {
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Steps'
min={1}
step={1}
precision={0}
onChange={(v) => dispatch(setSteps(Number(v)))}
value={steps}
/>
<SDNumberInput
label='CFG scale'
step={0.5}
onChange={(v) => dispatch(setCfgScale(Number(v)))}
value={cfgScale}
/>
<SDSelect
label='Sampler'
value={sampler}
onChange={(e) => dispatch(setSampler(e.target.value))}
validValues={SAMPLERS}
/>
</Flex>
);
};
export default SamplerOptions;

View File

@ -1,144 +0,0 @@
import {
Flex,
Input,
HStack,
FormControl,
FormLabel,
Text,
Button,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import {
randomizeSeed,
SDState,
setIterations,
setSeed,
setSeedWeights,
setShouldGenerateVariations,
setShouldRandomizeSeed,
setVariantAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
shouldRandomizeSeed: sd.shouldRandomizeSeed,
seed: sd.seed,
iterations: sd.iterations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const SeedVariationOptions = () => {
const {
shouldGenerateVariations,
variantAmount,
seedWeights,
shouldRandomizeSeed,
seed,
iterations,
} = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Images to generate'
step={1}
min={1}
precision={0}
onChange={(v) => dispatch(setIterations(Number(v)))}
value={iterations}
/>
<SDSwitch
label='Randomize seed on generation'
isChecked={shouldRandomizeSeed}
onChange={(e) =>
dispatch(setShouldRandomizeSeed(e.target.checked))
}
/>
<Flex gap={2}>
<SDNumberInput
label='Seed'
step={1}
precision={0}
flexGrow={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
isDisabled={shouldRandomizeSeed}
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={(v) => dispatch(setSeed(Number(v)))}
value={seed}
/>
<Button
size={'sm'}
isDisabled={shouldRandomizeSeed}
onClick={() => dispatch(randomizeSeed())}
>
<Text pl={2} pr={2}>
Shuffle
</Text>
</Button>
</Flex>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Variation amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default SeedVariationOptions;

View File

@ -1,92 +0,0 @@
import {
Flex,
FormControl,
FormLabel,
HStack,
Input,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import {
SDState,
setSeedWeights,
setShouldGenerateVariations,
setVariantAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const Variant = () => {
const { shouldGenerateVariations, variantAmount, seedWeights } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} alignItems={'center'} pl={1}>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
width={240}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={'sm'} whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default Variant;

View File

@ -1,56 +0,0 @@
export interface SeedWeightPair {
seed: number;
weight: number;
}
export type SeedWeights = Array<Array<number>>;
export const stringToSeedWeights = (string: string): SeedWeights | boolean => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
const pairs = arrPairs.map((p) => {
return [parseInt(p[0]), parseFloat(p[1])];
});
if (!validateSeedWeights(pairs)) {
return false;
}
return pairs;
};
export const validateSeedWeights = (
seedWeights: SeedWeights | string
): boolean => {
return typeof seedWeights === 'string'
? Boolean(stringToSeedWeights(seedWeights))
: Boolean(
seedWeights.length &&
!seedWeights.some((pair) => {
const [seed, weight] = pair;
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
const isWeightValid =
!isNaN(parseInt(weight.toString(), 10)) &&
weight >= 0 &&
weight <= 1;
return !(isSeedValid && isWeightValid);
})
);
};
export const seedWeightsToString = (
seedWeights: SeedWeights
): string | boolean => {
if (!validateSeedWeights(seedWeights)) {
return false;
}
return seedWeights.reduce((acc, pair, i, arr) => {
const [seed, weight] = pair;
acc += `${seed}:${weight}`;
if (i !== arr.length - 1) {
acc += ',';
}
return acc;
}, '');
};

View File

@ -1,11 +1,11 @@
import { import {
IconButton, IconButton,
useColorModeValue, useColorModeValue,
Flex, Flex,
Text, Text,
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setShouldShowLogViewer, SystemState } from './systemSlice'; import { setShouldShowLogViewer, SystemState } from './systemSlice';
import { useLayoutEffect, useRef, useState } from 'react'; import { useLayoutEffect, useRef, useState } from 'react';
@ -14,112 +14,138 @@ import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
const logSelector = createSelector( const logSelector = createSelector(
(state: RootState) => state.system, (state: RootState) => state.system,
(system: SystemState) => system.log, (system: SystemState) => system.log,
{ {
memoizeOptions: { memoizeOptions: {
resultEqualityCheck: (a, b) => a.length === b.length, // We don't need a deep equality check for this selector.
}, resultEqualityCheck: (a, b) => a.length === b.length,
} },
}
); );
const systemSelector = createSelector( const systemSelector = createSelector(
(state: RootState) => state.system, (state: RootState) => state.system,
(system: SystemState) => { (system: SystemState) => {
return { shouldShowLogViewer: system.shouldShowLogViewer }; return { shouldShowLogViewer: system.shouldShowLogViewer };
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
}, },
{ }
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
); );
/**
* Basic log viewer, floats on bottom of page.
*/
const LogViewer = () => { const LogViewer = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const bg = useColorModeValue('gray.50', 'gray.900'); const log = useAppSelector(logSelector);
const borderColor = useColorModeValue('gray.500', 'gray.500'); const { shouldShowLogViewer } = useAppSelector(systemSelector);
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const log = useAppSelector(logSelector); // Set colors based on dark/light mode
const { shouldShowLogViewer } = useAppSelector(systemSelector); const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500');
const logTextColors = useColorModeValue(
{
info: undefined,
warning: 'yellow.500',
error: 'red.500',
},
{
info: undefined,
warning: 'yellow.300',
error: 'red.300',
}
);
const viewerRef = useRef<HTMLDivElement>(null); // Rudimentary autoscroll
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const viewerRef = useRef<HTMLDivElement>(null);
useLayoutEffect(() => { /**
if (viewerRef.current !== null && shouldAutoscroll) { * If autoscroll is on, scroll to the bottom when:
viewerRef.current.scrollTop = viewerRef.current.scrollHeight; * - log updates
} * - viewer is toggled
}); *
* Also scroll to the bottom whenever autoscroll is turned on.
*/
useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
}
}, [shouldAutoscroll, log, shouldShowLogViewer]);
return ( const handleClickLogViewerToggle = () => {
<> dispatch(setShouldShowLogViewer(!shouldShowLogViewer));
{shouldShowLogViewer && ( };
<Flex
position={'fixed'} return (
left={0} <>
bottom={0} {shouldShowLogViewer && (
height='200px' <Flex
width='100vw' position={'fixed'}
overflow='auto' left={0}
direction='column' bottom={0}
fontFamily='monospace' height="200px" // TODO: Make the log viewer resizeable.
fontSize='sm' width="100vw"
pl={12} overflow="auto"
pr={2} direction="column"
pb={2} fontFamily="monospace"
borderTopWidth='4px' fontSize="sm"
borderColor={borderColor} pl={12}
background={bg} pr={2}
ref={viewerRef} pb={2}
> borderTopWidth="4px"
{log.map((entry, i) => ( borderColor={borderColor}
<Flex gap={2} key={i}> background={bg}
<Text fontSize='sm' fontWeight={'semibold'}> ref={viewerRef}
{entry.timestamp}: >
</Text> {log.map((entry, i) => {
<Text fontSize='sm' wordBreak={'break-all'}> const { timestamp, message, level } = entry;
{entry.message} return (
</Text> <Flex gap={2} key={i} textColor={logTextColors[level]}>
</Flex> <Text fontSize="sm" fontWeight={'semibold'}>
))} {timestamp}:
</Flex> </Text>
)} <Text fontSize="sm" wordBreak={'break-all'}>
{shouldShowLogViewer && ( {message}
<Tooltip </Text>
label={ </Flex>
shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off' );
} })}
> </Flex>
<IconButton )}
size='sm' {shouldShowLogViewer && (
position={'fixed'} <Tooltip label={shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'}>
left={2} <IconButton
bottom={12} size="sm"
aria-label='Toggle autoscroll' position={'fixed'}
variant={'solid'} left={2}
colorScheme={shouldAutoscroll ? 'blue' : 'gray'} bottom={12}
icon={<FaAngleDoubleDown />} aria-label="Toggle autoscroll"
onClick={() => setShouldAutoscroll(!shouldAutoscroll)} variant={'solid'}
/> colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
</Tooltip> icon={<FaAngleDoubleDown />}
)} onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}> />
<IconButton </Tooltip>
size='sm' )}
position={'fixed'} <Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
left={2} <IconButton
bottom={2} size="sm"
variant={'solid'} position={'fixed'}
aria-label='Toggle Log Viewer' left={2}
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />} bottom={2}
onClick={() => variant={'solid'}
dispatch(setShouldShowLogViewer(!shouldShowLogViewer)) aria-label="Toggle Log Viewer"
} icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
/> onClick={handleClickLogViewerToggle}
</Tooltip> />
</> </Tooltip>
); </>
);
}; };
export default LogViewer; export default LogViewer;

View File

@ -0,0 +1,38 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
currentStep: system.currentStep,
totalSteps: system.totalSteps,
currentStatusHasSteps: system.currentStatusHasSteps,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
const ProgressBar = () => {
const { isProcessing, currentStep, totalSteps, currentStatusHasSteps } =
useAppSelector(systemSelector);
const value = currentStep ? Math.round((currentStep * 100) / totalSteps) : 0;
return (
<Progress
height="10px"
value={value}
isIndeterminate={isProcessing && !currentStatusHasSteps}
/>
);
};
export default ProgressBar;

View File

@ -1,170 +1,164 @@
import { import {
Flex, Button,
FormControl, Flex,
FormLabel, FormControl,
Heading, FormLabel,
HStack, Heading,
Modal, HStack,
ModalBody, Modal,
ModalCloseButton, ModalBody,
ModalContent, ModalCloseButton,
ModalFooter, ModalContent,
ModalHeader, ModalFooter,
ModalOverlay, ModalHeader,
Switch, ModalOverlay,
Text, Switch,
useDisclosure, Text,
useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { import {
setShouldConfirmOnDelete, setShouldConfirmOnDelete,
setShouldDisplayInProgress, setShouldDisplayInProgress,
SystemState, SystemState,
} from './systemSlice'; } from './systemSlice';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { persistor } from '../../main'; import { persistor } from '../../main';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { cloneElement, ReactElement } from 'react'; import { cloneElement, ReactElement } from 'react';
const systemSelector = createSelector( const systemSelector = createSelector(
(state: RootState) => state.system, (state: RootState) => state.system,
(system: SystemState) => { (system: SystemState) => {
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system; const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
return { shouldDisplayInProgress, shouldConfirmOnDelete }; return { shouldDisplayInProgress, shouldConfirmOnDelete };
}, },
{ {
memoizeOptions: { resultEqualityCheck: isEqual }, memoizeOptions: { resultEqualityCheck: isEqual },
} }
); );
type Props = { type SettingsModalProps = {
children: ReactElement; /* The button to open the Settings Modal */
children: ReactElement;
}; };
const SettingsModal = ({ children }: Props) => { /**
const { * Modal for app settings. Also provides Reset functionality in which the
isOpen: isSettingsModalOpen, * app's localstorage is wiped via redux-persist.
onOpen: onSettingsModalOpen, *
onClose: onSettingsModalClose, * Secondary post-reset modal is included here.
} = useDisclosure(); */
const SettingsModal = ({ children }: SettingsModalProps) => {
const {
isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen,
onClose: onSettingsModalClose,
} = useDisclosure();
const { const {
isOpen: isRefreshModalOpen, isOpen: isRefreshModalOpen,
onOpen: onRefreshModalOpen, onOpen: onRefreshModalOpen,
onClose: onRefreshModalClose, onClose: onRefreshModalClose,
} = useDisclosure(); } = useDisclosure();
const { shouldDisplayInProgress, shouldConfirmOnDelete } = const { shouldDisplayInProgress, shouldConfirmOnDelete } =
useAppSelector(systemSelector); useAppSelector(systemSelector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleClickResetWebUI = () => { /**
persistor.purge().then(() => { * Resets localstorage, then opens a secondary modal informing user to
onSettingsModalClose(); * refresh their browser.
onRefreshModalOpen(); * */
}); const handleClickResetWebUI = () => {
}; persistor.purge().then(() => {
onSettingsModalClose();
onRefreshModalOpen();
});
};
return ( return (
<> <>
{cloneElement(children, { {cloneElement(children, {
onClick: onSettingsModalOpen, onClick: onSettingsModalOpen,
})} })}
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}> <Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
<ModalOverlay /> <ModalOverlay />
<ModalContent> <ModalContent>
<ModalHeader>Settings</ModalHeader> <ModalHeader>Settings</ModalHeader>
<ModalCloseButton /> <ModalCloseButton />
<ModalBody> <ModalBody>
<Flex gap={5} direction='column'> <Flex gap={5} direction="column">
<FormControl> <FormControl>
<HStack> <HStack>
<FormLabel marginBottom={1}> <FormLabel marginBottom={1}>
Display in-progress images (slower) Display in-progress images (slower)
</FormLabel> </FormLabel>
<Switch <Switch
isChecked={shouldDisplayInProgress} isChecked={shouldDisplayInProgress}
onChange={(e) => onChange={(e) =>
dispatch( dispatch(setShouldDisplayInProgress(e.target.checked))
setShouldDisplayInProgress( }
e.target.checked />
) </HStack>
) </FormControl>
} <FormControl>
/> <HStack>
</HStack> <FormLabel marginBottom={1}>Confirm on delete</FormLabel>
</FormControl> <Switch
<FormControl> isChecked={shouldConfirmOnDelete}
<HStack> onChange={(e) =>
<FormLabel marginBottom={1}> dispatch(setShouldConfirmOnDelete(e.target.checked))
Confirm on delete }
</FormLabel> />
<Switch </HStack>
isChecked={shouldConfirmOnDelete} </FormControl>
onChange={(e) =>
dispatch(
setShouldConfirmOnDelete(
e.target.checked
)
)
}
/>
</HStack>
</FormControl>
<Heading size={'md'}>Reset Web UI</Heading> <Heading size={'md'}>Reset Web UI</Heading>
<Text> <Text>
Resetting the web UI only resets the browser's Resetting the web UI only resets the browser's local cache of
local cache of your images and remembered your images and remembered settings. It does not delete any
settings. It does not delete any images from images from disk.
disk. </Text>
</Text> <Text>
<Text> If images aren't showing up in the gallery or something else
If images aren't showing up in the gallery or isn't working, please try resetting before submitting an issue
something else isn't working, please try on GitHub.
resetting before submitting an issue on GitHub. </Text>
</Text> <Button colorScheme="red" onClick={handleClickResetWebUI}>
<SDButton Reset Web UI
label='Reset Web UI' </Button>
colorScheme='red' </Flex>
onClick={handleClickResetWebUI} </ModalBody>
/>
</Flex>
</ModalBody>
<ModalFooter> <ModalFooter>
<SDButton <Button onClick={onSettingsModalClose}>Close</Button>
label='Close' </ModalFooter>
onClick={onSettingsModalClose} </ModalContent>
/> </Modal>
</ModalFooter>
</ModalContent>
</Modal>
<Modal <Modal
closeOnOverlayClick={false} closeOnOverlayClick={false}
isOpen={isRefreshModalOpen} isOpen={isRefreshModalOpen}
onClose={onRefreshModalClose} onClose={onRefreshModalClose}
isCentered isCentered
> >
<ModalOverlay bg='blackAlpha.300' backdropFilter='blur(40px)' /> <ModalOverlay bg="blackAlpha.300" backdropFilter="blur(40px)" />
<ModalContent> <ModalContent>
<ModalBody pb={6} pt={6}> <ModalBody pb={6} pt={6}>
<Flex justifyContent={'center'}> <Flex justifyContent={'center'}>
<Text fontSize={'lg'}> <Text fontSize={'lg'}>
Web UI has been reset. Refresh the page to Web UI has been reset. Refresh the page to reload.
reload. </Text>
</Text> </Flex>
</Flex> </ModalBody>
</ModalBody> </ModalContent>
</ModalContent> </Modal>
</Modal> </>
</> );
);
}; };
export default SettingsModal; export default SettingsModal;

View File

@ -0,0 +1,120 @@
import {
Flex,
Heading,
IconButton,
Link,
Spacer,
Text,
useColorMode,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isConnected: system.isConnected,
isProcessing: system.isProcessing,
currentIteration: system.currentIteration,
totalIterations: system.totalIterations,
currentStatus: system.currentStatus,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
/**
* Header, includes color mode toggle, settings button, status message.
*/
const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode();
const {
isConnected,
isProcessing,
currentIteration,
totalIterations,
currentStatus,
} = useAppSelector(systemSelector);
const statusMessageTextColor = isConnected ? 'green.500' : 'red.500';
const colorModeIcon = colorMode == 'light' ? <FaMoon /> : <FaSun />;
// Make FaMoon and FaSun icon apparent size consistent
const colorModeIconFontSize = colorMode == 'light' ? 18 : 20;
let statusMessage = currentStatus;
if (isProcessing) {
if (totalIterations > 1) {
statusMessage += ` [${currentIteration}/${totalIterations}]`;
}
}
return (
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
<Heading size={'lg'}>InvokeUI</Heading>
<Spacer />
<Text textColor={statusMessageTextColor}>{statusMessage}</Text>
<SettingsModal>
<IconButton
aria-label="Settings"
variant="link"
fontSize={24}
size={'sm'}
icon={<MdSettings />}
/>
</SettingsModal>
<IconButton
aria-label="Link to Github Issues"
variant="link"
fontSize={23}
size={'sm'}
icon={
<Link
isExternal
href="http://github.com/lstein/stable-diffusion/issues"
>
<MdHelp />
</Link>
}
/>
<IconButton
aria-label="Link to Github Repo"
variant="link"
fontSize={20}
size={'sm'}
icon={
<Link isExternal href="http://github.com/lstein/stable-diffusion">
<FaGithub />
</Link>
}
/>
<IconButton
aria-label="Toggle Dark Mode"
onClick={toggleColorMode}
variant="link"
size={'sm'}
fontSize={colorModeIconFontSize}
icon={colorModeIcon}
/>
</Flex>
);
};
export default SiteHeader;

View File

@ -1,10 +1,13 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { ExpandedIndex } from '@chakra-ui/react'; import { ExpandedIndex } from '@chakra-ui/react';
import * as InvokeAI from '../../app/invokeai'
export type LogLevel = 'info' | 'warning' | 'error';
export interface LogEntry { export interface LogEntry {
timestamp: string; timestamp: string;
level: LogLevel;
message: string; message: string;
} }
@ -12,10 +15,8 @@ export interface Log {
[index: number]: LogEntry; [index: number]: LogEntry;
} }
export interface SystemState { export interface SystemState extends InvokeAI.SystemStatus, InvokeAI.SystemConfig {
shouldDisplayInProgress: boolean; shouldDisplayInProgress: boolean;
isProcessing: boolean;
currentStep: number;
log: Array<LogEntry>; log: Array<LogEntry>;
shouldShowLogViewer: boolean; shouldShowLogViewer: boolean;
isGFPGANAvailable: boolean; isGFPGANAvailable: boolean;
@ -24,12 +25,17 @@ export interface SystemState {
socketId: string; socketId: string;
shouldConfirmOnDelete: boolean; shouldConfirmOnDelete: boolean;
openAccordions: ExpandedIndex; openAccordions: ExpandedIndex;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
} }
const initialSystemState = { const initialSystemState = {
isConnected: false, isConnected: false,
isProcessing: false, isProcessing: false,
currentStep: 0,
log: [], log: [],
shouldShowLogViewer: false, shouldShowLogViewer: false,
shouldDisplayInProgress: false, shouldDisplayInProgress: false,
@ -38,6 +44,17 @@ const initialSystemState = {
socketId: '', socketId: '',
shouldConfirmOnDelete: true, shouldConfirmOnDelete: true,
openAccordions: [0], openAccordions: [0],
currentStep: 0,
totalSteps: 0,
currentIteration: 0,
totalIterations: 0,
currentStatus: '',
currentStatusHasSteps: false,
model: '',
model_id: '',
model_hash: '',
app_id: '',
app_version: '',
}; };
const initialState: SystemState = initialSystemState; const initialState: SystemState = initialSystemState;
@ -51,18 +68,35 @@ export const systemSlice = createSlice({
}, },
setIsProcessing: (state, action: PayloadAction<boolean>) => { setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload; state.isProcessing = action.payload;
if (action.payload === false) {
state.currentStep = 0;
}
}, },
setCurrentStep: (state, action: PayloadAction<number>) => { setCurrentStatus: (state, action: PayloadAction<string>) => {
state.currentStep = action.payload; state.currentStatus = action.payload;
}, },
addLogEntry: (state, action: PayloadAction<string>) => { setSystemStatus: (state, action: PayloadAction<InvokeAI.SystemStatus>) => {
const currentStatus =
!action.payload.isProcessing && state.isConnected
? 'Connected'
: action.payload.currentStatus;
return { ...state, ...action.payload, currentStatus };
},
addLogEntry: (
state,
action: PayloadAction<{
timestamp: string;
message: string;
level?: LogLevel;
}>
) => {
const { timestamp, message, level } = action.payload;
const logLevel = level || 'info';
const entry: LogEntry = { const entry: LogEntry = {
timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp,
message: action.payload, message,
level: logLevel,
}; };
state.log.push(entry); state.log.push(entry);
}, },
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => { setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
@ -80,19 +114,24 @@ export const systemSlice = createSlice({
setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => { setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => {
state.openAccordions = action.payload; state.openAccordions = action.payload;
}, },
setSystemConfig: (state, action: PayloadAction<InvokeAI.SystemConfig>) => {
return { ...state, ...action.payload };
},
}, },
}); });
export const { export const {
setShouldDisplayInProgress, setShouldDisplayInProgress,
setIsProcessing, setIsProcessing,
setCurrentStep,
addLogEntry, addLogEntry,
setShouldShowLogViewer, setShouldShowLogViewer,
setIsConnected, setIsConnected,
setSocketId, setSocketId,
setShouldConfirmOnDelete, setShouldConfirmOnDelete,
setOpenAccordions, setOpenAccordions,
setSystemStatus,
setCurrentStatus,
setSystemConfig,
} = systemSlice.actions; } = systemSlice.actions;
export default systemSlice.reducer; export default systemSlice.reducer;

View File

@ -1,108 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
import { SystemState } from './systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
prompt: sd.prompt,
shouldGenerateVariations: sd.shouldGenerateVariations,
seedWeights: sd.seedWeights,
maskPath: sd.maskPath,
initialImagePath: sd.initialImagePath,
seed: sd.seed,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/*
Checks relevant pieces of state to confirm generation will not deterministically fail.
This is used to prevent the 'Generate' button from being clicked.
Other parameter values may cause failure but we rely on input validation for those.
*/
const useCheckParameters = () => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
seed,
} = useAppSelector(sdSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector);
return useMemo(() => {
// Cannot generate without a prompt
if (!prompt) {
return false;
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
return false;
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
return false;
}
// Cannot generate if not connected
if (!isConnected) {
return false;
}
// Cannot generate variations without valid seed weights
if (
shouldGenerateVariations &&
(!(validateSeedWeights(seedWeights) || seedWeights === '') ||
seed === -1)
) {
return false;
}
// All good
return true;
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
isConnected,
shouldGenerateVariations,
seedWeights,
seed,
]);
};
export default useCheckParameters;

View File

@ -8,9 +8,9 @@ import { persistStore } from 'redux-persist';
export const persistor = persistStore(store); export const persistor = persistStore(store);
import App from './App';
import { theme } from './app/theme'; import { theme } from './app/theme';
import Loading from './Loading'; import Loading from './Loading';
import App from './app/App';
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render( ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
<React.StrictMode> <React.StrictMode>

View File

@ -419,6 +419,13 @@
compute-scroll-into-view "1.0.14" compute-scroll-into-view "1.0.14"
copy-to-clipboard "3.3.1" copy-to-clipboard "3.3.1"
"@chakra-ui/icon@3.0.10":
version "3.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.10.tgz#1a11b5edb42a8af7aa5b6dec2bf2c6c4df1869fc"
integrity sha512-utO569d9bptEraJrEhuImfNzQ8v+a8PsQh8kTsodCzg8B16R3t5TTuoqeJqS6Nq16Vq6w87QbX3/4A73CNK5fw==
dependencies:
"@chakra-ui/shared-utils" "2.0.1"
"@chakra-ui/icon@3.0.9": "@chakra-ui/icon@3.0.9":
version "3.0.9" version "3.0.9"
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.9.tgz#ba127d9eefd727f62e9bce07a23eca39ae506744" resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.9.tgz#ba127d9eefd727f62e9bce07a23eca39ae506744"
@ -426,6 +433,13 @@
dependencies: dependencies:
"@chakra-ui/shared-utils" "2.0.1" "@chakra-ui/shared-utils" "2.0.1"
"@chakra-ui/icons@^2.0.10":
version "2.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/icons/-/icons-2.0.10.tgz#61aeb44c913c10e7ff77addc798494e50d66c760"
integrity sha512-hxMspvysOay2NsJyadM611F/Y4vVzJU/YkXTxsyBjm6v/DbENhpVmPnUf+kwwyl7dINNb9iOF+kuGxnuIEO1Tw==
dependencies:
"@chakra-ui/icon" "3.0.10"
"@chakra-ui/image@2.0.10": "@chakra-ui/image@2.0.10":
version "2.0.10" version "2.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/image/-/image-2.0.10.tgz#712c0e1c579d959225bd8316d8d8f66cbeb95bb8" resolved "https://registry.yarnpkg.com/@chakra-ui/image/-/image-2.0.10.tgz#712c0e1c579d959225bd8316d8d8f66cbeb95bb8"

View File

@ -100,6 +100,13 @@ SAMPLER_CHOICES = [
'plms', 'plms',
] ]
PRECISION_CHOICES = [
'auto',
'float32',
'autocast',
'float16',
]
# is there a way to pick this up during git commits? # is there a way to pick this up during git commits?
APP_ID = 'lstein/stable-diffusion' APP_ID = 'lstein/stable-diffusion'
APP_VERSION = 'v1.15' APP_VERSION = 'v1.15'
@ -174,31 +181,37 @@ class Args(object):
switches.append(f'-W {a["width"]}') switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}') switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}') switches.append(f'-C {a["cfg_scale"]}')
switches.append(f'-A {a["sampler_name"]}')
if a['grid']: if a['grid']:
switches.append('--grid') switches.append('--grid')
if a['seamless']: if a['seamless']:
switches.append('--seamless') switches.append('--seamless')
# img2img generations have parameters relevant only to them and have special handling
if a['init_img'] and len(a['init_img'])>0: if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}') switches.append(f'-I {a["init_img"]}')
if a['init_mask'] and len(a['init_mask'])>0: switches.append(f'-A ddim') # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
switches.append(f'-M {a["init_mask"]}') if a['fit']:
if a['init_color'] and len(a['init_color'])>0: switches.append(f'--fit')
switches.append(f'--init_color {a["init_color"]}') if a['init_mask'] and len(a['init_mask'])>0:
if a['fit']: switches.append(f'-M {a["init_mask"]}')
switches.append(f'--fit') if a['init_color'] and len(a['init_color'])>0:
if a['init_img'] and a['strength'] and a['strength']>0: switches.append(f'--init_color {a["init_color"]}')
switches.append(f'-f {a["strength"]}') if a['strength'] and a['strength']>0:
switches.append(f'-f {a["strength"]}')
else:
switches.append(f'-A {a["sampler_name"]}')
# gfpgan-specific parameters
if a['gfpgan_strength']: if a['gfpgan_strength']:
switches.append(f'-G {a["gfpgan_strength"]}') switches.append(f'-G {a["gfpgan_strength"]}')
# esrgan-specific parameters
if a['upscale']: if a['upscale']:
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}') switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
if a['embiggen']: if a['embiggen']:
switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}') switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}')
if a['embiggen_tiles']: if a['embiggen_tiles']:
switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}') switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}')
if a['variation_amount'] > 0:
switches.append(f'-v {a["variation_amount"]}')
if a['with_variations']: if a['with_variations']:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"])) formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
switches.append(f'-V {formatted_variations}') switches.append(f'-V {formatted_variations}')
@ -316,7 +329,16 @@ class Args(object):
'--full_precision', '--full_precision',
dest='full_precision', dest='full_precision',
action='store_true', action='store_true',
help='Use more memory-intensive full precision math for calculations', help='Deprecated way to set --precision=float32',
)
model_group.add_argument(
'--precision',
dest='precision',
type=str,
choices=PRECISION_CHOICES,
metavar='PRECISION',
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
) )
file_group.add_argument( file_group.add_argument(
'--from_file', '--from_file',
@ -618,18 +640,24 @@ def metadata_dumps(opt,
postprocessing=postprocessing postprocessing=postprocessing
) )
# TODO: This is just a hack until postprocessing pipeline work completed # 'postprocessing' is either null or an array of postprocessing metadatal
image_dict['postprocessing'] = [] if postprocessing:
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0: # TODO: This is just a hack until postprocessing pipeline work completed
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)') image_dict['postprocessing'] = []
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)') if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
else:
image_dict['postprocessing'] = None
# remove any image keys not mentioned in RFC #266 # remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','step_number','width','height','extra','strength'] 'cfg_scale','step_number','width','height','extra','strength']
rfc_dict ={} rfc_dict ={}
for item in image_dict.items(): for item in image_dict.items():
key,value = item key,value = item
if key in rfc266_img_fields: if key in rfc266_img_fields:
@ -644,18 +672,17 @@ def metadata_dumps(opt,
subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts] subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts]
rfc_dict['prompt'] = subprompts rfc_dict['prompt'] = subprompts
# variations # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
if opt.with_variations: rfc_dict['variations'] = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations] if opt.with_variations else []
variations = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations]
rfc_dict['variations'] = variations
if opt.init_img: if opt.init_img:
rfc_dict['type'] = 'img2img' rfc_dict['type'] = 'img2img'
rfc_dict['strength_steps'] = rfc_dict.pop('strength') rfc_dict['strength_steps'] = rfc_dict.pop('strength')
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img) rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS rfc_dict['sampler'] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
else: else:
rfc_dict['type'] = 'txt2img' rfc_dict['type'] = 'txt2img'
rfc_dict.pop('strength')
if len(seeds)==0 and opt.seed: if len(seeds)==0 and opt.seed:
seeds=[seed] seeds=[seed]

View File

@ -1,6 +1,6 @@
import torch import torch
from torch import autocast from torch import autocast
from contextlib import contextmanager, nullcontext from contextlib import nullcontext
def choose_torch_device() -> str: def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on''' '''Convenience routine for guessing which GPU device to run model on'''
@ -10,15 +10,18 @@ def choose_torch_device() -> str:
return 'mps' return 'mps'
return 'cpu' return 'cpu'
def choose_autocast_device(device): def choose_precision(device) -> str:
'''Returns an autocast compatible device from a torch device''' '''Returns an appropriate precision for the given torch device'''
device_type = device.type # this returns 'mps' on M1 if device.type == 'cuda':
# autocast only for cuda, but GTX 16xx have issues with it device_name = torch.cuda.get_device_name(device)
if device_type == 'cuda': if not ('GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name):
device_name = torch.cuda.get_device_name() return 'float16'
if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name: return 'float32'
return device_type,nullcontext
else: def choose_autocast(precision):
return device_type,autocast '''Returns an autocast context or nullcontext for the given precision string'''
else: # float16 currently requires autocast to avoid errors like:
return 'cpu',nullcontext # 'expected scalar type Half but found Float'
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext

View File

@ -9,13 +9,14 @@ from tqdm import tqdm, trange
from PIL import Image from PIL import Image
from einops import rearrange, repeat from einops import rearrange, repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from ldm.dream.devices import choose_autocast_device from ldm.dream.devices import choose_autocast
downsampling = 8 downsampling = 8
class Generator(): class Generator():
def __init__(self,model): def __init__(self, model, precision):
self.model = model self.model = model
self.precision = precision
self.seed = None self.seed = None
self.latent_channels = model.channels self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
@ -38,7 +39,7 @@ class Generator():
def generate(self,prompt,init_image,width,height,iterations=1,seed=None, def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None, image_callback=None, step_callback=None,
**kwargs): **kwargs):
device_type,scope = choose_autocast_device(self.model.device) scope = choose_autocast(self.precision)
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
init_image = init_image, init_image = init_image,
@ -51,7 +52,7 @@ class Generator():
results = [] results = []
seed = seed if seed else self.new_seed() seed = seed if seed else self.new_seed()
seed, initial_noise = self.generate_initial_noise(seed, width, height) seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(device_type), self.model.ema_scope(): with scope(self.model.device.type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'): for n in trange(iterations, desc='Generating'):
x_T = None x_T = None
if self.variation_amount > 0: if self.variation_amount > 0:

View File

@ -11,8 +11,8 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.dream.generator.img2img import Img2Img from ldm.dream.generator.img2img import Img2Img
class Embiggen(Generator): class Embiggen(Generator):
def __init__(self,model): def __init__(self, model, precision):
super().__init__(model) super().__init__(model, precision)
self.init_latent = None self.init_latent = None
@torch.no_grad() @torch.no_grad()

View File

@ -4,13 +4,13 @@ ldm.dream.generator.img2img descends from ldm.dream.generator
import torch import torch
import numpy as np import numpy as np
from ldm.dream.devices import choose_autocast_device from ldm.dream.devices import choose_autocast
from ldm.dream.generator.base import Generator from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self,model): def __init__(self, model, precision):
super().__init__(model) super().__init__(model, precision)
self.init_latent = None # by get_noise() self.init_latent = None # by get_noise()
@torch.no_grad() @torch.no_grad()
@ -32,8 +32,8 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
) )
device_type,scope = choose_autocast_device(self.model.device) scope = choose_autocast(self.precision)
with scope(device_type): with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding( self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image) self.model.encode_first_stage(init_image)
) # move to latent space ) # move to latent space

View File

@ -5,14 +5,14 @@ ldm.dream.generator.inpaint descends from ldm.dream.generator
import torch import torch
import numpy as np import numpy as np
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.dream.devices import choose_autocast_device from ldm.dream.devices import choose_autocast
from ldm.dream.generator.img2img import Img2Img from ldm.dream.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
class Inpaint(Img2Img): class Inpaint(Img2Img):
def __init__(self,model): def __init__(self, model, precision):
self.init_latent = None self.init_latent = None
super().__init__(model) super().__init__(model, precision)
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
@ -38,8 +38,8 @@ class Inpaint(Img2Img):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
) )
device_type,scope = choose_autocast_device(self.model.device) scope = choose_autocast(self.precision)
with scope(device_type): with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding( self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image) self.model.encode_first_stage(init_image)
) # move to latent space ) # move to latent space

View File

@ -7,8 +7,8 @@ import numpy as np
from ldm.dream.generator.base import Generator from ldm.dream.generator.base import Generator
class Txt2Img(Generator): class Txt2Img(Generator):
def __init__(self,model): def __init__(self, model, precision):
super().__init__(model) super().__init__(model, precision)
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,

61
ldm/dream/log.py Normal file
View File

@ -0,0 +1,61 @@
"""
Functions for better format logging
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
1 write_log_message -- Writes a message to the console
2 write_log_files -- Writes a message to files
2.1 write_log_default -- File in plain text
2.2 write_log_txt -- File in txt format
2.3 write_log_markdown -- File in markdown format
"""
import os
def write_log(results, log_path, file_types, output_cntr):
"""
logs the name of the output image, prompt, and prompt args to the terminal and files
"""
output_cntr = write_log_message(results, output_cntr)
write_log_files(results, log_path, file_types)
return output_cntr
def write_log_message(results, output_cntr):
"""logs to the terminal"""
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
for l in log_lines:
output_cntr += 1
print(f"[{output_cntr}] {l}", end="")
return output_cntr
def write_log_files(results, log_path, file_types):
for file_type in file_types:
if file_type == "txt":
write_log_txt(log_path, results)
elif file_type == "md" or file_type == "markdown":
write_log_markdown(log_path, results)
else:
print(f"'{file_type}' format is not supported, so write in plain text")
write_log_default(log_path, results, file_type)
def write_log_default(log_path, results, file_type):
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
file.writelines(plain_txt_lines)
def write_log_txt(log_path, results):
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + ".txt", "a", encoding="utf-8") as file:
file.writelines(txt_lines)
def write_log_markdown(log_path, results):
md_lines = []
for path, prompt in results:
file_name = os.path.basename(path)
md_lines.append(f"## {file_name}\n![]({file_name})\n\n{prompt}\n")
with open(log_path + ".md", "a", encoding="utf-8") as file:
file.writelines(md_lines)

View File

@ -34,6 +34,7 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt # saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output # returns full path of output
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None): def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
print(f'self.outdir={self.outdir}, name={name}')
path = os.path.join(self.outdir, name) path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt) info.add_text('Dream', dream_prompt)

View File

@ -29,7 +29,7 @@ from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from ldm.dream.image_util import InitImageResizer from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device from ldm.dream.devices import choose_torch_device, choose_precision
from ldm.dream.conditioning import get_uc_and_c from ldm.dream.conditioning import get_uc_and_c
def fix_func(orig): def fix_func(orig):
@ -104,7 +104,7 @@ gr = Generate(
# these values are set once and shouldn't be changed # these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml') conf = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file model = symbolic name of the model in the configuration file
full_precision = False precision = float precision to be used
# this value is sticky and maintained between generation calls # this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
@ -130,6 +130,7 @@ class Generate:
sampler_name = 'k_lms', sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic ddim_eta = 0.0, # deterministic
full_precision = False, full_precision = False,
precision = 'auto',
# these are deprecated; if present they override values in the conf file # these are deprecated; if present they override values in the conf file
weights = None, weights = None,
config = None, config = None,
@ -145,7 +146,7 @@ class Generate:
self.cfg_scale = 7.5 self.cfg_scale = 7.5
self.sampler_name = sampler_name self.sampler_name = sampler_name
self.ddim_eta = 0.0 # same seed always produces same image self.ddim_eta = 0.0 # same seed always produces same image
self.full_precision = True if choose_torch_device() == 'mps' else full_precision self.precision = precision
self.strength = 0.75 self.strength = 0.75
self.seamless = False self.seamless = False
self.embedding_path = embedding_path self.embedding_path = embedding_path
@ -162,6 +163,14 @@ class Generate:
# it wasn't actually doing anything. This logic could be reinstated. # it wasn't actually doing anything. This logic could be reinstated.
device_type = choose_torch_device() device_type = choose_torch_device()
self.device = torch.device(device_type) self.device = torch.device(device_type)
if full_precision:
if self.precision != 'auto':
raise ValueError('Remove --full_precision / -F if using --precision')
print('Please remove deprecated --full_precision / -F')
print('If auto config does not work you can use --precision=float32')
self.precision = 'float32'
if self.precision == 'auto':
self.precision = choose_precision(self.device)
# for VRAM usage statistics # for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
@ -440,25 +449,25 @@ class Generate:
def _make_img2img(self): def _make_img2img(self):
if not self.generators.get('img2img'): if not self.generators.get('img2img'):
from ldm.dream.generator.img2img import Img2Img from ldm.dream.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model) self.generators['img2img'] = Img2Img(self.model, self.precision)
return self.generators['img2img'] return self.generators['img2img']
def _make_embiggen(self): def _make_embiggen(self):
if not self.generators.get('embiggen'): if not self.generators.get('embiggen'):
from ldm.dream.generator.embiggen import Embiggen from ldm.dream.generator.embiggen import Embiggen
self.generators['embiggen'] = Embiggen(self.model) self.generators['embiggen'] = Embiggen(self.model, self.precision)
return self.generators['embiggen'] return self.generators['embiggen']
def _make_txt2img(self): def _make_txt2img(self):
if not self.generators.get('txt2img'): if not self.generators.get('txt2img'):
from ldm.dream.generator.txt2img import Txt2Img from ldm.dream.generator.txt2img import Txt2Img
self.generators['txt2img'] = Txt2Img(self.model) self.generators['txt2img'] = Txt2Img(self.model, self.precision)
return self.generators['txt2img'] return self.generators['txt2img']
def _make_inpaint(self): def _make_inpaint(self):
if not self.generators.get('inpaint'): if not self.generators.get('inpaint'):
from ldm.dream.generator.inpaint import Inpaint from ldm.dream.generator.inpaint import Inpaint
self.generators['inpaint'] = Inpaint(self.model) self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint'] return self.generators['inpaint']
def load_model(self): def load_model(self):
@ -469,7 +478,7 @@ class Generate:
model = self._load_model_from_config(self.config, self.weights) model = self._load_model_from_config(self.config, self.weights)
if self.embedding_path is not None: if self.embedding_path is not None:
model.embedding_manager.load( model.embedding_manager.load(
self.embedding_path, self.full_precision self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
) )
self.model = model.to(self.device) self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
@ -620,15 +629,12 @@ class Generate:
model = instantiate_from_config(c.model) model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
if self.full_precision: if self.precision == 'float16':
print( print('Using faster float16 precision')
'>> Using slower but more accurate full-precision math (--full_precision)' model.to(torch.float16)
)
else: else:
print( print('Using more accurate float32 precision')
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
)
model.half()
model.to(self.device) model.to(self.device)
model.eval() model.eval()

3
pyproject.toml Normal file
View File

@ -0,0 +1,3 @@
[tool.blue]
line-length = 90
target-version = ['py310']

Some files were not shown because too many files have changed in this diff Show More